X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=cnn-svrt.py;h=cfef09da934fb5d82181715f428dab1a261d04ba;hb=37030c396217cfe89f8dfa2b9e10ff1ec783a5a7;hp=153bdc9d23a18a7abe67cfbe3f72246a5ee2fa83;hpb=d21f7d8eecb12aa4cc60360db6aa33324327e987;p=pysvrt.git diff --git a/cnn-svrt.py b/cnn-svrt.py index 153bdc9..cfef09d 100755 --- a/cnn-svrt.py +++ b/cnn-svrt.py @@ -255,9 +255,9 @@ for arg in vars(args): ###################################################################### def int_to_suffix(n): - if n > 1000000 and n%1000000 == 0: + if n >= 1000000 and n%1000000 == 0: return str(n//1000000) + 'M' - elif n > 1000 and n%1000 == 0: + elif n >= 1000 and n%1000 == 0: return str(n//1000) + 'K' else: return str(n) @@ -284,8 +284,8 @@ for problem_number in range(1, 24): if torch.cuda.is_available(): model.cuda() - model_filename = model.name + '_' + \ - str(problem_number) + '_' + \ + model_filename = model.name + '_pb:' + \ + str(problem_number) + '_ns:' + \ int_to_suffix(args.nb_train_samples) + '.param' nb_parameters = 0