from torch.nn import functional as F
import ffutils
-import mygpt, quizz_machine
+import mygpt
+import sky, quizz_machine
# world quizzes vs. culture quizzes
assert args.nb_test_samples % args.batch_size == 0
quizz_machine = quizz_machine.QuizzMachine(
+ sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2),
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
batch_size=args.physical_batch_size,
quizz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True)
quizz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False)
- quizz_machine.save_quizzes(
+ quizz_machine.problem.save_quizzes(
new_c_quizzes[:72],
args.result_dir,
f"culture_c_quiz_{n_epoch:04d}_{model.id:02d}",
--- /dev/null
+#!/usr/bin/env python
+
+# Any copyright is dedicated to the Public Domain.
+# https://creativecommons.org/publicdomain/zero/1.0/
+
+# Written by Francois Fleuret <francois@fleuret.org>
+
+
+class Problem:
+ def generate_seq(self, nb):
+ pass
+
+ def save_quizzes(self, input, result_dir, filename_prefix, logger):
+ pass
+
+ def direction_tokens(self):
+ pass
######################################################################
-import sky
-
class QuizzMachine:
def make_ar_mask(self, input):
def __init__(
self,
+ problem,
nb_train_samples,
nb_test_samples,
batch_size,
):
super().__init__()
- self.problem = sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2)
+ self.problem = problem
self.batch_size = batch_size
self.device = device
ave_seq_logproba = seq_logproba.mean()
- logger(f"{ave_seq_logproba=} {min_ave_seq_logproba=}")
-
if min_ave_seq_logproba is None:
break
# Oh man that's ugly
- if ave_seq_logproba < min_ave_seq_logproba * 1.1:
+ if ave_seq_logproba < min_ave_seq_logproba:
if d_temperature > 0:
d_temperature *= -1 / 3
temperature += d_temperature
- elif ave_seq_logproba > min_ave_seq_logproba:
+ elif ave_seq_logproba > min_ave_seq_logproba * 0.99:
if d_temperature < 0:
d_temperature *= -1 / 3
temperature += d_temperature
######################################################################
+import problem
-class Problem:
- def generate_seq(self, nb_train_samples):
- pass
- def save_quizzes(self, input, result_dir, filename_prefix, logger):
- pass
-
- def direction_tokens(self):
- pass
-
-
-class Sky:
+class Sky(problem.Problem):
colors = torch.tensor(
[
[255, 255, 255],