batch_size,
input,
ar_mask,
+ temperature,
deterministic_synthesis,
forbidden_tokens=None,
logit_biases=None,
total=(input.size(0) + batch_size - 1) // batch_size,
)
+ sum_logits = 0
+
with torch.autograd.no_grad():
t = model.training
model.eval()
for input, ar_mask in batches:
- model.masked_inplace_autoregression(
- input,
- ar_mask,
- deterministic_synthesis,
- forbidden_tokens,
- logit_biases,
+ sum_logits += model.masked_inplace_autoregression(
+ input=input,
+ ar_mask=ar_mask,
+ temperature=temperature,
+ deterministic_synthesis=deterministic_synthesis,
+ forbidden_tokens=forbidden_tokens,
+ forced_biases=logit_biases,
)
model.train(t)
+ return sum_logits
+
######################################################################
class World(Task):
def save_image(self, input, result_dir, filename, logger):
- img = world.sample2img(input.to("cpu"), self.height, self.width)
+ img = world.seq2img(input.to("cpu"), self.height, self.width)
image_name = os.path.join(result_dir, filename)
- torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=8, padding=2)
+ torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
logger(f"wrote {image_name}")
def make_ar_mask(self, input):
self.height = 6
self.width = 8
- self.train_input = world.generate(
+ self.train_input = world.generate_seq(
nb_train_samples, height=self.height, width=self.width
).to(device)
- self.test_input = world.generate(
+ self.test_input = world.generate_seq(
nb_test_samples, height=self.height, width=self.width
).to(device)
if result_dir is not None:
self.save_image(
- self.train_input[:96], result_dir, f"world_train.png", logger
+ self.train_input[:72], result_dir, f"world_train.png", logger
)
def batches(self, split="train", desc=None):
self.nb_batch_samples_world = input.size(0)
self.nb_batch_samples_quizzes = 0
+ # Shuffle
+ input = input[torch.randperm(input.size(0))]
+
if desc is None:
desc = f"epoch-{split}"
for batch in tqdm.tqdm(
result = input.clone() * (1 - ar_mask)
masked_inplace_autoregression(
- model,
- self.batch_size,
- result,
- ar_mask,
- deterministic_synthesis,
+ model=model,
+ batch_size=self.batch_size,
+ input=result,
+ ar_mask=ar_mask,
+ temperature=1.0,
+ deterministic_synthesis=deterministic_synthesis,
progress_bar_desc=None,
device=self.device,
)
result = input.clone() * (1 - ar_mask)
masked_inplace_autoregression(
- model,
- self.batch_size,
- result,
- ar_mask,
- deterministic_synthesis,
+ model=model,
+ batch_size=self.batch_size,
+ input=result,
+ ar_mask=ar_mask,
+ temperature=1.0,
+ deterministic_synthesis=deterministic_synthesis,
progress_bar_desc=None,
device=self.device,
)
self.save_image(
- result[:96],
+ result[:72],
result_dir,
f"world_prediction_{n_epoch:04d}_{model.id:02d}.png",
logger,
return main_test_accuracy
+ def renew_samples(self, nb, for_train=True):
+ input = self.train_input if for_train else self.test_input
+ nb = min(nb, input.size(0))
+ input[:-nb] = input[nb:].clone()
+ input[-nb:] = world.generate_seq(nb, height=self.height, width=self.width).to(
+ self.device
+ )
+
def store_new_quizzes(self, new_quizzes, for_train=True):
if for_train:
self.train_quizzes.append(new_quizzes)
nb,
model,
other_models,
+ desired_average_logits=None,
):
###############################################################
# Generate quizzes with model
quizzes = torch.empty(
nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64
)
+
ar_mask = torch.full(quizzes.size(), 1, device=self.device)
- masked_inplace_autoregression(
- model,
- self.batch_size,
- quizzes,
- ar_mask,
- deterministic_synthesis=False,
- progress_bar_desc="creating quizzes",
- device=self.device,
- )
+ temperature = 1
+ d_temperature = 1
+
+ while True:
+ sum_logits = masked_inplace_autoregression(
+ model=model,
+ batch_size=self.batch_size,
+ input=quizzes,
+ ar_mask=ar_mask,
+ temperature=temperature,
+ deterministic_synthesis=False,
+ progress_bar_desc="creating quizzes",
+ device=self.device,
+ )
+
+ average_logits = sum_logits / quizzes.size(0)
+
+ logger(f"{average_logits=} {desired_average_logits=}")
+
+ if desired_average_logits is None:
+ break
+
+ # Oh man that's ugly
+ if average_logits > desired_average_logits:
+ if d_temperature < 0:
+ d_temperature *= -0.5
+ temperature += d_temperature
+ else:
+ if d_temperature > 0:
+ d_temperature *= -0.5
+ temperature += d_temperature
+
+ logger(f"chaging temperature to {temperature}")
###############################################################
# Create the reverse quizzes
# Check how many of the other models can solve them in both
# directions
- nb_correct = 0
+ nb_correct = []
for m in other_models:
result = quizzes.clone()
masked_inplace_autoregression(
- m,
- self.batch_size,
- result,
- ar_mask,
+ model=m,
+ batch_size=self.batch_size,
+ input=result,
+ ar_mask=ar_mask,
+ temperature=1.0,
deterministic_synthesis=True,
progress_bar_desc="solving quizzes",
device=self.device,
reverse_result = reverse_quizzes.clone()
masked_inplace_autoregression(
- m,
- self.batch_size,
- reverse_result,
- ar_mask,
+ model=m,
+ batch_size=self.batch_size,
+ input=reverse_result,
+ ar_mask=ar_mask,
+ temperature=1.0,
deterministic_synthesis=True,
progress_bar_desc="solving reversed quizzes",
device=self.device,
(reverse_quizzes == reverse_result).long().min(dim=-1).values
)
- nb_correct += correct * reverse_correct
+ nb_correct.append((correct * reverse_correct)[None, :])
+
+ nb_correct = torch.cat(nb_correct, dim=0)
+
+ # filename = os.path.join(result_dir, "correct_{n_epoch:04d}.dat")
+ # with open(filename, "w") as f:
+ # for k in nb_correct:
+ # f.write(f"{k}\n")
- return quizzes, nb_correct
+ return quizzes, nb_correct.sum(dim=0), sum_logits