import time
import argparse
+
from colorama import Fore, Back, Style
import torch
help = 'How many samples for test')
parser.add_argument('--nb_epochs',
- type = int, default = 25,
+ type = int, default = 100,
help = 'How many training epochs')
parser.add_argument('--log_file',
print(Fore.RED + 'Logging into ' + args.log_file + Style.RESET_ALL)
def log_string(s):
- s = Fore.GREEN + time.ctime() + Style.RESET_ALL + ' ' + \
- str(problem_number) + ' ' + s
+ s = Fore.GREEN + time.ctime() + Style.RESET_ALL + ' ' + s
log_file.write(s + '\n')
log_file.flush()
print(s)
# Afroze's ShallowNet
-# map size nb. maps
-# ----------------------
-# 128x128 1
-# -- conv(21x21) -> 108x108 6
-# -- max(2x2) -> 54x54 6
-# -- conv(19x19) -> 36x36 16
-# -- max(2x2) -> 18x18 16
-# -- conv(18x18) -> 1x1 120
-# -- reshape -> 120 1
-# -- full(120x84) -> 84 1
-# -- full(84x2) -> 2 1
+# map size nb. maps
+# ----------------------
+# input 128x128 1
+# -- conv(21x21 x 6) -> 108x108 6
+# -- max(2x2) -> 54x54 6
+# -- conv(19x19 x 16) -> 36x36 16
+# -- max(2x2) -> 18x18 16
+# -- conv(18x18 x 120) -> 1x1 120
+# -- reshape -> 120 1
+# -- full(120x84) -> 84 1
+# -- full(84x2) -> 2 1
class Net(nn.Module):
def __init__(self):
######################################################################
+for arg in vars(args):
+ log_string('ARGUMENT ' + str(arg) + ' ' + str(getattr(args, arg)))
+
for problem_number in range(1, 24):
train_input, train_target = generate_set(problem_number, args.nb_train_samples)
test_input, test_target = generate_set(problem_number, args.nb_test_samples)
nb_train_errors = nb_errors(model, train_input, train_target)
- log_string('TRAIN_ERROR {:.02f}% {:d} {:d}'.format(
+ log_string('TRAIN_ERROR {:d} {:.02f}% {:d} {:d}'.format(
+ problem_number,
100 * nb_train_errors / train_input.size(0),
nb_train_errors,
train_input.size(0))
nb_test_errors = nb_errors(model, test_input, test_target)
- log_string('TEST_ERROR {:.02f}% {:d} {:d}'.format(
+ log_string('TEST_ERROR {:d} {:.02f}% {:d} {:d}'.format(
+ problem_number,
100 * nb_test_errors / test_input.size(0),
nb_test_errors,
test_input.size(0))