batch_size,
input,
ar_mask,
+ seq_logproba,
+ temperature,
deterministic_synthesis,
forbidden_tokens=None,
logit_biases=None,
):
assert input.size() == ar_mask.size()
- batches = zip(input.split(batch_size), ar_mask.split(batch_size))
+ batches = zip(
+ input.split(batch_size),
+ ar_mask.split(batch_size),
+ seq_logproba.split(batch_size),
+ )
if progress_bar_desc is not None:
batches = tqdm.tqdm(
t = model.training
model.eval()
- for input, ar_mask in batches:
+ for input, ar_mask, seq_logproba in batches:
model.masked_inplace_autoregression(
- input,
- ar_mask,
- deterministic_synthesis,
- forbidden_tokens,
- logit_biases,
+ input=input,
+ ar_mask=ar_mask,
+ seq_logproba=seq_logproba,
+ temperature=temperature,
+ deterministic_synthesis=deterministic_synthesis,
+ forbidden_tokens=forbidden_tokens,
+ forced_biases=logit_biases,
)
model.train(t)
import world
-class World(Task):
+class QuizzMachine(Task):
def save_image(self, input, result_dir, filename, logger):
- img = world.sample2img(input.to("cpu"), self.height, self.width)
+ img = world.seq2img(input.to("cpu"), self.height, self.width)
image_name = os.path.join(result_dir, filename)
- torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=8, padding=2)
+ torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
logger(f"wrote {image_name}")
+ def save_quizzes(self, input, result_dir, filename_prefix, logger):
+ self.save_image(input, result_dir, filename_prefix + ".png", logger)
+
def make_ar_mask(self, input):
b = torch.arange(input.size(1), device=input.device) > input.size(1) // 2
return b.long()[None, :].expand_as(input)
self.height = 6
self.width = 8
- self.train_input = world.generate(
+ self.train_w_quizzes = world.generate_seq(
nb_train_samples, height=self.height, width=self.width
).to(device)
- self.test_input = world.generate(
+ self.test_w_quizzes = world.generate_seq(
nb_test_samples, height=self.height, width=self.width
).to(device)
- self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+ self.nb_codes = max(self.train_w_quizzes.max(), self.test_w_quizzes.max()) + 1
- self.train_quizzes = []
- self.test_quizzes = []
+ self.train_c_quizzes = []
+ self.test_c_quizzes = []
if result_dir is not None:
- self.save_image(
- self.train_input[:96], result_dir, f"world_train.png", logger
+ self.save_quizzes(
+ self.train_w_quizzes[:72], result_dir, f"culture_w_quizzes", logger
)
def batches(self, split="train", desc=None):
assert split in {"train", "test"}
if split == "train":
- input = self.train_input
- quizzes = self.train_quizzes
+ w_quizzes = self.train_w_quizzes
+ c_quizzes = self.train_c_quizzes
else:
- input = self.test_input
- quizzes = self.test_quizzes
+ w_quizzes = self.test_w_quizzes
+ c_quizzes = self.test_c_quizzes
- if len(quizzes) > 0:
- quizzes = torch.cat(quizzes, dim=0)
- if quizzes.size(0) > input.size(0) // 2:
- i = torch.randperm(input.size(0))[: input.size(0) // 2]
- quizzes = quizzes[i]
+ if len(c_quizzes) > 0:
+ c_quizzes = torch.cat(c_quizzes, dim=0)
+ if c_quizzes.size(0) > w_quizzes.size(0) // 2:
+ i = torch.randperm(w_quizzes.size(0))[: w_quizzes.size(0) // 2]
+ c_quizzes = c_quizzes[i]
- i = torch.randperm(input.size(0))[: input.size(0) - quizzes.size(0)]
- input = input[i]
+ i = torch.randperm(w_quizzes.size(0))[
+ : w_quizzes.size(0) - c_quizzes.size(0)
+ ]
+ w_quizzes = w_quizzes[i]
- self.nb_batch_samples_world = input.size(0)
- self.nb_batch_samples_quizzes = quizzes.size(0)
+ self.nb_batch_w_quizzes = w_quizzes.size(0)
+ self.nb_batch_c_quizzes = c_quizzes.size(0)
- input = torch.cat([input, quizzes], dim=0)
+ input = torch.cat([w_quizzes, c_quizzes], dim=0)
else:
- self.nb_batch_samples_world = input.size(0)
- self.nb_batch_samples_quizzes = 0
+ input = w_quizzes
+ self.nb_batch_w_quizzes = w_quizzes.size(0)
+ self.nb_batch_c_quizzes = 0
+
+ # Shuffle
+ input = input[torch.randperm(input.size(0))]
if desc is None:
desc = f"epoch-{split}"
input = input[:nmax]
ar_mask = self.make_ar_mask(input)
result = input.clone() * (1 - ar_mask)
+ seq_logproba = torch.empty(input.size(0), device=self.device)
masked_inplace_autoregression(
- model,
- self.batch_size,
- result,
- ar_mask,
- deterministic_synthesis,
+ model=model,
+ batch_size=self.batch_size,
+ input=result,
+ ar_mask=ar_mask,
+ seq_logproba=seq_logproba,
+ temperature=1.0,
+ deterministic_synthesis=deterministic_synthesis,
progress_bar_desc=None,
device=self.device,
)
return nb_total, nb_correct
- train_nb_total, train_nb_correct = compute_accuracy(self.train_input)
+ train_nb_total, train_nb_correct = compute_accuracy(self.train_w_quizzes)
logger(
f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
)
- test_nb_total, test_nb_correct = compute_accuracy(self.test_input, logger)
+ test_nb_total, test_nb_correct = compute_accuracy(self.test_w_quizzes, logger)
logger(
f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
##############################
- input = self.test_input[:96]
+ input = self.test_w_quizzes[:96]
ar_mask = self.make_ar_mask(input)
result = input.clone() * (1 - ar_mask)
+ seq_logproba = torch.empty(input.size(0), device=self.device)
masked_inplace_autoregression(
- model,
- self.batch_size,
- result,
- ar_mask,
- deterministic_synthesis,
+ model=model,
+ batch_size=self.batch_size,
+ input=result,
+ ar_mask=ar_mask,
+ seq_logproba=seq_logproba,
+ temperature=1.0,
+ deterministic_synthesis=deterministic_synthesis,
progress_bar_desc=None,
device=self.device,
)
- self.save_image(
- result[:96],
+ self.save_quizzes(
+ result[:72],
result_dir,
- f"world_prediction_{n_epoch:04d}_{model.id:02d}.png",
+ f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
logger,
)
return main_test_accuracy
- def store_new_quizzes(self, new_quizzes, for_train=True):
+ def renew_w_quizzes(self, nb, for_train=True):
+ input = self.train_w_quizzes if for_train else self.test_w_quizzes
+ nb = min(nb, input.size(0))
+ input[:-nb] = input[nb:].clone()
+ input[-nb:] = world.generate_seq(nb, height=self.height, width=self.width).to(
+ self.device
+ )
+
+ def store_c_quizzes(self, new_c_quizzes, for_train=True):
if for_train:
- self.train_quizzes.append(new_quizzes)
+ self.train_c_quizzes.append(new_c_quizzes)
else:
- self.test_quizzes.append(new_quizzes)
+ self.test_c_quizzes.append(new_c_quizzes)
- def create_new_quizzes(
+ def create_c_quizzes(
self,
n_epoch,
result_dir,
nb,
model,
other_models,
+ min_ave_seq_logproba,
):
###############################################################
# Generate quizzes with model
- quizzes = torch.empty(
+ c_quizzes = torch.empty(
nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64
)
- ar_mask = torch.full(quizzes.size(), 1, device=self.device)
- masked_inplace_autoregression(
- model,
- self.batch_size,
- quizzes,
- ar_mask,
- deterministic_synthesis=False,
- progress_bar_desc="creating quizzes",
- device=self.device,
- )
+ ar_mask = torch.full(c_quizzes.size(), 1, device=self.device)
+ seq_logproba = torch.empty(ar_mask.size(0), device=self.device)
+
+ temperature = 1
+ d_temperature = 1
+
+ while True:
+ seq_logproba[...] = 0
+
+ masked_inplace_autoregression(
+ model=model,
+ batch_size=self.batch_size,
+ input=c_quizzes,
+ ar_mask=ar_mask,
+ seq_logproba=seq_logproba,
+ temperature=temperature,
+ deterministic_synthesis=False,
+ progress_bar_desc="sampling c_quizzes",
+ device=self.device,
+ )
+
+ ave_seq_logproba = seq_logproba.mean()
+
+ logger(f"{ave_seq_logproba=} {min_ave_seq_logproba=}")
+
+ if min_ave_seq_logproba is None:
+ break
+
+ # Oh man that's ugly
+ if ave_seq_logproba < min_ave_seq_logproba * 1.1:
+ if d_temperature > 0:
+ d_temperature *= -1 / 3
+ temperature += d_temperature
+ elif ave_seq_logproba > min_ave_seq_logproba:
+ if d_temperature < 0:
+ d_temperature *= -1 / 3
+ temperature += d_temperature
+ else:
+ break
+
+ logger(f"chaging temperature to {temperature}")
###############################################################
# Create the reverse quizzes
l = self.height * self.width
- direction = quizzes[:, l : l + 1]
+ direction = c_quizzes[:, l : l + 1]
direction = world.token_forward * (
direction == world.token_backward
) + world.token_backward * (direction == world.token_forward)
- reverse_quizzes = torch.cat(
- [quizzes[:, l + 1 :], direction, quizzes[:, :l]], dim=1
+ reverse_c_quizzes = torch.cat(
+ [c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1
)
- ar_mask = self.make_ar_mask(quizzes)
+ ar_mask = self.make_ar_mask(c_quizzes)
+ seq_logproba = torch.empty(ar_mask.size(0), device=self.device)
###############################################################
# Check how many of the other models can solve them in both
# directions
- nb_correct = 0
+ nb_correct = []
for m in other_models:
- result = quizzes.clone()
+ result = c_quizzes.clone()
masked_inplace_autoregression(
- m,
- self.batch_size,
- result,
- ar_mask,
+ model=m,
+ batch_size=self.batch_size,
+ input=result,
+ ar_mask=ar_mask,
+ seq_logproba=seq_logproba,
+ temperature=1.0,
deterministic_synthesis=True,
- progress_bar_desc="solving quizzes",
+ progress_bar_desc="solving c_quizzes",
device=self.device,
)
- correct = (quizzes == result).long().min(dim=-1).values
+ correct = (c_quizzes == result).long().min(dim=-1).values
- reverse_result = reverse_quizzes.clone()
+ reverse_result = reverse_c_quizzes.clone()
masked_inplace_autoregression(
- m,
- self.batch_size,
- reverse_result,
- ar_mask,
+ model=m,
+ batch_size=self.batch_size,
+ input=reverse_result,
+ ar_mask=ar_mask,
+ seq_logproba=seq_logproba,
+ temperature=1.0,
deterministic_synthesis=True,
- progress_bar_desc="solving reversed quizzes",
+ progress_bar_desc="solving reversed c_quizzes",
device=self.device,
)
reverse_correct = (
- (reverse_quizzes == reverse_result).long().min(dim=-1).values
+ (reverse_c_quizzes == reverse_result).long().min(dim=-1).values
)
- nb_correct += correct * reverse_correct
+ nb_correct.append((correct * reverse_correct)[None, :])
+
+ nb_correct = torch.cat(nb_correct, dim=0).sum(dim=0)
- return quizzes, nb_correct
+ return c_quizzes, nb_correct, seq_logproba.mean()