batch_size,
input,
ar_mask,
+ seq_logproba,
temperature,
deterministic_synthesis,
forbidden_tokens=None,
):
assert input.size() == ar_mask.size()
- batches = zip(input.split(batch_size), ar_mask.split(batch_size))
+ batches = zip(
+ input.split(batch_size),
+ ar_mask.split(batch_size),
+ seq_logproba.split(batch_size),
+ )
if progress_bar_desc is not None:
batches = tqdm.tqdm(
total=(input.size(0) + batch_size - 1) // batch_size,
)
- sum_logits = 0
-
with torch.autograd.no_grad():
t = model.training
model.eval()
- for input, ar_mask in batches:
- sum_logits += model.masked_inplace_autoregression(
+ for input, ar_mask, seq_logproba in batches:
+ model.masked_inplace_autoregression(
input=input,
ar_mask=ar_mask,
+ seq_logproba=seq_logproba,
temperature=temperature,
deterministic_synthesis=deterministic_synthesis,
forbidden_tokens=forbidden_tokens,
model.train(t)
- return sum_logits
-
######################################################################
import world
-class World(Task):
+class QuizzMachine(Task):
def save_image(self, input, result_dir, filename, logger):
img = world.seq2img(input.to("cpu"), self.height, self.width)
image_name = os.path.join(result_dir, filename)
torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
logger(f"wrote {image_name}")
+ def save_quizzes(self, input, result_dir, filename_prefix, logger):
+ self.save_image(input, result_dir, filename_prefix + ".png", logger)
+
def make_ar_mask(self, input):
b = torch.arange(input.size(1), device=input.device) > input.size(1) // 2
return b.long()[None, :].expand_as(input)
self.height = 6
self.width = 8
- self.train_input = world.generate_seq(
+ self.train_w_quizzes = world.generate_seq(
nb_train_samples, height=self.height, width=self.width
).to(device)
- self.test_input = world.generate_seq(
+ self.test_w_quizzes = world.generate_seq(
nb_test_samples, height=self.height, width=self.width
).to(device)
- self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+ self.nb_codes = max(self.train_w_quizzes.max(), self.test_w_quizzes.max()) + 1
- self.train_quizzes = []
- self.test_quizzes = []
+ self.train_c_quizzes = []
+ self.test_c_quizzes = []
if result_dir is not None:
- self.save_image(
- self.train_input[:72], result_dir, f"world_train.png", logger
+ self.save_quizzes(
+ self.train_w_quizzes[:72], result_dir, f"culture_w_quizzes", logger
)
def batches(self, split="train", desc=None):
assert split in {"train", "test"}
if split == "train":
- input = self.train_input
- quizzes = self.train_quizzes
+ w_quizzes = self.train_w_quizzes
+ c_quizzes = self.train_c_quizzes
else:
- input = self.test_input
- quizzes = self.test_quizzes
+ w_quizzes = self.test_w_quizzes
+ c_quizzes = self.test_c_quizzes
- if len(quizzes) > 0:
- quizzes = torch.cat(quizzes, dim=0)
- if quizzes.size(0) > input.size(0) // 2:
- i = torch.randperm(input.size(0))[: input.size(0) // 2]
- quizzes = quizzes[i]
+ if len(c_quizzes) > 0:
+ c_quizzes = torch.cat(c_quizzes, dim=0)
+ if c_quizzes.size(0) > w_quizzes.size(0) // 2:
+ i = torch.randperm(w_quizzes.size(0))[: w_quizzes.size(0) // 2]
+ c_quizzes = c_quizzes[i]
- i = torch.randperm(input.size(0))[: input.size(0) - quizzes.size(0)]
- input = input[i]
+ i = torch.randperm(w_quizzes.size(0))[
+ : w_quizzes.size(0) - c_quizzes.size(0)
+ ]
+ w_quizzes = w_quizzes[i]
- self.nb_batch_samples_world = input.size(0)
- self.nb_batch_samples_quizzes = quizzes.size(0)
+ self.nb_batch_w_quizzes = w_quizzes.size(0)
+ self.nb_batch_c_quizzes = c_quizzes.size(0)
- input = torch.cat([input, quizzes], dim=0)
+ input = torch.cat([w_quizzes, c_quizzes], dim=0)
else:
- self.nb_batch_samples_world = input.size(0)
- self.nb_batch_samples_quizzes = 0
+ input = w_quizzes
+ self.nb_batch_w_quizzes = w_quizzes.size(0)
+ self.nb_batch_c_quizzes = 0
# Shuffle
input = input[torch.randperm(input.size(0))]
input = input[:nmax]
ar_mask = self.make_ar_mask(input)
result = input.clone() * (1 - ar_mask)
+ seq_logproba = torch.empty(input.size(0), device=self.device)
masked_inplace_autoregression(
model=model,
batch_size=self.batch_size,
input=result,
ar_mask=ar_mask,
+ seq_logproba=seq_logproba,
temperature=1.0,
deterministic_synthesis=deterministic_synthesis,
progress_bar_desc=None,
return nb_total, nb_correct
- train_nb_total, train_nb_correct = compute_accuracy(self.train_input)
+ train_nb_total, train_nb_correct = compute_accuracy(self.train_w_quizzes)
logger(
f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
)
- test_nb_total, test_nb_correct = compute_accuracy(self.test_input, logger)
+ test_nb_total, test_nb_correct = compute_accuracy(self.test_w_quizzes, logger)
logger(
f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
##############################
- input = self.test_input[:96]
+ input = self.test_w_quizzes[:96]
ar_mask = self.make_ar_mask(input)
result = input.clone() * (1 - ar_mask)
+ seq_logproba = torch.empty(input.size(0), device=self.device)
masked_inplace_autoregression(
model=model,
batch_size=self.batch_size,
input=result,
ar_mask=ar_mask,
+ seq_logproba=seq_logproba,
temperature=1.0,
deterministic_synthesis=deterministic_synthesis,
progress_bar_desc=None,
device=self.device,
)
- self.save_image(
+ self.save_quizzes(
result[:72],
result_dir,
- f"world_prediction_{n_epoch:04d}_{model.id:02d}.png",
+ f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
logger,
)
return main_test_accuracy
- def renew_samples(self, nb, for_train=True):
- input = self.train_input if for_train else self.test_input
+ def renew_w_quizzes(self, nb, for_train=True):
+ input = self.train_w_quizzes if for_train else self.test_w_quizzes
nb = min(nb, input.size(0))
input[:-nb] = input[nb:].clone()
input[-nb:] = world.generate_seq(nb, height=self.height, width=self.width).to(
self.device
)
- def store_new_quizzes(self, new_quizzes, for_train=True):
+ def store_c_quizzes(self, new_c_quizzes, for_train=True):
if for_train:
- self.train_quizzes.append(new_quizzes)
+ self.train_c_quizzes.append(new_c_quizzes)
else:
- self.test_quizzes.append(new_quizzes)
+ self.test_c_quizzes.append(new_c_quizzes)
- def create_new_quizzes(
+ def create_c_quizzes(
self,
n_epoch,
result_dir,
nb,
model,
other_models,
- desired_average_logits=None,
+ min_ave_seq_logproba,
):
###############################################################
# Generate quizzes with model
- quizzes = torch.empty(
+ c_quizzes = torch.empty(
nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64
)
- ar_mask = torch.full(quizzes.size(), 1, device=self.device)
+ ar_mask = torch.full(c_quizzes.size(), 1, device=self.device)
+ seq_logproba = torch.empty(ar_mask.size(0), device=self.device)
temperature = 1
d_temperature = 1
while True:
- sum_logits = masked_inplace_autoregression(
+ seq_logproba[...] = 0
+
+ masked_inplace_autoregression(
model=model,
batch_size=self.batch_size,
- input=quizzes,
+ input=c_quizzes,
ar_mask=ar_mask,
+ seq_logproba=seq_logproba,
temperature=temperature,
deterministic_synthesis=False,
- progress_bar_desc="creating quizzes",
+ progress_bar_desc="sampling c_quizzes",
device=self.device,
)
- average_logits = sum_logits / quizzes.size(0)
+ ave_seq_logproba = seq_logproba.mean()
- logger(f"{average_logits=} {desired_average_logits=}")
+ logger(f"{ave_seq_logproba=} {min_ave_seq_logproba=}")
- if desired_average_logits is None:
+ if min_ave_seq_logproba is None:
break
# Oh man that's ugly
- if average_logits > desired_average_logits:
+ if ave_seq_logproba < min_ave_seq_logproba * 1.1:
+ if d_temperature > 0:
+ d_temperature *= -1 / 3
+ temperature += d_temperature
+ elif ave_seq_logproba > min_ave_seq_logproba:
if d_temperature < 0:
- d_temperature *= -0.5
+ d_temperature *= -1 / 3
temperature += d_temperature
else:
- if d_temperature > 0:
- d_temperature *= -0.5
- temperature += d_temperature
+ break
+
logger(f"chaging temperature to {temperature}")
###############################################################
# Create the reverse quizzes
l = self.height * self.width
- direction = quizzes[:, l : l + 1]
+ direction = c_quizzes[:, l : l + 1]
direction = world.token_forward * (
direction == world.token_backward
) + world.token_backward * (direction == world.token_forward)
- reverse_quizzes = torch.cat(
- [quizzes[:, l + 1 :], direction, quizzes[:, :l]], dim=1
+ reverse_c_quizzes = torch.cat(
+ [c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1
)
- ar_mask = self.make_ar_mask(quizzes)
+ ar_mask = self.make_ar_mask(c_quizzes)
+ seq_logproba = torch.empty(ar_mask.size(0), device=self.device)
###############################################################
# Check how many of the other models can solve them in both
nb_correct = []
for m in other_models:
- result = quizzes.clone()
+ result = c_quizzes.clone()
masked_inplace_autoregression(
model=m,
batch_size=self.batch_size,
input=result,
ar_mask=ar_mask,
+ seq_logproba=seq_logproba,
temperature=1.0,
deterministic_synthesis=True,
- progress_bar_desc="solving quizzes",
+ progress_bar_desc="solving c_quizzes",
device=self.device,
)
- correct = (quizzes == result).long().min(dim=-1).values
+ correct = (c_quizzes == result).long().min(dim=-1).values
- reverse_result = reverse_quizzes.clone()
+ reverse_result = reverse_c_quizzes.clone()
masked_inplace_autoregression(
model=m,
batch_size=self.batch_size,
input=reverse_result,
ar_mask=ar_mask,
+ seq_logproba=seq_logproba,
temperature=1.0,
deterministic_synthesis=True,
- progress_bar_desc="solving reversed quizzes",
+ progress_bar_desc="solving reversed c_quizzes",
device=self.device,
)
reverse_correct = (
- (reverse_quizzes == reverse_result).long().min(dim=-1).values
+ (reverse_c_quizzes == reverse_result).long().min(dim=-1).values
)
nb_correct.append((correct * reverse_correct)[None, :])
- nb_correct = torch.cat(nb_correct, dim=0)
-
- # filename = os.path.join(result_dir, "correct_{n_epoch:04d}.dat")
- # with open(filename, "w") as f:
- # for k in nb_correct:
- # f.write(f"{k}\n")
+ nb_correct = torch.cat(nb_correct, dim=0).sum(dim=0)
- return quizzes, nb_correct.sum(dim=0), sum_logits
+ return c_quizzes, nb_correct, seq_logproba.mean()