batch_size,
input,
ar_mask,
- seq_logproba,
+ summed_logits,
temperature,
deterministic_synthesis,
forbidden_tokens=None,
model.masked_inplace_autoregression(
input=input,
ar_mask=ar_mask,
- seq_logproba=seq_logproba,
+ summed_logits=summed_logits,
temperature=temperature,
deterministic_synthesis=deterministic_synthesis,
forbidden_tokens=forbidden_tokens,
torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
logger(f"wrote {image_name}")
- def save_quizzes(self, input, result_dir, filename_prefix, logger):
- self.save_image(input, result_dir, filename_prefix + ".png", logger)
-
def make_ar_mask(self, input):
b = torch.arange(input.size(1), device=input.device) > input.size(1) // 2
return b.long()[None, :].expand_as(input)
self.height = 6
self.width = 8
- self.train_w_quizzes = world.generate_seq(
+ self.train_input = world.generate_seq(
nb_train_samples, height=self.height, width=self.width
).to(device)
- self.test_w_quizzes = world.generate_seq(
+ self.test_input = world.generate_seq(
nb_test_samples, height=self.height, width=self.width
).to(device)
- self.nb_codes = max(self.train_w_quizzes.max(), self.test_w_quizzes.max()) + 1
+ self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
- self.train_c_quizzes = []
- self.test_c_quizzes = []
+ self.train_quizzes = []
+ self.test_quizzes = []
if result_dir is not None:
- self.save_quizzes(
- self.train_w_quizzes[:72], result_dir, f"culture_w_quizzes", logger
+ self.save_image(
+ self.train_input[:72], result_dir, f"world_train.png", logger
)
def batches(self, split="train", desc=None):
assert split in {"train", "test"}
if split == "train":
- w_quizzes = self.train_w_quizzes
- c_quizzes = self.train_c_quizzes
+ input = self.train_input
+ quizzes = self.train_quizzes
else:
- w_quizzes = self.test_w_quizzes
- c_quizzes = self.test_c_quizzes
+ input = self.test_input
+ quizzes = self.test_quizzes
- if len(c_quizzes) > 0:
- c_quizzes = torch.cat(c_quizzes, dim=0)
- if c_quizzes.size(0) > w_quizzes.size(0) // 2:
- i = torch.randperm(w_quizzes.size(0))[: w_quizzes.size(0) // 2]
- c_quizzes = c_quizzes[i]
+ if len(quizzes) > 0:
+ quizzes = torch.cat(quizzes, dim=0)
+ if quizzes.size(0) > input.size(0) // 2:
+ i = torch.randperm(input.size(0))[: input.size(0) // 2]
+ quizzes = quizzes[i]
- i = torch.randperm(w_quizzes.size(0))[
- : w_quizzes.size(0) - c_quizzes.size(0)
- ]
- w_quizzes = w_quizzes[i]
+ i = torch.randperm(input.size(0))[: input.size(0) - quizzes.size(0)]
+ input = input[i]
- self.nb_batch_w_quizzes = w_quizzes.size(0)
- self.nb_batch_c_quizzes = c_quizzes.size(0)
+ self.nb_batch_samples_world = input.size(0)
+ self.nb_batch_samples_quizzes = quizzes.size(0)
- input = torch.cat([w_quizzes, c_quizzes], dim=0)
+ input = torch.cat([input, quizzes], dim=0)
else:
- input = w_quizzes
- self.nb_batch_w_quizzes = w_quizzes.size(0)
- self.nb_batch_c_quizzes = 0
+ self.nb_batch_samples_world = input.size(0)
+ self.nb_batch_samples_quizzes = 0
# Shuffle
input = input[torch.randperm(input.size(0))]
batch_size=self.batch_size,
input=result,
ar_mask=ar_mask,
- seq_logproba=None,
+ summed_logits=None,
temperature=1.0,
deterministic_synthesis=deterministic_synthesis,
progress_bar_desc=None,
return nb_total, nb_correct
- train_nb_total, train_nb_correct = compute_accuracy(self.train_w_quizzes)
+ train_nb_total, train_nb_correct = compute_accuracy(self.train_input)
logger(
f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
)
- test_nb_total, test_nb_correct = compute_accuracy(self.test_w_quizzes, logger)
+ test_nb_total, test_nb_correct = compute_accuracy(self.test_input, logger)
logger(
f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
##############################
- input = self.test_w_quizzes[:96]
+ input = self.test_input[:96]
ar_mask = self.make_ar_mask(input)
result = input.clone() * (1 - ar_mask)
batch_size=self.batch_size,
input=result,
ar_mask=ar_mask,
- seq_logproba=None,
+ summed_logits=None,
temperature=1.0,
deterministic_synthesis=deterministic_synthesis,
progress_bar_desc=None,
device=self.device,
)
- self.save_quizzes(
+ self.save_image(
result[:72],
result_dir,
- f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
+ f"world_prediction_{n_epoch:04d}_{model.id:02d}.png",
logger,
)
return main_test_accuracy
- def renew_w_quizzes(self, nb, for_train=True):
- input = self.train_w_quizzes if for_train else self.test_w_quizzes
+ def renew_samples(self, nb, for_train=True):
+ input = self.train_input if for_train else self.test_input
nb = min(nb, input.size(0))
input[:-nb] = input[nb:].clone()
input[-nb:] = world.generate_seq(nb, height=self.height, width=self.width).to(
self.device
)
- def store_c_quizzes(self, new_c_quizzes, for_train=True):
+ def store_new_quizzes(self, new_quizzes, for_train=True):
if for_train:
- self.train_c_quizzes.append(new_c_quizzes)
+ self.train_quizzes.append(new_quizzes)
else:
- self.test_c_quizzes.append(new_c_quizzes)
+ self.test_quizzes.append(new_quizzes)
- def create_c_quizzes(
+ def create_new_quizzes(
self,
n_epoch,
result_dir,
nb,
model,
other_models,
- min_ave_seq_logproba=None,
+ desired_average_logits=None,
):
###############################################################
# Generate quizzes with model
- c_quizzes = torch.empty(
+ quizzes = torch.empty(
nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64
)
- ar_mask = torch.full(c_quizzes.size(), 1, device=self.device)
- seq_logproba = torch.empty(nb, device=self.device)
+ ar_mask = torch.full(quizzes.size(), 1, device=self.device)
+ summed_logits = torch.empty(nb, device=self.device)
temperature = 1
d_temperature = 1
while True:
- seq_logproba[...] = 0
+ summed_logits[...] = 0
masked_inplace_autoregression(
model=model,
batch_size=self.batch_size,
- input=c_quizzes,
+ input=quizzes,
ar_mask=ar_mask,
- seq_logproba=seq_logproba,
+ summed_logits=summed_logits,
temperature=temperature,
deterministic_synthesis=False,
- progress_bar_desc="sampling c_quizzes",
+ progress_bar_desc="creating quizzes",
device=self.device,
)
- ave_seq_logproba = seq_logproba.mean()
+ average_logits = summed_logits.mean()
- logger(f"{ave_seq_logproba=} {min_ave_seq_logproba=}")
+ logger(f"{average_logits=} {desired_average_logits=}")
- if min_ave_seq_logproba is None:
+ if desired_average_logits is None:
break
# Oh man that's ugly
- if ave_seq_logproba < min_ave_seq_logproba * 1.1:
+ if average_logits < desired_average_logits * 1.1:
if d_temperature > 0:
d_temperature *= -0.5
temperature += d_temperature
- elif ave_seq_logproba > min_ave_seq_logproba:
+ elif average_logits > desired_average_logits:
if d_temperature < 0:
d_temperature *= -0.5
temperature += d_temperature
else:
break
- logger(f"chaging temperature to {temperature}")
+ logger(f"changing temperature to {temperature}")
###############################################################
# Create the reverse quizzes
l = self.height * self.width
- direction = c_quizzes[:, l : l + 1]
+ direction = quizzes[:, l : l + 1]
direction = world.token_forward * (
direction == world.token_backward
) + world.token_backward * (direction == world.token_forward)
- reverse_c_quizzes = torch.cat(
- [c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1
+ reverse_quizzes = torch.cat(
+ [quizzes[:, l + 1 :], direction, quizzes[:, :l]], dim=1
)
- ar_mask = self.make_ar_mask(c_quizzes)
+ ar_mask = self.make_ar_mask(quizzes)
###############################################################
# Check how many of the other models can solve them in both
nb_correct = []
for m in other_models:
- result = c_quizzes.clone()
+ result = quizzes.clone()
masked_inplace_autoregression(
model=m,
batch_size=self.batch_size,
input=result,
ar_mask=ar_mask,
- seq_logproba=None,
+ summed_logits=None,
temperature=1.0,
deterministic_synthesis=True,
- progress_bar_desc="solving c_quizzes",
+ progress_bar_desc="solving quizzes",
device=self.device,
)
- correct = (c_quizzes == result).long().min(dim=-1).values
+ correct = (quizzes == result).long().min(dim=-1).values
- reverse_result = reverse_c_quizzes.clone()
+ reverse_result = reverse_quizzes.clone()
masked_inplace_autoregression(
model=m,
batch_size=self.batch_size,
input=reverse_result,
ar_mask=ar_mask,
- seq_logproba=None,
+ summed_logits=None,
temperature=1.0,
deterministic_synthesis=True,
- progress_bar_desc="solving reversed c_quizzes",
+ progress_bar_desc="solving reversed quizzes",
device=self.device,
)
reverse_correct = (
- (reverse_c_quizzes == reverse_result).long().min(dim=-1).values
+ (reverse_quizzes == reverse_result).long().min(dim=-1).values
)
nb_correct.append((correct * reverse_correct)[None, :])
- nb_correct = torch.cat(nb_correct, dim=0).sum(dim=0)
+ nb_correct = torch.cat(nb_correct, dim=0)
+
+ # filename = os.path.join(result_dir, "correct_{n_epoch:04d}.dat")
+ # with open(filename, "w") as f:
+ # for k in nb_correct:
+ # f.write(f"{k}\n")
- return c_quizzes, nb_correct, seq_logproba.mean()
+ return quizzes, nb_correct.sum(dim=0), summed_logits.mean()