-elif args.task == "snake":
- task = tasks.Snake(
- nb_train_samples=args.nb_train_samples,
- nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
- height=args.snake_height,
- width=args.snake_width,
- nb_colors=args.snake_nb_colors,
- length=args.snake_length,
- prompt_length=args.snake_length // 2,
- device=device,
+ q_p, q_g = quizzes.to(local_device).chunk(2)
+
+ # Half of the samples train the prediction, and we inject noise in
+ # all, and hints in half
+ b_p = samples_for_prediction_imt(q_p)
+ b_p = add_noise_imt(b_p)
+ half = torch.rand(b_p.size(0)) < 0.5
+ b_p[half] = add_hints_imt(b_p[half])
+
+ # The other half are denoising examples for the generation
+ b_g = samples_for_generation_imt(q_g)
+
+ imt_set = torch.cat([b_p, b_g])
+ imt_set = imt_set[torch.randperm(imt_set.size(0), device=imt_set.device)]
+
+ if train:
+ label = "train"
+ model.train().to(local_device)
+ optimizer_to(model.optimizer, local_device)
+ batch_size = args.train_batch_size
+ else:
+ label = "test"
+ model.eval().to(local_device)
+ batch_size = args.eval_batch_size
+
+ nb_samples, acc_loss = 0, 0.0
+
+ for imt in tqdm.tqdm(
+ imt_set.split(batch_size),
+ dynamic_ncols=True,
+ desc=label,
+ total=quizzes.size(0) // batch_size,
+ delay=10,
+ ):
+ input, masks, targets = imt.unbind(dim=1)
+ if train and nb_samples % args.batch_size == 0:
+ model.optimizer.zero_grad()
+
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
+ logits = model(input * 2 + masks)
+
+ loss_per_token = F.cross_entropy(
+ logits.transpose(1, 2), targets, reduction="none"
+ )
+ loss = (loss_per_token * masks).mean()
+ acc_loss += loss.item() * imt.size(0)
+ nb_samples += imt.size(0)
+
+ if train:
+ loss.backward()
+
+ if nb_samples % args.batch_size == 0:
+ model.optimizer.step()
+
+ log_string(f"{label}_loss {n_epoch} model {model.id} {acc_loss/nb_samples}")
+
+
+######################################################################
+
+
+def save_inference_images(model, n_epoch, c_quizzes, c_quiz_multiplier, local_device):
+ # Save some images of the prediction results
+
+ quizzes = generate_quiz_set(150, c_quizzes, args.c_quiz_multiplier)
+ imt_set = samples_for_prediction_imt(quizzes.to(local_device))
+ result = ae_predict(model, imt_set, local_device=local_device).to("cpu")
+ masks = imt_set[:, 1].to("cpu")
+
+ correct = (quizzes == result).min(dim=1).values.long()
+ correct_parts = (2 * correct - 1)[:, None] * masks.reshape(masks.size(0), 4, -1)[
+ :, :, 1
+ ]
+ predicted_parts = correct_parts.abs()
+
+ problem.save_quizzes_as_image(
+ args.result_dir,
+ f"culture_prediction_{n_epoch}_{model.id}.png",
+ quizzes=result[:128],
+ predicted_parts=predicted_parts[:128],
+ correct_parts=correct_parts[:128],