Update.
authorFrançois Fleuret <francois@fleuret.org>
Tue, 25 Jun 2024 16:16:44 +0000 (18:16 +0200)
committerFrançois Fleuret <francois@fleuret.org>
Tue, 25 Jun 2024 16:16:44 +0000 (18:16 +0200)
main.py
problem.py [new file with mode: 0755]
quizz_machine.py
sky.py

diff --git a/main.py b/main.py
index 05c3557..524715a 100755 (executable)
--- a/main.py
+++ b/main.py
@@ -12,7 +12,8 @@ from torch import nn
 from torch.nn import functional as F
 
 import ffutils
-import mygpt, quizz_machine
+import mygpt
+import sky, quizz_machine
 
 # world quizzes vs. culture quizzes
 
@@ -210,6 +211,7 @@ assert args.nb_train_samples % args.batch_size == 0
 assert args.nb_test_samples % args.batch_size == 0
 
 quizz_machine = quizz_machine.QuizzMachine(
+    sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2),
     nb_train_samples=args.nb_train_samples,
     nb_test_samples=args.nb_test_samples,
     batch_size=args.physical_batch_size,
@@ -390,7 +392,7 @@ def create_c_quizzes(
     quizz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True)
     quizz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False)
 
-    quizz_machine.save_quizzes(
+    quizz_machine.problem.save_quizzes(
         new_c_quizzes[:72],
         args.result_dir,
         f"culture_c_quiz_{n_epoch:04d}_{model.id:02d}",
diff --git a/problem.py b/problem.py
new file mode 100755 (executable)
index 0000000..25ffc49
--- /dev/null
@@ -0,0 +1,17 @@
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+
+class Problem:
+    def generate_seq(self, nb):
+        pass
+
+    def save_quizzes(self, input, result_dir, filename_prefix, logger):
+        pass
+
+    def direction_tokens(self):
+        pass
index 28b94d1..d63855c 100755 (executable)
@@ -66,8 +66,6 @@ def masked_inplace_autoregression(
 
 ######################################################################
 
-import sky
-
 
 class QuizzMachine:
     def make_ar_mask(self, input):
@@ -76,6 +74,7 @@ class QuizzMachine:
 
     def __init__(
         self,
+        problem,
         nb_train_samples,
         nb_test_samples,
         batch_size,
@@ -85,7 +84,7 @@ class QuizzMachine:
     ):
         super().__init__()
 
-        self.problem = sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2)
+        self.problem = problem
         self.batch_size = batch_size
         self.device = device
 
@@ -267,17 +266,15 @@ class QuizzMachine:
 
             ave_seq_logproba = seq_logproba.mean()
 
-            logger(f"{ave_seq_logproba=} {min_ave_seq_logproba=}")
-
             if min_ave_seq_logproba is None:
                 break
 
             # Oh man that's ugly
-            if ave_seq_logproba < min_ave_seq_logproba * 1.1:
+            if ave_seq_logproba < min_ave_seq_logproba:
                 if d_temperature > 0:
                     d_temperature *= -1 / 3
                 temperature += d_temperature
-            elif ave_seq_logproba > min_ave_seq_logproba:
+            elif ave_seq_logproba > min_ave_seq_logproba * 0.99:
                 if d_temperature < 0:
                     d_temperature *= -1 / 3
                 temperature += d_temperature
diff --git a/sky.py b/sky.py
index cb25ea0..ec476a6 100755 (executable)
--- a/sky.py
+++ b/sky.py
@@ -14,19 +14,10 @@ from torch.nn import functional as F
 
 ######################################################################
 
+import problem
 
-class Problem:
-    def generate_seq(self, nb_train_samples):
-        pass
 
-    def save_quizzes(self, input, result_dir, filename_prefix, logger):
-        pass
-
-    def direction_tokens(self):
-        pass
-
-
-class Sky:
+class Sky(problem.Problem):
     colors = torch.tensor(
         [
             [255, 255, 255],