slope 0.29.0
Loading...
Searching...
No Matches
folds.cpp
1#include "folds.h"
2#include <algorithm>
3#include <random>
4
5namespace slope {
6
7const std::vector<int>&
8Folds::getTestIndices(size_t fold_idx, size_t rep_idx) const
9{
10 return folds[rep_idx][fold_idx];
11}
12
13std::vector<int>
14Folds::getTrainingIndices(size_t fold_idx, size_t rep_idx) const
15{
16 std::vector<int> train_indices;
17 for (size_t i = 0; i < n_folds; ++i) {
18 if (i != fold_idx) {
19 const auto& fold = folds[rep_idx][i];
20 train_indices.insert(train_indices.end(), fold.begin(), fold.end());
21 }
22 }
23 return train_indices;
24}
25
26std::vector<std::vector<int>>
27Folds::createFolds(int n, int n_folds, uint64_t random_seed)
28{
29 // Initialize random number generator
30 std::mt19937 generator(random_seed);
31
32 // Create and shuffle indices
33 std::vector<int> indices(n);
34 std::iota(indices.begin(), indices.end(), 0);
35 std::shuffle(indices.begin(), indices.end(), generator);
36
37 // Create folds
38 std::vector<std::vector<int>> folds(n_folds);
39
40 // Calculate base fold size and remainder
41 int base_fold_size = n / n_folds;
42 int remainder = n % n_folds;
43
44 // Current position in indices
45 int current_pos = 0;
46
47 // Distribute indices across folds
48 for (int fold = 0; fold < n_folds; ++fold) {
49 // Add one extra element to early folds if we have remainder
50 int fold_size = base_fold_size + (fold < remainder ? 1 : 0);
51
52 // Fill this fold
53 folds[fold].reserve(fold_size);
54 for (int i = 0; i < fold_size; ++i) {
55 folds[fold].push_back(indices[current_pos++]);
56 }
57 }
58
59 return folds;
60}
61
62} // namespace slope
std::size_t n_folds
Number of folds.
Definition folds.h:164
std::vector< int > getTrainingIndices(size_t fold_idx, size_t rep_idx=0) const
Get training indices for a specific fold and repetition.
Definition folds.cpp:14
const std::vector< int > & getTestIndices(size_t fold_idx, size_t rep_idx=0) const
Get test indices for a specific fold and repetition.
Definition folds.cpp:8
std::vector< std::vector< std::vector< int > > > folds
Indices for each fold in each repetition.
Definition folds.h:163
Cross-validation fold management for SLOPE models.
Namespace containing SLOPE regression implementation.
Definition clusters.cpp:5