help="A comma-separated subset of: " + grids_tasks + ".",
)
-parser.add_argument(
- "--grids_science_tasks",
- type=str,
- default=None,
- help="A comma-separated subset of: " + grids_tasks + ", or None.",
-)
-
######################################################################
parser.add_argument("--sky_height", type=int, default=6)
if args.result_dir is None:
args.result_dir = f"results_culture"
-assert not args.grids_science_tasks or (
- len(
- set(args.grids_world_tasks.split(","))
- & set(args.grids_science_tasks.split(","))
- )
- == 0
-), "World and science tasks have to be disjoint"
-
######################################################################
default_args = {
assert args.nb_train_samples % args.batch_size == 0
assert args.nb_test_samples % args.batch_size == 0
-if args.problem == "sky":
- problem = sky.Sky(
- height=args.sky_height,
- width=args.sky_width,
- nb_birds=args.sky_nb_birds,
- nb_iterations=args.sky_nb_iterations,
- speed=args.sky_speed,
- max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
- chunk_size=100,
- nb_threads=args.nb_threads,
- )
-
-elif args.problem == "grids":
- problem = grids.Grids(
- max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
- chunk_size=100,
- nb_threads=args.nb_threads,
- tasks=args.grids_world_tasks,
- )
-
- if args.grids_science_tasks is None:
- science_w_quizzes = None
- else:
- science_problem = grids.Grids(
- max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
- chunk_size=100,
- nb_threads=args.nb_threads,
- tasks=args.grids_science_tasks,
- )
- science_w_quizzes = science_problem.generate_w_quizzes(100)
-
- if not args.resume:
- science_problem.save_some_examples(args.result_dir, "science_")
-
-
-else:
- raise ValueError
+problem = grids.Grids(
+ max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
+ chunk_size=100,
+ nb_threads=args.nb_threads,
+ tasks=args.grids_world_tasks,
+)
if not args.resume:
problem.save_some_examples(args.result_dir)
######################################################################
-def save_additional_results(model, models, science_w_quizzes):
+def save_additional_results(model, models):
# Save generated quizzes with the successive steps
recorder = []
log_string(f"wrote {filename}")
- ######################################################################
-
- if science_w_quizzes is not None:
- struct = ("A", "f_A", "B", "f_B")
- mask = (0, 0, 0, 1)
- result, correct, _ = quiz_machine.predict(
- model=model,
- quizzes=science_w_quizzes.to(main_device),
- struct=struct,
- mask=mask,
- )
-
- predicted_parts = torch.tensor(mask, device=correct.device)[None, :].expand(
- correct.size(0), -1
- )
- correct = (2 * correct - 1) * (predicted_parts.sum(dim=-1) == 1).long()
-
- nb_correct = (correct == 1).long().sum()
- nb_total = (correct != 0).long().sum()
-
- log_string(
- f"science_accuracy {n_epoch} model {model.id} val {nb_correct} / {nb_total}"
- )
-
- i = correct == 1
- j = correct != 1
-
- result = torch.cat([result[i], result[j]], dim=0)
- correct = torch.cat([correct[i], correct[j]], dim=0)
- correct_parts = predicted_parts * correct[:, None]
-
- result = result[:128]
- predicted_parts = predicted_parts[:128]
- correct_parts = correct_parts[:128]
-
- quiz_machine.problem.save_quizzes_as_image(
- args.result_dir,
- f"culture_science_{n_epoch:04d}_{model.id:02d}.png",
- quizzes=result,
- predicted_parts=predicted_parts,
- correct_parts=correct_parts,
- )
-
######################################################################
for model in models:
model.recorded_c_quizzes = []
+ teaching_count = torch.zeros(len(models), len(models), dtype=torch.int64)
+
while nb_validated < nb_to_validate:
model_for_generation = models[torch.randint(len(models), (1,)).item()]
proba_other_solutions[dont_get_it] = -1
i = proba_other_solutions.argmax()
model.recorded_c_quizzes.append(solved_c_quizzes[s, i])
+ teaching_count[i, model.id] += 1
nb_validated += 1
duration = time.perf_counter() - start_time
f"keep c_quizzes model {model_for_generation.id} validated {nb_validated} / {nb_to_generate_per_iteration} ({100*nb_validated/nb_to_generate_per_iteration:.02f}%) nb_accumulated {nb_validated} / {nb_to_validate} (finishes {e} -- {int((nb_validated * 3600)/duration)}/h)"
)
+ for s in range(teaching_count.size(0)):
+ o = [x.item() for x in teaching_count[s]]
+ log_string(f"teacher model {s} to {o}")
+
for model in models:
new_bag = torch.cat([q[None, :] for q in model.recorded_c_quizzes], dim=0)
if n > 0:
model.train_c_quiz_bags.append(new_bag[:n])
if n < new_bag.size(0):
- model.test_c_quiz_bags.append(new_bag[:n])
+ model.test_c_quiz_bags.append(new_bag[n:])
- vq = new_bag[:128]
+ c_quizzes = new_bag[:128]
- seq_logprobas = quiz_machine.models_logprobas(
- models, vq, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
- ) + quiz_machine.models_logprobas(
- models, vq, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
- )
+ l = [
+ quiz_machine.models_logprobas(
+ model,
+ c_quizzes,
+ ("A", "f_A", "B", "f_B"),
+ (0, 0, 0, 1),
+ (0, 0, 1, 0),
+ )
+ + quiz_machine.models_logprobas(
+ model,
+ c_quizzes,
+ ("f_A", "A", "f_B", "B"),
+ (0, 0, 0, 1),
+ (0, 0, 1, 0),
+ )
+ for model in models
+ ]
+
+ seq_logprobas = torch.cat([x[:, None] for x in l], dim=1)
probas = seq_logprobas.exp()
filename = f"culture_c_quiz_{n_epoch:04d}.png"
quiz_machine.problem.save_quizzes_as_image(
- args.result_dir, filename, vq, comments=comments
+ args.result_dir, filename, c_quizzes, comments=comments
)
log_string(
log_string(f"wrote {filename}")
for model in weakest_models:
- save_additional_results(model, models, science_w_quizzes)
+ save_additional_results(model, models)
######################################################################