model.train(t)
-######################################################################
-
-
-class Task:
- def batches(self, split="train", nb_to_use=-1, desc=None):
- pass
-
- def vocabulary_size(self):
- pass
-
- def produce_results(
- self, n_epoch, model, result_dir, logger, deterministic_synthesis
- ):
- pass
-
-
######################################################################
import sky
-class QuizzMachine(Task):
+class QuizzMachine:
def save_image(self, input, result_dir, filename, logger):
- img = sky.seq2img(input.to("cpu"), self.height, self.width)
+ img = self.sky.seq2img(input.to("cpu"))
image_name = os.path.join(result_dir, filename)
torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
logger(f"wrote {image_name}")
):
super().__init__()
+ self.sky = sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2)
self.batch_size = batch_size
self.device = device
- self.height = 6
- self.width = 8
-
- self.train_w_quizzes = sky.generate_seq(
- nb_train_samples, height=self.height, width=self.width
- ).to(device)
- self.test_w_quizzes = sky.generate_seq(
- nb_test_samples, height=self.height, width=self.width
- ).to(device)
+ self.train_w_quizzes = self.sky.generate_seq(nb_train_samples).to(device)
+ self.test_w_quizzes = self.sky.generate_seq(nb_test_samples).to(device)
self.nb_codes = max(self.train_w_quizzes.max(), self.test_w_quizzes.max()) + 1
input = self.train_w_quizzes if for_train else self.test_w_quizzes
nb = min(nb, input.size(0))
input[:-nb] = input[nb:].clone()
- input[-nb:] = sky.generate_seq(nb, height=self.height, width=self.width).to(
- self.device
- )
+ input[-nb:] = self.sky.generate_seq(nb).to(self.device)
def store_c_quizzes(self, new_c_quizzes, for_train=True):
if for_train:
# Generate quizzes with model
c_quizzes = torch.empty(
- nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64
+ nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
)
ar_mask = torch.full(c_quizzes.size(), 1, device=self.device)
seq_logproba = torch.empty(ar_mask.size(0), device=self.device)
temperature = 1
- d_temperature = 1
+ d_temperature = 1 / 3
while True:
seq_logproba[...] = 0
###############################################################
# Create the reverse quizzes
- l = self.height * self.width
+ l = (c_quizzes.size(1) - 1) // 2
direction = c_quizzes[:, l : l + 1]
- direction = sky.token_forward * (
- direction == sky.token_backward
- ) + sky.token_backward * (direction == sky.token_forward)
+ direction = self.sky.token_forward * (
+ direction == self.sky.token_backward
+ ) + self.sky.token_backward * (direction == self.sky.token_forward)
reverse_c_quizzes = torch.cat(
[c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1
)