from mygpt import BracketedSequence
-# from graph import save_attention_image
-save_attention_image = None
-
######################################################################
def save_image(self, input, result_dir, filename, logger):
img = world.sample2img(input.to("cpu"), self.height, self.width)
image_name = os.path.join(result_dir, filename)
- torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=8, padding=2)
+ torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
logger(f"wrote {image_name}")
def make_ar_mask(self, input):
self.batch_size = batch_size
self.device = device
- self.height = 6
- self.width = 8
+ self.height = 7
+ self.width = 9
self.train_input = world.generate(
nb_train_samples, height=self.height, width=self.width
nb_test_samples, height=self.height, width=self.width
).to(device)
+ # print()
+ # for a in world.seq2str(self.train_input):
+ # print(a)
+ # for a in world.seq2str(self.test_input):
+ # print(a)
+ # exit(0)
+
self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
self.train_quizzes = []
if result_dir is not None:
self.save_image(
- self.train_input[:96], result_dir, f"world_train.png", logger
+ self.train_input[:72], result_dir, f"world_train.png", logger
)
def batches(self, split="train", desc=None):
)
self.save_image(
- result[:96],
+ result[:72],
result_dir,
f"world_prediction_{n_epoch:04d}_{model.id:02d}.png",
logger,
model,
other_models,
):
- new_quizzes = torch.empty(
+ ###############################################################
+ # Generate quizzes with model
+
+ quizzes = torch.empty(
nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64
)
- ar_mask = torch.full(new_quizzes.size(), 1, device=self.device)
+ ar_mask = torch.full(quizzes.size(), 1, device=self.device)
masked_inplace_autoregression(
model,
self.batch_size,
- new_quizzes,
+ quizzes,
ar_mask,
deterministic_synthesis=False,
- progress_bar_desc="new quizzes",
+ progress_bar_desc="creating quizzes",
device=self.device,
)
- ar_mask = self.make_ar_mask(new_quizzes)
+ ###############################################################
+ # Create the reverse quizzes
+
+ l = self.height * self.width
+ direction = quizzes[:, l : l + 1]
+ direction = world.token_forward * (
+ direction == world.token_backward
+ ) + world.token_backward * (direction == world.token_forward)
+ reverse_quizzes = torch.cat(
+ [quizzes[:, l + 1 :], direction, quizzes[:, :l]], dim=1
+ )
+
+ ar_mask = self.make_ar_mask(quizzes)
- nb_correct = 0
+ ###############################################################
+ # Check how many of the other models can solve them in both
+ # directions
+
+ nb_correct = []
for m in other_models:
- result = new_quizzes.clone()
+ result = quizzes.clone()
masked_inplace_autoregression(
m,
device=self.device,
)
- l = self.height * self.width
- direction = new_quizzes[:, l : l + 1]
- direction = world.token_forward * (
- direction == world.token_backward
- ) + world.token_backward * (direction == world.token_forward)
- inverted_quizzes = torch.cat(
- [new_quizzes[:, l + 1 :], direction, new_quizzes[:, :l]], dim=1
- )
+ correct = (quizzes == result).long().min(dim=-1).values
- inverted_result = inverted_quizzes.clone()
+ reverse_result = reverse_quizzes.clone()
masked_inplace_autoregression(
m,
self.batch_size,
- inverted_result,
+ reverse_result,
ar_mask,
deterministic_synthesis=True,
- progress_bar_desc="solving reverse quizzes",
+ progress_bar_desc="solving reversed quizzes",
device=self.device,
)
- nb_correct += (new_quizzes == result).long().min(dim=-1).values * (
- inverted_quizzes == inverted_result
- ).long().min(dim=-1).values
+ reverse_correct = (
+ (reverse_quizzes == reverse_result).long().min(dim=-1).values
+ )
+
+ nb_correct.append((correct * reverse_correct)[None, :])
+
+ nb_correct = torch.cat(nb_correct, dim=0)
+
+ filename = os.path.join(result_dir, "correct_{n_epoch:04d}.dat")
+ with open(filename, "w") as f:
+ for k in nb_correct:
+ f.write(f"{k}\n")
- return new_quizzes, nb_correct
+ return quizzes, nb_correct.sum(dim=0)