Predictions for sgdnet Models

# S3 method for sgdnet
predict(object, newx = NULL, s = NULL, type,
  exact = FALSE, newoffset = NULL, ...)

# S3 method for sgdnet_gaussian
predict(object, newx = NULL, s = NULL,
  type = c("link", "response", "coefficients", "nonzero"),
  exact = FALSE, newoffset = NULL, ...)

# S3 method for sgdnet_binomial
predict(object, newx = NULL, s = NULL,
  type = c("link", "response", "coefficients", "nonzero", "class"),
  exact = FALSE, newoffset = NULL, ...)

# S3 method for sgdnet_multinomial
predict(object, newx = NULL, s = NULL,
  type = c("link", "response", "coefficients", "nonzero", "class"),
  exact = FALSE, newoffset = NULL, ...)

# S3 method for sgdnet_mgaussian
predict(object, newx = NULL, s = NULL,
  type = c("link", "response", "coefficients", "nonzero"),
  exact = FALSE, newoffset = NULL, ...)

Arguments

object

an object of class 'sgdnet'.

newx

new data to predict on. Must be provided if type is "link".

s

the lambda penalty value on which to base the predictions.

type

type of prediction to return, one of

link

linear predictors,

response

responses,

coefficients

coefficients (weights); equivalent to calling coef()

nonzero

nonzero coefficients at each step of the regularization path, and

class

class predictions for each new data point in newx at each step of the regularization path -- only useful for 'binomial' and 'multinomial' families.

exact

if the given value of s is not in the model and exact = TRUE, the model will be refit using s. If FALSE, predictions will be made using a linearly interpolated coefficient matrix.

newoffset

if an offset was used in the call to sgdnet(), a new offset can be provided here for making predictions (but not for type = 'coefficients'/'nonzero')

...

arguments to be passed on to stats::update() to refit the model via sgdnet() if s is missing from the model and an exact fit is required by exact.

Value

Predictions for object given data in newx.

See also

Examples

# Gaussian # Split into training and test sets n <- length(abalone$y) train_ind <- sample(n, size = floor(0.8 * n)) # Fit the model using the training set fit_gaussian <- sgdnet(abalone$x[train_ind, ], abalone$y[train_ind]) # Predict using the test set pred_gaussian <- predict(fit_gaussian, newx = abalone$x[-train_ind, ]) # Mean absolute prediction error along regularization path mae <- 1/(n - length(train_ind)) * colSums(abs(abalone$y[-train_ind] - pred_gaussian)) # Binomial n <- length(heart$y) train_ind <- sample(n, size = floor(0.8 * n)) fit_binomial <- sgdnet(heart$x[train_ind, ], heart$y[train_ind], family = "binomial") # Predict classes at custom lambda value (s) using linear interpolation predict(fit_binomial, heart$x[-train_ind, ], type = "class", s = 1/n)
#> 1 #> [1,] "absence" #> [2,] "absence" #> [3,] "absence" #> [4,] "presence" #> [5,] "presence" #> [6,] "absence" #> [7,] "absence" #> [8,] "presence" #> [9,] "absence" #> [10,] "absence" #> [11,] "absence" #> [12,] "absence" #> [13,] "absence" #> [14,] "presence" #> [15,] "absence" #> [16,] "presence" #> [17,] "absence" #> [18,] "presence" #> [19,] "absence" #> [20,] "presence" #> [21,] "absence" #> [22,] "absence" #> [23,] "absence" #> [24,] "presence" #> [25,] "absence" #> [26,] "presence" #> [27,] "presence" #> [28,] "absence" #> [29,] "presence" #> [30,] "presence" #> [31,] "presence" #> [32,] "absence" #> [33,] "absence" #> [34,] "absence" #> [35,] "presence" #> [36,] "absence" #> [37,] "presence" #> [38,] "absence" #> [39,] "absence" #> [40,] "absence" #> [41,] "absence" #> [42,] "absence" #> [43,] "absence" #> [44,] "absence" #> [45,] "presence" #> [46,] "absence" #> [47,] "absence" #> [48,] "presence" #> [49,] "absence" #> [50,] "absence" #> [51,] "presence" #> [52,] "absence" #> [53,] "absence" #> [54,] "absence"
# Multinomial n <- length(wine$y) train_ind <- sample(n, size = floor(0.8 * n)) fit_multinomial <- sgdnet(wine$x[train_ind, ], wine$y[train_ind], family = "multinomial", alpha = 0.25) predict(fit_multinomial, wine$x[-train_ind, ], s = 0.0001, exact = TRUE, type = "class")
#> Error in NROW(x): object 'train_ind' not found
# Multivariate gaussian regression, predict nonzero coefficients fit_mgaussian <- sgdnet(student$x, student$y, family = "mgaussian") predict(fit_mgaussian, type = "nonzero")
#> $s0 #> NULL #> #> $s1 #> [1] 7 #> #> $s2 #> [1] 7 #> #> $s3 #> [1] 3 7 #> #> $s4 #> [1] 2 3 7 #> #> $s5 #> [1] 2 3 7 #> #> $s6 #> [1] 2 3 7 #> #> $s7 #> [1] 2 3 4 7 #> #> $s8 #> [1] 2 3 4 7 19 #> #> $s9 #> [1] 2 3 4 5 7 10 19 #> #> $s10 #> [1] 2 3 4 5 7 10 11 19 #> #> $s11 #> [1] 1 2 3 4 5 7 8 10 11 18 19 #> #> $s12 #> [1] 1 2 3 4 5 7 8 9 10 11 18 19 #> #> $s13 #> [1] 1 2 3 4 5 7 8 9 10 11 18 19 #> #> $s14 #> [1] 1 2 3 4 5 7 8 9 10 11 18 19 #> #> $s15 #> [1] 1 2 3 4 5 7 8 9 10 11 18 19 #> #> $s16 #> [1] 1 2 3 4 5 7 8 9 10 11 18 19 #> #> $s17 #> [1] 1 2 3 4 5 7 8 9 10 11 18 19 #> #> $s18 #> [1] 1 2 3 4 5 7 8 9 10 11 12 16 18 19 #> #> $s19 #> [1] 1 2 3 4 5 7 8 9 10 11 12 16 18 19 #> #> $s20 #> [1] 1 2 3 4 5 7 8 9 10 11 12 16 18 19 21 #> #> $s21 #> [1] 1 2 3 4 5 7 8 9 10 11 12 16 18 19 21 #> #> $s22 #> [1] 1 2 3 4 5 7 8 9 10 11 12 16 18 19 21 #> #> $s23 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 16 17 18 19 21 #> #> $s24 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 16 17 18 19 21 #> #> $s25 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 16 17 18 19 21 #> #> $s26 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 16 17 18 19 21 #> #> $s27 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 15 16 17 18 19 21 #> #> $s28 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 15 16 17 18 19 21 #> #> $s29 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 15 16 17 18 19 21 #> #> $s30 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 15 16 17 18 19 21 #> #> $s31 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 15 16 17 18 19 21 #> #> $s32 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 15 16 17 18 19 21 #> #> $s33 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 15 16 17 18 19 21 #> #> $s34 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 15 16 17 18 19 21 #> #> $s35 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 21 #> #> $s36 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 21 #> #> $s37 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 21 #> #> $s38 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 21 #> #> $s39 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 21 #> #> $s40 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s41 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s42 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s43 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s44 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s45 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s46 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s47 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s48 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s49 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s50 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s51 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s52 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s53 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s54 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s55 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s56 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s57 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s58 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s59 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s60 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s61 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s62 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s63 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s64 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s65 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s66 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s67 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s68 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s69 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s70 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s71 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s72 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s73 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s74 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s75 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s76 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s77 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s78 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s79 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s80 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s81 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s82 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s83 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s84 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s85 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s86 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s87 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s88 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s89 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s90 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s91 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s92 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s93 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s94 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s95 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s96 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s97 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s98 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #> #> $s99 #> [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 #>