import ffutils
import mygpt, tasks
+# world quizzes vs. culture quizzes
+
+######################################################################
+
+accuracy_to_make_c_quizzes = 0.975
+nb_new_c_quizzes_for_train = 1000
+nb_new_c_quizzes_for_test = 100
+
######################################################################
if torch.cuda.is_available():
######################################################################
+if args.dirty_debug:
+ accuracy_to_make_c_quizzes = 0.0
+ nb_new_c_quizzes_for_train = 100
+ nb_new_c_quizzes_for_test = 10
+
+######################################################################
+
default_args = {
"model": "37M",
"batch_size": 100,
######################################################################
-def create_quizzes(
+def create_c_quizzes(
model,
other_models,
task,
):
kept = []
- sum_logits, sum_nb_quizzes = 0, 0
+ sum_logits, sum_nb_c_quizzes = 0, 0
while sum([x.size(0) for x in kept]) < nb_for_train + nb_for_test:
nb_to_generate = 4 * (nb_for_train + nb_for_test)
- new_quizzes, nb_correct, average_logits = task.create_new_quizzes(
+ new_c_quizzes, nb_correct, average_logits = task.create_c_quizzes(
n_epoch=n_epoch,
result_dir=args.result_dir,
logger=log_string,
desired_average_logits=desired_average_logits,
)
- sum_logits += new_quizzes.size(0) * average_logits
- sum_nb_quizzes += new_quizzes.size(0)
+ sum_logits += new_c_quizzes.size(0) * average_logits
+ sum_nb_c_quizzes += new_c_quizzes.size(0)
- to_keep = new_quizzes[nb_correct == len(other_models) - 1]
+ to_keep = new_c_quizzes[nb_correct == len(other_models) - 1]
if args.dirty_debug:
- to_keep = new_quizzes
+ to_keep = new_c_quizzes
log_string(
- f"keep {to_keep.size(0)}/{new_quizzes.size(0)} quizzes ({to_keep.size(0)*100/new_quizzes.size(0):.02f}%)"
+ f"keep {to_keep.size(0)}/{new_c_quizzes.size(0)} c_quizzes ({to_keep.size(0)*100/new_c_quizzes.size(0):.02f}%)"
)
kept.append(to_keep)
- new_quizzes = torch.cat(kept, dim=0)[: nb_for_train + nb_for_test]
+ new_c_quizzes = torch.cat(kept, dim=0)[: nb_for_train + nb_for_test]
- task.store_new_quizzes(new_quizzes[:nb_for_train], for_train=True)
- task.store_new_quizzes(new_quizzes[nb_for_train:], for_train=False)
+ task.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True)
+ task.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False)
- task.save_image(
- new_quizzes[:72],
+ task.save_quizzes(
+ new_c_quizzes[:72],
args.result_dir,
- f"world_quiz_{n_epoch:04d}_{model.id:02d}.png",
+ f"culture_c_quiz_{n_epoch:04d}_{model.id:02d}",
log_string,
)
- return sum_logits / sum_nb_quizzes
+ return sum_logits / sum_nb_c_quizzes
######################################################################
######################################################################
-accuracy_to_make_quizzes = 0.975
-nb_new_quizzes_for_train = 1000
-nb_new_quizzes_for_test = 100
-
-if args.dirty_debug:
- accuracy_to_make_quizzes = 0.0
- nb_new_quizzes_for_train = 100
- nb_new_quizzes_for_test = 10
-
desired_average_logits = None
for n_epoch in range(args.nb_epochs):
# improve it
one_epoch(model, task)
- task.renew_samples(args.nb_train_samples // args.nb_gpts)
+ task.renew_w_quizzes(args.nb_train_samples // args.nb_gpts)
log_string(
- f"train_set_composition world {task.nb_batch_samples_world} quizzes {task.nb_batch_samples_quizzes}"
+ f"train_set_composition w_quizzes {task.nb_batch_w_quizzes} c_quizzes {task.nb_batch_c_quizzes}"
)
# test it
run_tests(model, task, deterministic_synthesis=False)
log_string(
- f"test_set_composition world {task.nb_batch_samples_world} quizzes {task.nb_batch_samples_quizzes}"
+ f"test_set_composition w_quizzes {task.nb_batch_w_quizzes} c_quizzes {task.nb_batch_c_quizzes}"
)
- if min([m.main_test_accuracy for m in models]) >= accuracy_to_make_quizzes:
+ if min([m.main_test_accuracy for m in models]) >= accuracy_to_make_c_quizzes:
other_models = models.copy()
other_models.remove(model)
- average_logits = create_quizzes(
+ average_logits = create_c_quizzes(
model,
other_models,
task,
- nb_for_train=nb_new_quizzes_for_train,
- nb_for_test=nb_new_quizzes_for_test,
+ nb_for_train=nb_new_c_quizzes_for_train,
+ nb_for_test=nb_new_c_quizzes_for_test,
desired_average_logits=desired_average_logits,
)
torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
logger(f"wrote {image_name}")
+ def save_quizzes(self, input, result_dir, filename_prefix, logger):
+ self.save_image(input, result_dir, filename_prefix + ".png", logger)
+
def make_ar_mask(self, input):
b = torch.arange(input.size(1), device=input.device) > input.size(1) // 2
return b.long()[None, :].expand_as(input)
self.height = 6
self.width = 8
- self.train_input = world.generate_seq(
+ self.train_w_quizzes = world.generate_seq(
nb_train_samples, height=self.height, width=self.width
).to(device)
- self.test_input = world.generate_seq(
+ self.test_w_quizzes = world.generate_seq(
nb_test_samples, height=self.height, width=self.width
).to(device)
- self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+ self.nb_codes = max(self.train_w_quizzes.max(), self.test_w_quizzes.max()) + 1
- self.train_quizzes = []
- self.test_quizzes = []
+ self.train_c_quizzes = []
+ self.test_c_quizzes = []
if result_dir is not None:
- self.save_image(
- self.train_input[:72], result_dir, f"world_train.png", logger
+ self.save_quizzes(
+ self.train_w_quizzes[:72], result_dir, f"culture_w_quizzes", logger
)
def batches(self, split="train", desc=None):
assert split in {"train", "test"}
if split == "train":
- input = self.train_input
- quizzes = self.train_quizzes
+ w_quizzes = self.train_w_quizzes
+ c_quizzes = self.train_c_quizzes
else:
- input = self.test_input
- quizzes = self.test_quizzes
+ w_quizzes = self.test_w_quizzes
+ c_quizzes = self.test_c_quizzes
- if len(quizzes) > 0:
- quizzes = torch.cat(quizzes, dim=0)
- if quizzes.size(0) > input.size(0) // 2:
- i = torch.randperm(input.size(0))[: input.size(0) // 2]
- quizzes = quizzes[i]
+ if len(c_quizzes) > 0:
+ c_quizzes = torch.cat(c_quizzes, dim=0)
+ if c_quizzes.size(0) > w_quizzes.size(0) // 2:
+ i = torch.randperm(w_quizzes.size(0))[: w_quizzes.size(0) // 2]
+ c_quizzes = c_quizzes[i]
- i = torch.randperm(input.size(0))[: input.size(0) - quizzes.size(0)]
- input = input[i]
+ i = torch.randperm(w_quizzes.size(0))[
+ : w_quizzes.size(0) - c_quizzes.size(0)
+ ]
+ w_quizzes = w_quizzes[i]
- self.nb_batch_samples_world = input.size(0)
- self.nb_batch_samples_quizzes = quizzes.size(0)
+ self.nb_batch_w_quizzes = w_quizzes.size(0)
+ self.nb_batch_c_quizzes = c_quizzes.size(0)
- input = torch.cat([input, quizzes], dim=0)
+ input = torch.cat([w_quizzes, c_quizzes], dim=0)
else:
- self.nb_batch_samples_world = input.size(0)
- self.nb_batch_samples_quizzes = 0
+ input = w_quizzes
+ self.nb_batch_w_quizzes = w_quizzes.size(0)
+ self.nb_batch_c_quizzes = 0
# Shuffle
input = input[torch.randperm(input.size(0))]
return nb_total, nb_correct
- train_nb_total, train_nb_correct = compute_accuracy(self.train_input)
+ train_nb_total, train_nb_correct = compute_accuracy(self.train_w_quizzes)
logger(
f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
)
- test_nb_total, test_nb_correct = compute_accuracy(self.test_input, logger)
+ test_nb_total, test_nb_correct = compute_accuracy(self.test_w_quizzes, logger)
logger(
f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
##############################
- input = self.test_input[:96]
+ input = self.test_w_quizzes[:96]
ar_mask = self.make_ar_mask(input)
result = input.clone() * (1 - ar_mask)
device=self.device,
)
- self.save_image(
+ self.save_quizzes(
result[:72],
result_dir,
- f"world_prediction_{n_epoch:04d}_{model.id:02d}.png",
+ f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
logger,
)
return main_test_accuracy
- def renew_samples(self, nb, for_train=True):
- input = self.train_input if for_train else self.test_input
+ def renew_w_quizzes(self, nb, for_train=True):
+ input = self.train_w_quizzes if for_train else self.test_w_quizzes
nb = min(nb, input.size(0))
input[:-nb] = input[nb:].clone()
input[-nb:] = world.generate_seq(nb, height=self.height, width=self.width).to(
self.device
)
- def store_new_quizzes(self, new_quizzes, for_train=True):
+ def store_c_quizzes(self, new_c_quizzes, for_train=True):
if for_train:
- self.train_quizzes.append(new_quizzes)
+ self.train_c_quizzes.append(new_c_quizzes)
else:
- self.test_quizzes.append(new_quizzes)
+ self.test_c_quizzes.append(new_c_quizzes)
- def create_new_quizzes(
+ def create_c_quizzes(
self,
n_epoch,
result_dir,
###############################################################
# Generate quizzes with model
- quizzes = torch.empty(
+ c_quizzes = torch.empty(
nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64
)
- ar_mask = torch.full(quizzes.size(), 1, device=self.device)
+ ar_mask = torch.full(c_quizzes.size(), 1, device=self.device)
summed_logits = torch.empty(nb, device=self.device)
temperature = 1
masked_inplace_autoregression(
model=model,
batch_size=self.batch_size,
- input=quizzes,
+ input=c_quizzes,
ar_mask=ar_mask,
summed_logits=summed_logits,
temperature=temperature,
deterministic_synthesis=False,
- progress_bar_desc="creating quizzes",
+ progress_bar_desc="sampling c_quizzes",
device=self.device,
)
# Create the reverse quizzes
l = self.height * self.width
- direction = quizzes[:, l : l + 1]
+ direction = c_quizzes[:, l : l + 1]
direction = world.token_forward * (
direction == world.token_backward
) + world.token_backward * (direction == world.token_forward)
- reverse_quizzes = torch.cat(
- [quizzes[:, l + 1 :], direction, quizzes[:, :l]], dim=1
+ reverse_c_quizzes = torch.cat(
+ [c_quizzes[:, l + 1 :], direction, c_quizzes[:, :l]], dim=1
)
- ar_mask = self.make_ar_mask(quizzes)
+ ar_mask = self.make_ar_mask(c_quizzes)
###############################################################
# Check how many of the other models can solve them in both
nb_correct = []
for m in other_models:
- result = quizzes.clone()
+ result = c_quizzes.clone()
masked_inplace_autoregression(
model=m,
summed_logits=None,
temperature=1.0,
deterministic_synthesis=True,
- progress_bar_desc="solving quizzes",
+ progress_bar_desc="solving c_quizzes",
device=self.device,
)
- correct = (quizzes == result).long().min(dim=-1).values
+ correct = (c_quizzes == result).long().min(dim=-1).values
- reverse_result = reverse_quizzes.clone()
+ reverse_result = reverse_c_quizzes.clone()
masked_inplace_autoregression(
model=m,
summed_logits=None,
temperature=1.0,
deterministic_synthesis=True,
- progress_bar_desc="solving reversed quizzes",
+ progress_bar_desc="solving reversed c_quizzes",
device=self.device,
)
reverse_correct = (
- (reverse_quizzes == reverse_result).long().min(dim=-1).values
+ (reverse_c_quizzes == reverse_result).long().min(dim=-1).values
)
nb_correct.append((correct * reverse_correct)[None, :])
- nb_correct = torch.cat(nb_correct, dim=0)
+ nb_correct = torch.cat(nb_correct, dim=0).sum(dim=0)
# filename = os.path.join(result_dir, "correct_{n_epoch:04d}.dat")
# with open(filename, "w") as f:
# for k in nb_correct:
# f.write(f"{k}\n")
- return quizzes, nb_correct.sum(dim=0), summed_logits.mean()
+ return c_quizzes, nb_correct, summed_logits.mean()