projects
/
pysvrt.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
51d45be
)
Replace the numbers of samples by numbers of batches of samples.
author
Francois Fleuret
<francois@fleuret.org>
Thu, 15 Jun 2017 19:24:25 +0000
(21:24 +0200)
committer
Francois Fleuret
<francois@fleuret.org>
Thu, 15 Jun 2017 19:24:25 +0000
(21:24 +0200)
cnn-svrt.py
patch
|
blob
|
history
diff --git
a/cnn-svrt.py
b/cnn-svrt.py
index
e7e4574
..
ab1b363
100755
(executable)
--- a/
cnn-svrt.py
+++ b/
cnn-svrt.py
@@
-44,18
+44,22
@@
parser = argparse.ArgumentParser(
formatter_class = argparse.ArgumentDefaultsHelpFormatter
)
formatter_class = argparse.ArgumentDefaultsHelpFormatter
)
-parser.add_argument('--nb_train_
sampl
es',
- type = int, default = 1000
00
,
+parser.add_argument('--nb_train_
batch
es',
+ type = int, default = 1000,
help = 'How many samples for train')
help = 'How many samples for train')
-parser.add_argument('--nb_test_
sampl
es',
- type = int, default = 100
00
,
+parser.add_argument('--nb_test_
batch
es',
+ type = int, default = 100,
help = 'How many samples for test')
parser.add_argument('--nb_epochs',
type = int, default = 50,
help = 'How many training epochs')
help = 'How many samples for test')
parser.add_argument('--nb_epochs',
type = int, default = 50,
help = 'How many training epochs')
+parser.add_argument('--batch_size',
+ type = int, default = 100,
+ help = 'Mini-batch size')
+
parser.add_argument('--log_file',
type = str, default = 'cnn-svrt.log',
help = 'Log file name')
parser.add_argument('--log_file',
type = str, default = 'cnn-svrt.log',
help = 'Log file name')
@@
-120,12
+124,13
@@
class AfrozeShallowNet(nn.Module):
return x
def train_model(model, train_input, train_target):
return x
def train_model(model, train_input, train_target):
+ bs = args.batch_size
criterion = nn.CrossEntropyLoss()
if torch.cuda.is_available():
criterion.cuda()
criterion = nn.CrossEntropyLoss()
if torch.cuda.is_available():
criterion.cuda()
- optimizer
, bs = optim.SGD(model.parameters(), lr = 1e-2), 100
+ optimizer
= optim.SGD(model.parameters(), lr = 1e-2)
for k in range(0, args.nb_epochs):
acc_loss = 0.0
for k in range(0, args.nb_epochs):
acc_loss = 0.0
@@
-142,9
+147,10
@@
def train_model(model, train_input, train_target):
######################################################################
######################################################################
-def nb_errors(model, data_input, data_target
, bs = 100
):
- ne = 0
+def nb_errors(model, data_input, data_target):
+ bs = args.batch_size
+ ne = 0
for b in range(0, data_input.size(0), bs):
output = model.forward(data_input.narrow(0, b, bs))
wta_prediction = output.data.max(1)[1].view(-1)
for b in range(0, data_input.size(0), bs):
output = model.forward(data_input.narrow(0, b, bs))
wta_prediction = output.data.max(1)[1].view(-1)
@@
-161,8
+167,10
@@
for arg in vars(args):
log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
for problem_number in range(1, 24):
log_string('argument ' + str(arg) + ' ' + str(getattr(args, arg)))
for problem_number in range(1, 24):
- train_input, train_target = generate_set(problem_number, args.nb_train_samples)
- test_input, test_target = generate_set(problem_number, args.nb_test_samples)
+ train_input, train_target = generate_set(problem_number,
+ args.nb_train_batches * args.batch_size)
+ test_input, test_target = generate_set(problem_number,
+ args.nb_test_batches * args.batch_size)
model = AfrozeShallowNet()
if torch.cuda.is_available():
model = AfrozeShallowNet()
if torch.cuda.is_available():