这是indexloc提供的服务,不要输入任何密码
Skip to content
This repository was archived by the owner on Aug 3, 2021. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file modified .gitignore
100644 → 100755
Empty file.
Empty file modified AutoEncoder.png
100644 → 100755
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file modified LICENSE
100644 → 100755
Empty file.
Empty file modified README.md
100644 → 100755
Empty file.
Empty file modified azkaban/AutoRec/ConstrainedRecoEncoder.job
100644 → 100755
Empty file.
Empty file modified azkaban/AutoRec/ConstrainedRecoEncoderNoLastLayerNl.job
100644 → 100755
Empty file.
Empty file modified azkaban/AutoRec/RecoEncoder.job
100644 → 100755
Empty file.
Empty file modified azkaban/AutoRec/RecoEncoderNoLastLayerNl.job
100644 → 100755
Empty file.
Empty file modified azkaban/AutoRec/done.job
100644 → 100755
Empty file.
Empty file modified azkaban/AutoRec/netflix_data_preprocess.job
100644 → 100755
Empty file.
Empty file modified azkaban/AutoRecAllSplits/RecoEncoder1Y.job
100644 → 100755
Empty file.
Empty file modified azkaban/AutoRecAllSplits/RecoEncoderN3m.job
100644 → 100755
Empty file.
Empty file modified azkaban/AutoRecAllSplits/RecoEncoderN6m.job
100644 → 100755
Empty file.
Empty file modified azkaban/AutoRecAllSplits/RecoEncoderNF.job
100644 → 100755
Empty file.
Empty file modified azkaban/AutoRecAllSplits/done.job
100644 → 100755
Empty file.
Empty file modified azkaban/AutoRecAllSplits/netflix_data_preprocess.job
100644 → 100755
Empty file.
Empty file modified compute_RMSE.py
100644 → 100755
Empty file.
Empty file modified data_utils/movie_lense_data_converter.py
100644 → 100755
Empty file.
Empty file modified data_utils/netflix_data_convert.py
100644 → 100755
Empty file.
Empty file modified infer.py
100644 → 100755
Empty file.
Empty file modified reco_encoder/__init__.py
100644 → 100755
Empty file.
Empty file modified reco_encoder/data/__init__.py
100644 → 100755
Empty file.
Empty file modified reco_encoder/data/input_layer.py
100644 → 100755
Empty file.
Empty file modified reco_encoder/model/__init__.py
100644 → 100755
Empty file.
21 changes: 11 additions & 10 deletions reco_encoder/model/model.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,32 @@
def activation(input, kind):
#print("Activation: {}".format(kind))
if kind == 'selu':
return F.selu(input)
res = F.selu(input)
elif kind == 'relu':
return F.relu(input)
res = F.relu(input)
elif kind == 'relu6':
return F.relu6(input)
res = F.relu6(input)
elif kind == 'sigmoid':
return F.sigmoid(input)
res = F.sigmoid(input)
elif kind == 'tanh':
return F.tanh(input)
res = F.tanh(input)
elif kind == 'elu':
return F.elu(input)
res = F.elu(input)
elif kind == 'lrelu':
return F.leaky_relu(input)
res = F.leaky_relu(input)
elif kind == 'swish':
return input*F.sigmoid(input)
res = input*F.sigmoid(input)
elif kind == 'none':
return input
res = input
else:
raise ValueError('Unknown non-linearity type')
return res

def MSEloss(inputs, targets, size_avarage=False):
mask = targets != 0
num_ratings = torch.sum(mask.float())
criterion = nn.MSELoss(size_average=size_avarage)
return criterion(inputs * mask.float(), targets), Variable(torch.Tensor([1.0])) if size_avarage else num_ratings
return criterion(inputs.float() * mask.float(), targets.float()), Variable(torch.Tensor([1.0])) if size_avarage else num_ratings

class AutoEncoder(nn.Module):
def __init__(self, layer_sizes, nl_type='selu', is_constrained=True, dp_drop_prob=0.0, last_layer_activations=True):
Expand Down
100 changes: 79 additions & 21 deletions run.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from math import sqrt
import numpy as np
import os
import sys

scale_factor = 128.0

parser = argparse.ArgumentParser(description='RecoEncoder')
parser.add_argument('--lr', type=float, default=0.00001, metavar='N',
Expand Down Expand Up @@ -65,12 +68,12 @@ def do_eval(encoder, evaluation_data_layer):
denom = 0.0
total_epoch_loss = 0.0
for i, (eval, src) in enumerate(evaluation_data_layer.iterate_one_epoch_eval()):
inputs = Variable(src.cuda().to_dense() if use_gpu else src.to_dense())
targets = Variable(eval.cuda().to_dense() if use_gpu else eval.to_dense())
inputs = Variable(src.cuda().to_dense().half() if use_gpu else src.to_dense().half())
targets = Variable(eval.cuda().to_dense().half() if use_gpu else eval.to_dense().half())
outputs = encoder(inputs)
loss, num_ratings = model.MSEloss(outputs, targets)
total_epoch_loss += loss.data[0]
denom += num_ratings.data[0]
denom += num_ratings.data.float()[0]
return sqrt(total_epoch_loss / denom)

def log_var_and_grad_summaries(logger, layers, global_step, prefix, log_histograms=False):
Expand All @@ -85,20 +88,45 @@ def log_var_and_grad_summaries(logger, layers, global_step, prefix, log_histogra
"""
for ind, w in enumerate(layers):
# Variables
w_var = w.data.cpu().numpy()
w_var = w.data.float().cpu().numpy()
logger.scalar_summary("Variables/FrobNorm/{}_{}".format(prefix, ind), np.linalg.norm(w_var),
global_step)
if log_histograms:
logger.histo_summary(tag="Variables/{}_{}".format(prefix, ind), values=w.data.cpu().numpy(),
logger.histo_summary(tag="Variablmodeles/{}_{}".format(prefix, ind), values=w.data.cpu().numpy(),
step=global_step)

# Gradients
w_grad = w.grad.data.cpu().numpy()
w_grad = w.grad.float().data.cpu().numpy()
logger.scalar_summary("Gradients/FrobNorm/{}_{}".format(prefix, ind), np.linalg.norm(w_grad),
global_step)
if log_histograms:
logger.histo_summary(tag="Gradients/{}_{}".format(prefix, ind), values=w.grad.data.cpu().numpy(),
step=global_step)
logger.histo_summary(tag="Gradients/{}_{}".format(prefix, ind), values=w.grad.float().data.cpu().numpy(),
step=global_step)


######
def prep_param_lists(model):
model_params = [p for p in model.parameters() if p.requires_grad]
master_params = [p.clone().float().detach() for p in model_params]
for p in master_params:
p.requires_grad = True
return model_params, master_params

def master_params_to_model_params(model_params, master_params):
for model, master in zip(model_params, master_params):
model.data.copy_(master.data)

def model_grads_to_master_grads(model_params, master_params):
for model, master in zip(model_params, master_params):
if master.grad is None:
master.grad = Variable(
master.data.new(*master.data.size())
)
master.grad.data.copy_(model.grad.data)

######



def main():
logger = Logger(args.logdir)
Expand Down Expand Up @@ -127,6 +155,8 @@ def main():
is_constrained=args.constrained,
dp_drop_prob=args.drop_prob,
last_layer_activations=not args.skip_last_layer_nl)


os.makedirs(args.logdir, exist_ok=True)
model_checkpoint = args.logdir + "/model"
path_to_model = Path(model_checkpoint)
Expand All @@ -147,27 +177,31 @@ def main():
rencoder = nn.DataParallel(rencoder,
device_ids=gpu_ids)

if use_gpu: rencoder = rencoder.cuda()
if use_gpu: rencoder = rencoder.cuda().half()

##########
model_params, master_params = prep_param_lists(rencoder)
##########

if args.optimizer == "adam":
optimizer = optim.Adam(rencoder.parameters(),
optimizer = optim.Adam(master_params,#rencoder.parameters(),
lr=args.lr,
weight_decay=args.weight_decay)
elif args.optimizer == "adagrad":
optimizer = optim.Adagrad(rencoder.parameters(),
optimizer = optim.Adagrad(master_params, #rencoder.parameters(),
lr=args.lr,
weight_decay=args.weight_decay)
elif args.optimizer == "momentum":
optimizer = optim.SGD(rencoder.parameters(),
optimizer = optim.SGD(master_params,#rencoder.parameters(),
lr=args.lr, momentum=0.9,
weight_decay=args.weight_decay)
scheduler = MultiStepLR(optimizer, milestones=[24, 36, 48, 66, 72], gamma=0.5)
elif args.optimizer == "rmsprop":
optimizer = optim.RMSprop(rencoder.parameters(),
optimizer = optim.RMSprop(master_params,#rencoder.parameters(),
lr=args.lr, momentum=0.9,
weight_decay=args.weight_decay)
else:
raise ValueError('Unknown optimizer kind')
raise ValueError('Unknown optimizer kind')

t_loss = 0.0
t_loss_denom = 0.0
Expand All @@ -185,19 +219,32 @@ def main():
if args.optimizer == "momentum":
scheduler.step()
for i, mb in enumerate(data_layer.iterate_one_epoch()):
inputs = Variable(mb.cuda().to_dense() if use_gpu else mb.to_dense())
optimizer.zero_grad()
inputs = Variable(mb.cuda().to_dense().half() if use_gpu else mb.to_dense())
rencoder.zero_grad()
outputs = rencoder(inputs)
loss, num_ratings = model.MSEloss(outputs, inputs)
loss = loss / num_ratings
loss.backward()
loss = loss / num_ratings.float()
scaled_loss = scale_factor * loss.float()
scaled_loss.backward()
#loss.backward()

##
model_grads_to_master_grads(model_params, master_params)
for param in master_params:
param.grad.data.mul_(1./scale_factor)
##
optimizer.step()
##
master_params_to_model_params(model_params, master_params)
##

global_step += 1
t_loss += loss.data[0]
t_loss_denom += 1

if i % args.summary_frequency == 0:
print('[%d, %5d] RMSE: %.7f' % (epoch, i, sqrt(t_loss / t_loss_denom)))
sys.stdout.flush()
logger.scalar_summary("Training_RMSE", sqrt(t_loss/t_loss_denom), global_step)
t_loss = 0
t_loss_denom = 0.0
Expand All @@ -217,16 +264,27 @@ def main():
inputs = Variable(outputs.data)
if args.noise_prob > 0.0:
inputs = dp(inputs)
optimizer.zero_grad()
rencoder.zero_grad()
outputs = rencoder(inputs)
loss, num_ratings = model.MSEloss(outputs, inputs)
loss = loss / num_ratings
loss.backward()
loss = loss / num_ratings.float()
scaled_loss = scale_factor * loss.float()
scaled_loss.backward()
#loss.backward()
##
model_grads_to_master_grads(model_params, master_params)
for param in master_params:
param.grad.data.mul_(1. / scale_factor)
##
optimizer.step()
##
master_params_to_model_params(model_params, master_params)
##

e_end_time = time.time()
print('Total epoch {} finished in {} seconds with TRAINING RMSE loss: {}'
.format(epoch, e_end_time - e_start_time, sqrt(total_epoch_loss/denom)))
sys.stdout.flush()
logger.scalar_summary("Training_RMSE_per_epoch", sqrt(total_epoch_loss/denom), epoch)
logger.scalar_summary("Epoch_time", e_end_time - e_start_time, epoch)
if epoch % 3 == 0 or epoch == args.num_epochs - 1:
Expand Down
Empty file modified test/__init__.py
100644 → 100755
Empty file.
Empty file modified test/data_layer_tests.py
100644 → 100755
Empty file.
Empty file.
Empty file modified test/testData_iRec/_SUCCESS
100644 → 100755
Empty file.
Empty file.
Empty file.
Empty file modified test/testData_uRec/._SUCCESS.crc
100644 → 100755
Empty file.
Empty file.
Empty file.
Empty file.
Empty file.
Empty file modified test/test_model.py
100644 → 100755
Empty file.