slope 6.2.1
Loading...
Searching...
No Matches
utils.h
Go to the documentation of this file.
1
6#pragma once
7
8#include "eigen_compat.h"
9#include <Eigen/Core>
10#include <Eigen/SparseCore>
11#include <algorithm>
12#include <numeric>
13#include <set>
14#include <string>
15#include <unordered_set>
16#include <vector>
17
18namespace slope {
19
20using slope::all;
21
31template<typename Derived>
32Eigen::Index
33nonZeros(const Eigen::SparseMatrixBase<Derived>& x)
34{
35 return x.derived().nonZeros();
36}
37
44template<typename Derived>
45Eigen::Index
46nonZeros(const Eigen::DenseBase<Derived>& x)
47{
48 return (x.derived().array() != 0).count();
49}
50
67template<typename T>
68void
69sort(T& v, const bool descending = false)
70{
71 if (descending) {
72 std::sort(
73 v.data(), v.data() + v.size(), std::greater<typename T::value_type>());
74 } else {
75 std::sort(
76 v.data(), v.data() + v.size(), std::less<typename T::value_type>());
77 }
78}
79
92template<typename T>
93std::vector<int>
94which(const T& x)
95{
96 std::vector<int> out;
97 for (int i = 0; i < x.size(); i++) {
98 if (x[i]) {
99 out.emplace_back(i);
100 }
101 }
102
103 return out;
104}
105
121template<typename T>
122std::vector<int>
123sortIndex(T& v, const bool descending = false)
124{
125 using namespace std;
126
127 vector<int> idx(v.size());
128 iota(idx.begin(), idx.end(), 0);
129
130 if (descending) {
131 sort(idx.begin(), idx.end(), [&v](int i, int j) { return v[i] > v[j]; });
132 } else {
133 sort(idx.begin(), idx.end(), [&v](int i, int j) { return v[i] < v[j]; });
134 }
135
136 return idx;
137}
138
150template<typename T>
151void
152permute(T& values, const std::vector<int>& ind)
153{
157 T out(values.size());
158
162 for (int i = 0; i < values.size(); ++i)
163 out[i] = std::move(values[ind[i]]);
164
168 values = std::move(out);
169}
170
184template<typename T>
185void
186inversePermute(T& values, const std::vector<int>& ind)
187{
188 T out(values.size());
190 for (int i = 0; i < values.size(); ++i)
191 out[ind[i]] = std::move(values[i]);
192
193 values = std::move(out);
194}
195
208template<typename T>
209void
210move_elements(std::vector<T>& v, const int from, const int to, const int size)
211{
212 assert(from >= 0);
213 assert(to >= 0);
214 assert(size >= 0);
215 assert(from != to);
216
217 if (from > to) {
218 assert(from + size <= static_cast<int>(v.size()));
219 std::rotate(v.begin() + to, v.begin() + from, v.begin() + from + size);
220 } else {
221 assert(to + size <= static_cast<int>(v.size()));
222 std::rotate(
223 v.begin() + from, v.begin() + from + size, v.begin() + to + size);
224 }
225}
226
239void
240validateOption(const std::string& value,
241 const std::set<std::string>& valid_options,
242 const std::string& parameter_name);
243
255template<typename T>
256T
257subset(const Eigen::EigenBase<T>& x, const std::vector<int>& indices)
258{
259 return subset(x.derived(), indices);
260}
261
273template<typename T>
274typename Eigen::MatrixBase<T>::PlainObject
275subset(const Eigen::DenseBase<T>& x, const std::vector<int>& indices)
276{
277 return x.derived()(indices, all);
278}
279
293template<typename T>
294T
295subset(const Eigen::SparseMatrixBase<T>& x, const std::vector<int>& indices)
296{
297 std::vector<Eigen::Triplet<double>> triplets;
298 triplets.reserve(slope::nonZeros(x.derived()));
299
300 for (int j = 0; j < x.cols(); ++j) {
301 for (typename T::InnerIterator it(x.derived(), j); it; ++it) {
302 auto it_idx = std::find(indices.begin(), indices.end(), it.row());
303
304 if (it_idx != indices.end()) {
305 int new_row = std::distance(indices.begin(), it_idx);
306 triplets.emplace_back(new_row, j, it.value());
307 }
308 }
309 }
310
311 T out(indices.size(), x.cols());
312 out.setFromTriplets(triplets.begin(), triplets.end());
313
314 return out;
315}
316
329template<typename T>
330T
331subsetCols(const Eigen::MatrixBase<T>& x, const std::vector<int>& indices)
332{
333 return x.derived()(all, indices);
334}
335
348template<typename T>
349T
350subsetCols(const Eigen::SparseMatrixBase<T>& x, const std::vector<int>& indices)
351{
352 std::vector<Eigen::Triplet<double>> triplets;
353 triplets.reserve(slope::nonZeros(x.derived()));
354
355 for (size_t j_idx = 0; j_idx < indices.size(); ++j_idx) {
356 int j = indices[j_idx];
357 for (typename T::InnerIterator it(x.derived(), j); it; ++it) {
358 triplets.emplace_back(it.row(), j_idx, it.value());
359 }
360 }
361
362 T out(x.rows(), indices.size());
363 out.setFromTriplets(triplets.begin(), triplets.end());
364
365 return out;
366}
367
379inline std::unordered_set<double>
380unique(const Eigen::MatrixXd& x)
381{
382 std::unordered_set<double> unique;
383 for (Eigen::Index j = 0; j < x.cols(); j++) {
384 for (Eigen::Index i = 0; i < x.rows(); i++) {
385 unique.insert(x(i, j));
386 }
387 }
388
389 return unique;
390}
391
392} // namespace slope
Eigen compatibility layer for version differences.
Namespace containing SLOPE regression implementation.
Definition clusters.h:11
std::unordered_set< double > unique(const Eigen::MatrixXd &x)
Create a set of unique values from an Eigen matrix.
Definition utils.h:380
void move_elements(std::vector< T > &v, const int from, const int to, const int size)
Definition utils.h:210
std::vector< int > which(const T &x)
Definition utils.h:94
T subsetCols(const Eigen::MatrixBase< T > &x, const std::vector< int > &indices)
Extract specified columns from a dense matrix.
Definition utils.h:331
T subset(const Eigen::EigenBase< T > &x, const std::vector< int > &indices)
Extract a subset of rows from an Eigen matrix.
Definition utils.h:257
void validateOption(const std::string &value, const std::set< std::string > &valid_options, const std::string &parameter_name)
Validates if a given value exists in a set of valid options.
std::vector< int > sortIndex(T &v, const bool descending=false)
Definition utils.h:123
void permute(T &values, const std::vector< int > &ind)
Definition utils.h:152
void sort(T &v, const bool descending=false)
Definition utils.h:69
void inversePermute(T &values, const std::vector< int > &ind)
Definition utils.h:186
Eigen::Index nonZeros(const Eigen::SparseMatrixBase< Derived > &x)
Definition utils.h:33