X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=main.py;h=ca0d1524b04d740c8aaa38fa970b503d868bdba8;hb=d16410119a4e5c1117f7f0fbbe80e3e54f81f28b;hp=18b19db81d59927ac6d25d575999b91d5b42795f;hpb=0ab695df8f6a2a0cc70a424e57943a0d5606903b;p=culture.git diff --git a/main.py b/main.py index 18b19db..ca0d152 100755 --- a/main.py +++ b/main.py @@ -474,6 +474,7 @@ elif args.task == "world": nb_train_samples=args.nb_train_samples, nb_test_samples=args.nb_test_samples, batch_size=args.physical_batch_size, + result_dir=args.result_dir, logger=log_string, device=device, ) @@ -853,7 +854,7 @@ def one_epoch(model, task, learning_rate): train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) - log_string(f"train)perplexity {n_epoch} {train_perplexity}") + log_string(f"train_perplexity {n_epoch} {train_perplexity}") ###################################################################### @@ -878,7 +879,7 @@ def run_tests(model, task, deterministic_synthesis): nb_test_samples += input.size(0) - task.produce_results( + main_test_accuracy = task.produce_results( n_epoch=n_epoch, model=model, result_dir=args.result_dir, @@ -887,7 +888,10 @@ def run_tests(model, task, deterministic_synthesis): ) test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples)) - log_string(f"test)perplexity {n_epoch} {test_perplexity}") + + log_string(f"test_perplexity {n_epoch} {test_perplexity}") + + return main_test_accuracy ###################################################################### @@ -897,7 +901,39 @@ for n_epoch in range(nb_epochs_finished, args.nb_epochs): one_epoch(model, task, learning_rate) - run_tests(model, task, deterministic_synthesis=True) + test_accuracy = run_tests(model, task, deterministic_synthesis=False) + + # -------------------------------------------- + + if test_accuracy >= 0.8: + nb_for_train, nb_for_test = 1000, 100 + kept = [] + + while sum([x.size(0) for x in kept]) < nb_for_train + nb_for_test: + new_quizzes, nb_correct = task.create_new_quizzes( + n_epoch=n_epoch, + result_dir=args.result_dir, + logger=log_string, + nb=nb_required, + model=model, + nb_runs=10, + ) + + to_keep = new_quizzes[torch.logical_and(nb_correct >= 8, nb_correct < 10)] + log_string(f"keep {to_keep.size(0)} quizzes") + kept.append(to_keep) + + new_quizzes = torch.cat(kept, dim=0)[: nb_for_train + nb_for_test] + + task.store_new_quizzes(new_quizzes[:nb_for_train], train=True) + task.store_new_quizzes(new_quizzes[nb_for_train:], train=False) + + task.save_image( + new_quizzes[:96], + args.result_dir, + f"world_new_{n_epoch:04d}.png", + log_string, + ) # --------------------------------------------