From 7a46506f936bad2e136424b68cbd92890d46830c Mon Sep 17 00:00:00 2001 From: Francois Fleuret Date: Thu, 15 Jun 2017 21:24:25 +0200 Subject: [PATCH] Replace the numbers of samples by numbers of batches of samples. --- cnn-svrt.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/cnn-svrt.py b/cnn-svrt.py index e7e4574..ab1b363 100755 --- a/cnn-svrt.py +++ b/cnn-svrt.py @@ -44,18 +44,22 @@ parser = argparse.ArgumentParser( formatter_class = argparse.ArgumentDefaultsHelpFormatter ) -parser.add_argument('--nb_train_samples', - type = int, default = 100000, +parser.add_argument('--nb_train_batches', + type = int, default = 1000, help = 'How many samples for train') -parser.add_argument('--nb_test_samples', - type = int, default = 10000, +parser.add_argument('--nb_test_batches', + type = int, default = 100, help = 'How many samples for test') parser.add_argument('--nb_epochs', type = int, default = 50, help = 'How many training epochs') +parser.add_argument('--batch_size', + type = int, default = 100, + help = 'Mini-batch size') + parser.add_argument('--log_file', type = str, default = 'cnn-svrt.log', help = 'Log file name') @@ -120,12 +124,13 @@ class AfrozeShallowNet(nn.Module): return x def train_model(model, train_input, train_target): + bs = args.batch_size criterion = nn.CrossEntropyLoss() if torch.cuda.is_available(): criterion.cuda() - optimizer, bs = optim.SGD(model.parameters(), lr = 1e-2), 100 + optimizer = optim.SGD(model.parameters(), lr = 1e-2) for k in range(0, args.nb_epochs): acc_loss = 0.0 @@ -142,9 +147,10 @@ def train_model(model, train_input, train_target): ###################################################################### -def nb_errors(model, data_input, data_target, bs = 100): - ne = 0 +def nb_errors(model, data_input, data_target): + bs = args.batch_size + ne = 0 for b in range(0, data_input.size(0), bs): output = model.forward(data_input.narrow(0, b, bs)) wta_prediction = output.data.max(1)[1].view(-1) @@ -161,8 +167,10 @@ for arg in vars(args): log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg))) for problem_number in range(1, 24): - train_input, train_target = generate_set(problem_number, args.nb_train_samples) - test_input, test_target = generate_set(problem_number, args.nb_test_samples) + train_input, train_target = generate_set(problem_number, + args.nb_train_batches * args.batch_size) + test_input, test_target = generate_set(problem_number, + args.nb_test_batches * args.batch_size) model = AfrozeShallowNet() if torch.cuda.is_available(): -- 2.39.5