parser.add_argument("--nb_c_quizzes", type=int, default=10000)
-parser.add_argument("--nb_new_c_quizzes_for_train", type=int, default=None)
-
-parser.add_argument("--nb_new_c_quizzes_for_test", type=int, default=None)
-
parser.add_argument("--c_quiz_multiplier", type=int, default=1)
parser.add_argument("--learning_rate", type=float, default=5e-4)
parser.add_argument("--nb_models", type=int, default=5)
-parser.add_argument("--nb_diffusion_iterations", type=int, default=25)
+parser.add_argument("--diffusion_nb_iterations", type=int, default=25)
-parser.add_argument("--proba_diffusion_corruption", type=float, default=0.05)
+parser.add_argument("--diffusion_proba_corruption", type=float, default=0.05)
parser.add_argument("--min_succeed_to_validate", type=int, default=2)
diffuser = diffusion.Diffuser(
- mu_T_sampler, args.nb_diffusion_iterations, args.proba_diffusion_corruption
+ mu_T_sampler, args.diffusion_nb_iterations, args.diffusion_proba_corruption
)
######################################################################
)
+def NTC_channel_cat(*x):
+ return torch.cat([a.expand_as(x[0])[:, :, None] for a in x], dim=2)
+
+
def NTC_masked_cross_entropy(output, targets, mask):
loss_per_token = F.cross_entropy(output.transpose(1, 2), targets, reduction="none")
return (loss_per_token * mask).mean()
######################################################################
-def one_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_device):
+def one_epoch_(model, n_epoch, c_quizzes, local_device=main_device):
model.train().to(local_device)
optimizer_to(model.optimizer, local_device)
)
+######################################################################
+
+
+def batch_prediction(input, proba_hints=0.0):
+ nb = input.size(0)
+ mask_generate = input.new_zeros(input.size())
+ u = F.one_hot(torch.randint(4, (nb,), device=mask_generate.device), num_classes=4)
+ mask_generate.view(nb, 4, -1)[:, :, 1:] = u[:, :, None]
+
+ if proba_hints > 0:
+ h = torch.rand(input.size(), device=input.device) * mask_generate
+ mask_hints = h.sort(dim=1, descending=True).values < args.nb_hints
+ v = torch.rand(nb, device=input.device)[:, None]
+ mask_hints = mask_hints * (v < proba_hints).long()
+ mask_generate = (1 - mask_hints) * mask_generate
+
+ # noise = quiz_machine.problem.pure_noise(nb, input.device)
+ targets = input
+ input = (1 - mask_generate) * targets # + mask_generate * noise
+
+ return input, targets, mask_generate
+
+
+def predict(model, quizzes, local_device=main_device):
+ model.eval().to(local_device)
+
+ input, targets, mask = batch_prediction(quizzes.to(local_device))
+
+ input_batches = input.reshape(-1, args.physical_batch_size, input.size(1))
+ targets_batches = targets.reshape(-1, args.physical_batch_size, targets.size(1))
+ mask_batches = mask.reshape(-1, args.physical_batch_size, mask.size(1))
+
+ record = []
+
+ for input, targets, mask in tqdm.tqdm(
+ zip(input_batches, targets_batches, mask_batches),
+ dynamic_ncols=True,
+ desc="predict",
+ total=quizzes.size(0) // args.physical_batch_size,
+ ):
+ # noise = quiz_machine.problem.pure_noise(input.size(0), input.device)
+ input = (1 - mask) * input # + mask * noise
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
+ logits = model(NTC_channel_cat(input, mask))
+ dist = torch.distributions.categorical.Categorical(logits=logits)
+ result = (1 - mask) * input + mask * dist.sample()
+ record.append(result)
+
+ return torch.cat(record)
+
+
+######################################################################
+
+
+def batch_generation(input):
+ nb = input.size(0)
+ probs_iterations = 0.1 ** torch.linspace(
+ 0, 1, args.diffusion_nb_iterations, device=input.device
+ )
+ probs_iterations = probs_iterations[None, :] / probs_iterations.sum()
+ probs_iterations = probs_iterations.expand(nb, -1)
+ dist = torch.distributions.categorical.Categorical(probs=probs_iterations)
+ t = dist.sample() + 1
+ r = torch.rand(input.size(), device=input.device)
+ proba_erased = 1 - (1 - args.diffusion_proba_corruption) ** t
+ mask_erased = (r <= proba_erased[:, None]).long()
+
+ noise = quiz_machine.problem.pure_noise(nb, input.device)
+
+ targets = input
+ input = (1 - mask_erased) * input + mask_erased * noise
+ mask_generate = input.new_full(input.size(), 1)
+ mask_generate.reshape(mask_generate.size(0), 4, -1)[:, :, 0] = 0
+
+ return input, targets, mask_generate
+
+
+def prioritized_rand(low):
+ x = torch.rand(low.size(), device=low.device).sort(dim=1, descending=True).values
+ k = torch.rand(low.size(), device=low.device) + low.long()
+ k = k.sort(dim=1).indices
+ y = x.new(x.size())
+ y.scatter_(dim=1, index=k, src=x)
+ return y
+
+
+def generate(model, nb, local_device=main_device):
+ input = quiz_machine.problem.pure_noise(nb, local_device)
+ mask_generate = input.new_full(input.size(), 1)
+ mask_generate.reshape(mask_generate.size(0), 4, -1)[:, :, 0] = 0
+
+ changed = True
+ for it in range(self.diffusion_nb_iterations):
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
+ logits = model(NTC_channel_cat(input, mask_generate))
+ dist = torch.distributions.categorical.Categorical(logits=logits)
+ output = dist.sample()
+
+ r = self.prioritized_rand(input != output)
+ mask_changes = (r <= self.proba_corruption).long()
+ update = (1 - mask_changes) * input + mask_changes * output
+
+ if update.equal(input):
+ break
+ else:
+ changed = changed & (update != input).max(dim=1).values
+ input[changed] = update[changed]
+
+ return input
+
+
+######################################################################
+
+
+def batch_interleave(a, b, perm):
+ return torch.cat([a, b])[perm].reshape(-1, args.physical_batch_size, a.size(1))
+
+
+def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True):
+ if train:
+ label = "train"
+ model.train().to(local_device)
+ optimizer_to(model.optimizer, local_device)
+ else:
+ label = "test"
+ model.eval().to(local_device)
+
+ nb_samples, acc_loss = 0, 0.0
+
+ quizzes = quiz_machine.quiz_set(
+ args.nb_train_samples if train else args.nb_test_samples,
+ c_quizzes,
+ args.c_quiz_multiplier,
+ )
+
+ input_p, input_g = quizzes.to(local_device).chunk(2)
+ input_p, targets_p, mask_p = batch_prediction(input_p, proba_hints=0.5)
+ input_g, targets_g, mask_g = batch_generation(input_g)
+
+ perm = torch.randperm(quizzes.size(0), device=local_device)
+ input_batches = batch_interleave(input_p, input_g, perm)
+ targets_batches = batch_interleave(targets_p, targets_g, perm)
+ mask_batches = batch_interleave(mask_p, mask_g, perm)
+
+ for input, targets, mask in tqdm.tqdm(
+ zip(input_batches, targets_batches, mask_batches),
+ dynamic_ncols=True,
+ desc=label,
+ total=quizzes.size(0) // args.physical_batch_size,
+ ):
+ if train and nb_samples % args.batch_size == 0:
+ model.optimizer.zero_grad()
+
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
+ logits = model(NTC_channel_cat(input, mask))
+
+ loss = NTC_masked_cross_entropy(logits, targets, mask)
+ acc_loss += loss.item() * input.size(0)
+ nb_samples += input.size(0)
+
+ if train:
+ loss.backward()
+
+ if nb_samples % args.batch_size == 0:
+ model.optimizer.step()
+
+ log_string(f"{label}_loss {n_epoch} model {model.id} {acc_loss/nb_samples}")
+
+
+def one_train_test_epoch(model, n_epoch, c_quizzes, local_device=main_device):
+ one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True)
+
+ one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=False)
+
+ quizzes = quiz_machine.quiz_set(150, c_quizzes, args.c_quiz_multiplier)
+ result = predict(model, quizzes).to("cpu")
+
+ quiz_machine.problem.save_quizzes_as_image(
+ args.result_dir,
+ f"culture_prediction_{n_epoch}_{model.id}.png",
+ quizzes=result[:128],
+ )
+
+ nb_correct = (quizzes == result).min(dim=1).values.long().sum()
+ model.test_accuracy = nb_correct / quizzes.size(0)
+
+
######################################################################
import attae
# None if c_quizzes is None else c_quizzes[agreements[:, model.id]],
multithread_execution(
- one_epoch,
- [
- (model, quiz_machine, n_epoch, c_quizzes, gpu)
- for model, gpu in zip(weakest_models, gpus)
- ],
+ one_train_test_epoch,
+ [(model, n_epoch, c_quizzes, gpu) for model, gpu in zip(weakest_models, gpus)],
)
# --------------------------------------------------------------------