Update.
authorFrançois Fleuret <francois@fleuret.org>
Wed, 31 Jul 2024 05:07:28 +0000 (07:07 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Wed, 31 Jul 2024 05:07:28 +0000 (07:07 +0200)
main.py

diff --git a/main.py b/main.py
index 50e34a8..d50837a 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -107,7 +107,7 @@ parser.add_argument("--nb_rounds", type=int, default=1)
 
 parser.add_argument("--dirty_debug", action="store_true", default=False)
 
-parser.add_argument("--autoencoder_dim", type=int, default=-1)
+parser.add_argument("--test_generator", action="store_true", default=False)
 
 ######################################################################
 
@@ -1023,64 +1023,73 @@ if args.dirty_debug:
 
 ######################################################################
 
-# DIRTY TEST
+if args.test_generator:
+    token_prolog_0 = vocabulary_size + 0
+    token_prolog_1 = vocabulary_size + 1
+    token_prolog_2 = vocabulary_size + 2
+    generator_vocabulary_size = vocabulary_size + 3
 
-# train_complexifier(models[0], models[1], models[2])
+    generator = mygpt.MyGPT(
+        vocabulary_size=generator_vocabulary_size,
+        dim_model=args.dim_model,
+        dim_keys=args.dim_keys,
+        dim_hidden=args.dim_hidden,
+        nb_heads=args.nb_heads,
+        nb_blocks=args.nb_blocks,
+        causal=True,
+        dropout=args.dropout,
+    ).to(main_device)
 
-# exit(0)
+    generator.main_test_accuracy = 0.0
 
-######################################################################
+    for n_epoch in range(args.nb_epochs):
+        one_generator_epoch(
+            generator,
+            quiz_machine=quiz_machine,
+            models=models,
+            w_quizzes=True,
+            local_device=main_device,
+        )
 
-token_prolog_0 = vocabulary_size + 0
-token_prolog_1 = vocabulary_size + 1
-token_prolog_2 = vocabulary_size + 2
-generator_vocabulary_size = vocabulary_size + 3
-
-generator = mygpt.MyGPT(
-    vocabulary_size=generator_vocabulary_size,
-    dim_model=args.dim_model,
-    dim_keys=args.dim_keys,
-    dim_hidden=args.dim_hidden,
-    nb_heads=args.nb_heads,
-    nb_blocks=args.nb_blocks,
-    causal=True,
-    dropout=args.dropout,
-).to(main_device)
-
-generator.main_test_accuracy = 0.0
-
-for n_epoch in range(25):
-    one_generator_epoch(
-        generator,
-        quiz_machine=quiz_machine,
-        models=models,
-        w_quizzes=True,
-        local_device=main_device,
-    )
+        filename = f"generator.pth"
+        torch.save(
+            (generator.state_dict(), generator.main_test_accuracy),
+            os.path.join(args.result_dir, filename),
+        )
+        log_string(f"wrote {filename}")
 
-    c_quizzes = generate_c_quizz_with_generator(
-        generator, quiz_machine, args.batch_size
-    )
+        c_quizzes = generate_c_quizz_with_generator(
+            generator, quiz_machine, args.batch_size
+        )
 
-    seq_logproba = quiz_machine.models_logprobas(
-        models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
-    ) + quiz_machine.models_logprobas(
-        models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1)
-    )
+        seq_logproba = quiz_machine.models_logprobas(
+            models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1)
+        ) + quiz_machine.models_logprobas(
+            models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1)
+        )
 
-    print(seq_logproba.exp())
+        print(seq_logproba.exp())
 
+        comments = []
 
-one_generator_epoch(
-    generator,
-    quiz_machine=quiz_machine,
-    models=models,
-    w_quizzes=False,
-    local_device=main_device,
-)
+        for l in seq_logproba:
+            comments.append("proba " + " ".join([f"{x.exp().item():.02f}" for x in l]))
+
+        filename = f"generator_batch_{n_epoch:04d}.png"
+        quiz_machine.problem.save_quizzes_as_image(
+            args.result_dir, filename, c_quizzes, comments=comments
+        )
+        log_string(f"wrote {filename}")
 
-exit(0)
+    one_generator_epoch(
+        generator,
+        quiz_machine=quiz_machine,
+        models=models,
+        w_quizzes=False,
+        local_device=main_device,
+    )
 
+    exit(0)
 
 ######################################################################