######################################################################
-class Task:
- def batches(self, split="train", nb_to_use=-1, desc=None):
- pass
-
- def vocabulary_size(self):
- pass
-
- def produce_results(
- self, n_epoch, model, result_dir, logger, deterministic_synthesis
- ):
- pass
-
-
-######################################################################
-
-import sky
-
-
-class QuizzMachine(Task):
- def save_image(self, input, result_dir, filename, logger):
- img = sky.seq2img(input.to("cpu"), self.height, self.width)
- image_name = os.path.join(result_dir, filename)
- torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
- logger(f"wrote {image_name}")
-
- def save_quizzes(self, input, result_dir, filename_prefix, logger):
- self.save_image(input, result_dir, filename_prefix + ".png", logger)
-
+class QuizzMachine:
def make_ar_mask(self, input):
b = torch.arange(input.size(1), device=input.device) > input.size(1) // 2
return b.long()[None, :].expand_as(input)
def __init__(
self,
+ problem,
nb_train_samples,
nb_test_samples,
batch_size,
):
super().__init__()
+ self.problem = problem
self.batch_size = batch_size
self.device = device
- self.height = 6
- self.width = 8
- self.train_w_quizzes = sky.generate_seq(
- nb_train_samples, height=self.height, width=self.width
- ).to(device)
-
- self.test_w_quizzes = sky.generate_seq(
- nb_test_samples, height=self.height, width=self.width
- ).to(device)
+ self.train_w_quizzes = self.problem.generate_seq(nb_train_samples).to(device)
+ self.test_w_quizzes = self.problem.generate_seq(nb_test_samples).to(device)
self.nb_codes = max(self.train_w_quizzes.max(), self.test_w_quizzes.max()) + 1
self.test_c_quizzes = []
if result_dir is not None:
- self.save_quizzes(
- self.train_w_quizzes[:72], result_dir, f"culture_w_quizzes", logger
+ self.problem.save_quizzes(
+ self.train_w_quizzes[:72], result_dir, f"culture_w_quizzes"
)
def batches(self, split="train", desc=None):
device=self.device,
)
- self.save_quizzes(
- result[:72],
- result_dir,
- f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
- logger,
+ self.problem.save_quizzes(
+ result[:72], result_dir, f"culture_prediction_{n_epoch:04d}_{model.id:02d}"
)
return main_test_accuracy
input = self.train_w_quizzes if for_train else self.test_w_quizzes
nb = min(nb, input.size(0))
input[:-nb] = input[nb:].clone()
- input[-nb:] = sky.generate_seq(nb, height=self.height, width=self.width).to(
- self.device
- )
+ input[-nb:] = self.problem.generate_seq(nb).to(self.device)
def store_c_quizzes(self, new_c_quizzes, for_train=True):
if for_train:
# Generate quizzes with model
c_quizzes = torch.empty(
- nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64
+ nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
)
ar_mask = torch.full(c_quizzes.size(), 1, device=self.device)
seq_logproba = torch.empty(ar_mask.size(0), device=self.device)
temperature = 1
- d_temperature = 1
+ d_temperature = 1 / 3
while True:
seq_logproba[...] = 0
ave_seq_logproba = seq_logproba.mean()
- logger(f"{ave_seq_logproba=} {min_ave_seq_logproba=}")
-
if min_ave_seq_logproba is None:
break
# Oh man that's ugly
- if ave_seq_logproba < min_ave_seq_logproba * 1.1:
+ if ave_seq_logproba < min_ave_seq_logproba:
if d_temperature > 0:
d_temperature *= -1 / 3
temperature += d_temperature
- elif ave_seq_logproba > min_ave_seq_logproba:
+ elif ave_seq_logproba > min_ave_seq_logproba * 0.99:
if d_temperature < 0:
d_temperature *= -1 / 3
temperature += d_temperature
###############################################################
# Create the reverse quizzes
- l = self.height * self.width
+ token_forward, token_backward = self.problem.direction_tokens()
+
+ l = (c_quizzes.size(1) - 1) // 2
direction = c_quizzes[:, l : l + 1]
- direction = sky.token_forward * (
- direction == sky.token_backward
- ) + sky.token_backward * (direction == sky.token_forward)
+ direction = self.problem.token_forward * (
+ direction == self.problem.token_backward
+ ) + self.problem.token_backward * (direction == self.problem.token_forward)
reverse_c_quizzes = torch.cat(
[c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1
)