CS-7863-Sci-Stat-Proj-6/Schrick-Noah_Ridge-LASSO-Regression.R

126 lines
4.6 KiB
R

if (!require("numbers")) install.packages("numbers")
library(numbers)
penalized_loss <- function(X, y, beta, lam, alpha=0){
# y needs to be 0/1
# beta: regression coefficients
# lam: penalty, lam=0 un-penalized logistic regression
# alpha = 0 ridge penalty, alpha = 1 lasso penalty
m <- nrow(X)
Xtilde <- as.matrix(cbind(intercept=rep(1,m), X))
cnames <- colnames(Xtilde)
z <- Xtilde %*% beta # column vector
yhat <- 1/(1+exp(-z))
yclass <- as.numeric(y)
# 1. logistic unpenalized loss
penal.loss <- sum(-yclass*log(yhat) - (1-yclass)*log(1-yhat))/m +
# 2. penalty, lam=0 removes penalty
lam*((1-alpha)*lam*sum(beta*beta)/2 + # ridge
alpha*sum(abs(beta))) # lasso
return(penal.loss)
}
ridge_grad <- function(X, y, beta, lam){
# y needs to be 0/1
# also works for non-penalized logistic regression if lam=0
m <- nrow(X)
p <- ncol(X)
Xtilde <- as.matrix(cbind(intercept=rep(1,m), X))
cnames <- colnames(Xtilde)
z <- Xtilde %*% beta # column vector
yhat <- 1/(1+exp(-z))
yclass <- as.numeric(y)
grad <- rep(0,p+1)
for (a in seq(1,p+1)){
beta_a <- beta[a] # input beta from previous descent step
Loss.grad <- sum(-yclass*(1-yhat)*Xtilde[,a] +
(1-yclass)*yhat*Xtilde[,a])
grad[a] <- Loss.grad + lam*beta_a
} # end for loop
grad <- grad/m
return(grad)
}
### gradient descent to optimize beta's
ridge_betas <- function(X,y,beta_init=NULL,lam, alpha=0, method="BFGS"){
if (is.null(beta_init)){beta_init <- rep(.1, ncol(X)+1)}
# method: BFGS, CG, Nelder-Mead
no_penalty_cg <- optim(beta_init, # guess
fn=function(beta){penalized_loss(X, y, beta, lam, alpha=0)}, # objective
gr=function(beta){ridge_grad(X, y, beta, lam)}, # gradient
method = method) #, control= list(trace = 2))
return(list(loss=no_penalty_cg$value, betas = no_penalty_cg$par))
}
# Regression coeffs for LASSO
lasso_betas <- function(X,y,beta_init=NULL){
ridge_betas(X,y,beta_init=beta_init,lam=0,alpha=1,method="BFGS")
}
# Adjust betas
lasso_coeff <- function(X, y, lambda=0.03125, tol=1e-2){
unpen_beta <- lasso_betas(X, y, beta_init=numeric(101))
old_loss <- unpen_beta$loss
lasso_converged <- FALSE
loop_count <- 0
while (!lasso_converged){
beta_LS <- optim(unpen_beta$betas, # guess
fn=function(beta){penalized_loss(X, y, beta, lam=0, alpha=1)}, # objective
gr=function(beta){ridge_grad(X, y, beta, lam=0)}, # gradient
method = "BFGS") #, control= list(trace = 2))
for(i in 1:length(beta_LS$par)){
if(abs(beta_LS$par[i]) <= lambda){ #lambda is 0, so alpha?){
beta_LS$par[i] <- 0
}
else if (beta_LS$par[i] > lambda){
beta_LS$par[i] <- beta_LS$par[i]-lambda
}
else{
beta_LS$par[i] <- beta_LS$par[i]+lambda
}
}
unpen_beta <- lasso_betas(X,y,beta_init=beta_LS$par)
lasso_converged <- abs(unpen_beta$loss - old_loss) < tol
if(mod(loop_count, 25) == 0){
cat("Loop:", loop_count, "Convergence:", abs(unpen_beta$loss - old_loss),"\n")
}
old_loss <- unpen_beta$loss
loop_count <- loop_count + 1
}
print(loop_count)
return(unpen_beta)
}
if (!require("caret")) install.packages("caret")
library(caret)
tune_ridge <- function(X, y, num_folds, tune_grid, verbose=T){
folds <- caret::createFolds(y, k = num_folds)
cv.results <- list()
for (fold.id in seq(1,num_folds)){
te.idx <- folds[[fold.id]]
if (verbose){cat("fold", fold.id, "of",num_folds,"\n")}
if(verbose){cat("\t inner loop over hyperparameters...\n")}
# iterate over hyperparameter
scores <- sapply(tune_grid, # hyp loop var
function(lam){
# train beta's
btrain <- ridge_betas(X[-te.idx,], y[-te.idx],
beta_init = NULL,
lam=lam, method="BFGS")$betas
# get test loss with training beta's
penalized_loss(X[te.idx,], y[te.idx], btrain, lam=lam, alpha=0)
}
) # end sapply hyp loop over hyperparameters
cv.results[[fold.id]] <- scores # scores vector
} # end for folds loop
cv.results <- data.frame(cv.results) # turn list to df
cv.results$means <- rowMeans(as.matrix(cv.results))
cv.results$hyp <- tune_grid
colnames(cv.results) <- c(names(folds),"means","hyp")
#### Select best performance
best.idx <- which.min(cv.results$means) # accuracy
return(list(cv.table = cv.results,
lam.min = cv.results$hyp[best.idx]))
}