slope 0.29.0
Loading...
Searching...
No Matches
ols.cpp
1#include "ols.h"
2#include <Eigen/Dense>
3#include <Eigen/SparseQR>
4#include <utility>
5
6std::pair<double, Eigen::VectorXd>
7fitOls(const Eigen::MatrixXd& X, const Eigen::VectorXd& y, bool fit_intercept)
8{
9 Eigen::MatrixXd x_mod = X;
10
11 if (fit_intercept) {
12 // Add column of ones for intercept
13 Eigen::MatrixXd ones = Eigen::MatrixXd::Ones(x_mod.rows(), 1);
14 x_mod.conservativeResize(Eigen::NoChange, x_mod.cols() + 1);
15 x_mod.rightCols(1) = ones;
16 }
17
18 // Solve with column pivoting
19 Eigen::ColPivHouseholderQR<Eigen::MatrixXd> qr(x_mod);
20 Eigen::VectorXd all_coefs = qr.solve(y);
21
22 double intercept = 0.0;
23 Eigen::VectorXd coeffs;
24
25 if (fit_intercept) {
26 intercept = all_coefs(all_coefs.size() - 1);
27 coeffs = all_coefs.head(all_coefs.size() - 1);
28 } else {
29 coeffs = all_coefs;
30 }
31
32 return { intercept, coeffs };
33}
34
35std::pair<double, Eigen::VectorXd>
36fitOls(const Eigen::SparseMatrix<double>& X,
37 const Eigen::VectorXd& y,
38 bool fit_intercept)
39{
40 Eigen::SparseMatrix<double> x_mod = X;
41
42 if (fit_intercept) {
43 // Construct column of ones for intercept
44 Eigen::VectorXd ones = Eigen::VectorXd::Ones(x_mod.rows());
45 Eigen::SparseMatrix<double> intercept_col(x_mod.rows(), 1);
46 intercept_col.reserve(x_mod.rows());
47
48 for (int i = 0; i < x_mod.rows(); ++i) {
49 intercept_col.insert(i, 0) = ones(i);
50 }
51
52 // Concatenate X and intercept column
53 Eigen::SparseMatrix<double> temp(x_mod.rows(), x_mod.cols() + 1);
54 temp.leftCols(x_mod.cols()) = x_mod;
55 temp.rightCols(1) = intercept_col;
56 x_mod = temp;
57 }
58
59 // Solve with sparse QR
60 Eigen::SparseQR<Eigen::SparseMatrix<double>, Eigen::COLAMDOrdering<int>>
61 solver;
62 solver.compute(x_mod);
63 Eigen::VectorXd all_cofs = solver.solve(y);
64
65 double intercept = 0.0;
66 Eigen::VectorXd coeffs;
67
68 if (fit_intercept) {
69 intercept = all_cofs(all_cofs.size() - 1);
70 coeffs = all_cofs.head(all_cofs.size() - 1);
71 } else {
72 coeffs = all_cofs;
73 }
74
75 return { intercept, coeffs };
76}
Ordinary Least Squares (OLS) regression functionality.