all_changed = torch.full((all_input.size(0),), True, device=all_input.device)
for it in range(args.diffusion_nb_iterations):
- log_string(f"nb_changed {all_changed.long().sum().item()}")
+ # log_string(f"nb_changed {all_changed.long().sum().item()}")
if not all_changed.any():
break
c_quizzes = None
-time_c_quizzes = 0
-time_train = 0
-
######################################################################
nb_gpus = len(gpus)
nb_c_quizzes_to_generate = (args.nb_c_quizzes + nb_gpus - 1) // nb_gpus
- (c_quizzes,) = multithread_execution(
+ (new_c_quizzes,) = multithread_execution(
generate_c_quizzes,
[(models, nb_c_quizzes_to_generate, gpu) for gpu in gpus],
)
save_quiz_image(
models,
- c_quizzes[:256],
+ new_c_quizzes[:256],
f"culture_c_quiz_{n_epoch:04d}.png",
solvable_only=False,
)
save_quiz_image(
models,
- c_quizzes[:256],
+ new_c_quizzes[:256],
f"culture_c_quiz_{n_epoch:04d}_solvable.png",
solvable_only=True,
)
- u = c_quizzes.reshape(c_quizzes.size(0), 4, -1)[:, :, 1:]
- i = (u[:, 2] != u[:, 3]).long().sum(dim=1).sort(descending=True).indices
+ log_string(f"generated_c_quizzes {new_c_quizzes.size()=}")
- save_quiz_image(
- models,
- c_quizzes[i][:256],
- f"culture_c_quiz_{n_epoch:04d}_solvable_high_delta.png",
- solvable_only=True,
+ c_quizzes = (
+ new_c_quizzes
+ if c_quizzes is None
+ else torch.cat([c_quizzes, new_c_quizzes])
)
-
- log_string(f"generated_c_quizzes {c_quizzes.size()=}")
+ c_quizzes = c_quizzes[-args.nb_train_samples :]
for model in models:
model.test_accuracy = 0