def batch_prediction(input, proba_hints=0.0):
nb = input.size(0)
- mask_generate = input.new_zeros(input.size())
- u = F.one_hot(torch.randint(4, (nb,), device=mask_generate.device), num_classes=4)
- mask_generate.view(nb, 4, -1)[:, :, 1:] = u[:, :, None]
+ mask = input.new_zeros(input.size())
+ u = F.one_hot(torch.randint(4, (nb,), device=mask.device), num_classes=4)
+ mask.view(nb, 4, -1)[:, :, 1:] = u[:, :, None]
if proba_hints > 0:
- h = torch.rand(input.size(), device=input.device) * mask_generate
+ h = torch.rand(input.size(), device=input.device) * mask
mask_hints = h.sort(dim=1, descending=True).values < args.nb_hints
v = torch.rand(nb, device=input.device)[:, None]
mask_hints = mask_hints * (v < proba_hints).long()
- mask_generate = (1 - mask_hints) * mask_generate
+ mask = (1 - mask_hints) * mask
# noise = quiz_machine.problem.pure_noise(nb, input.device)
targets = input
- input = (1 - mask_generate) * targets # + mask_generate * noise
+ input = (1 - mask) * targets # + mask * noise
- return input, targets, mask_generate
+ return input, targets, mask
def predict(model, input, targets, mask, local_device=main_device):
targets = input
input = (1 - mask_erased) * input + mask_erased * noise
- mask_generate = input.new_full(input.size(), 1)
- mask_generate.reshape(mask_generate.size(0), 4, -1)[:, :, 0] = 0
+ mask = input.new_full(input.size(), 1)
+ mask.reshape(mask.size(0), 4, -1)[:, :, 0] = 0
- return input, targets, mask_generate
+ return input, targets, mask
def prioritized_rand(low):
def generate(model, nb, local_device=main_device):
input = quiz_machine.problem.pure_noise(nb, local_device)
- mask_generate = input.new_full(input.size(), 1)
- mask_generate.reshape(mask_generate.size(0), 4, -1)[:, :, 0] = 0
+ mask = input.new_full(input.size(), 1)
+ mask.reshape(mask.size(0), 4, -1)[:, :, 0] = 0
changed = True
- for it in range(self.diffusion_nb_iterations):
+ for it in range(args.diffusion_nb_iterations):
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
- logits = model(NTC_channel_cat(input, mask_generate))
+ logits = model(NTC_channel_cat(input, mask))
dist = torch.distributions.categorical.Categorical(logits=logits)
output = dist.sample()
- r = self.prioritized_rand(input != output)
- mask_changes = (r <= self.proba_corruption).long()
+ r = prioritized_rand(input != output)
+ mask_changes = (r <= args.diffusion_proba_corruption).long() * mask
update = (1 - mask_changes) * input + mask_changes * output
-
if update.equal(input):
break
else:
def one_train_test_epoch(model, n_epoch, c_quizzes, local_device=main_device):
+ # train
+
one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True)
one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=False)
+ # predict
+
quizzes = quiz_machine.quiz_set(150, c_quizzes, args.c_quiz_multiplier)
input, targets, mask = batch_prediction(quizzes.to(local_device))
result = predict(model, input, targets, mask).to("cpu")
model.test_accuracy = correct.sum() / quizzes.size(0)
+ # generate
+
+ result = generate(model, 25).to("cpu")
+ quiz_machine.problem.save_quizzes_as_image(
+ args.result_dir,
+ f"culture_generation_{n_epoch}_{model.id}.png",
+ quizzes=result[:128],
+ )
+
######################################################################