+nb_parameters = sum(p.numel() for p in models[0].parameters())
+log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
+
+######################################################################
+
+if args.nb_new_c_quizzes_for_train is None:
+ args.nb_new_c_quizzes_for_train = args.nb_train_samples // 100
+
+if args.nb_new_c_quizzes_for_test is None:
+ args.nb_new_c_quizzes_for_test = args.nb_test_samples // 100
+
+log_string(
+ f"nb_new_c_quizzes_for_train {args.nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {args.nb_new_c_quizzes_for_test}"
+)
+
+######################################################################
+
+if args.dirty_debug:
+ args.accuracy_to_make_c_quizzes = 0.0
+ args.nb_gpts = 2
+ args.nb_new_c_quizzes_for_train = 100
+ args.nb_new_c_quizzes_for_test = 10
+
+######################################################################
+
+if args.test == "tsne":
+ model = models[0]
+
+ quizzes = []
+ labels = []
+ nb_samples_per_task = 1000
+
+ for n, t in enumerate(args.grids_world_tasks.split(",")):
+ quizzes.append(
+ quiz_machine.problem.generate_w_quizzes(nb_samples_per_task, [t])
+ )
+ labels.append(torch.full((quizzes[-1].size(0),), n))
+
+ quizzes = torch.cat(quizzes, dim=0)
+ labels = torch.cat(labels, dim=0)
+
+ with torch.autograd.no_grad():
+ model.eval().to(main_device)
+ record = []
+ for input, targets in zip(
+ quizzes.split(args.batch_size), labels.split(args.batch_size)
+ ):
+ input = input.to(main_device)
+ bs = mygpt.BracketedSequence(input)
+ bs = mygpt.BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
+ bs = model.embedding(bs)
+ bs = model.trunk[args.nb_blocks // 2](bs)
+ record.append((bs.x.to("cpu"), targets))
+
+ x = torch.cat([x for x, y in record], dim=0).flatten(1)
+ y = torch.cat([y for x, y in record], dim=0)
+
+ print(f"{x.size()=} {y.size()=}")
+ # torch.save((x,y), "/tmp/embed.pth")
+ # exit(0)
+
+ from sklearn.manifold import TSNE
+
+ x_np = x.numpy()
+ z_np = TSNE(n_components=2, perplexity=50).fit_transform(x_np)
+ z = torch.from_numpy(z_np)
+
+ print(f"{z.size()=}")
+
+ with open("/tmp/result.dat", "w") as f:
+ for k in range(z.size(0)):
+ f.write(f"{y[k]} {z[k,0]} {z[k,1]}\n")
+
+ exit(0)
+
+######################################################################
+
+if args.test == "generator":
+ token_prolog_0 = vocabulary_size + 0
+ token_prolog_1 = vocabulary_size + 1
+ token_prolog_2 = vocabulary_size + 2
+ generator_vocabulary_size = vocabulary_size + 3
+
+ generator = mygpt.MyGPT(
+ vocabulary_size=generator_vocabulary_size,
+ dim_model=args.dim_model,
+ dim_keys=args.dim_keys,
+ dim_hidden=args.dim_hidden,
+ nb_heads=args.nb_heads,
+ nb_blocks=args.nb_blocks,
+ compute_attzero=compute_causal_attzero,
+ dropout=args.dropout,
+ ).to(main_device)
+
+ generator.main_test_accuracy = 0.0
+
+ filename = f"generator.pth"
+
+ try:
+ d = torch.load(os.path.join(args.result_dir, filename))
+ generator.load_state_dict(d[0])
+ generator.main_test_accuracy = d[1]
+ log_string(f"successfully loaded {filename}")
+ except FileNotFoundError:
+ log_string(f"cannot find {filename}")
+ pass
+
+ for n_epoch in range(args.nb_epochs):
+ one_generator_epoch(
+ generator,
+ quiz_machine=quiz_machine,
+ models=models,
+ fraction_w_quizzes=1 if n_epoch < 25 else 0.5,
+ local_device=main_device,
+ )
+
+ filename = f"generator.pth"
+ torch.save(
+ (generator.state_dict(), generator.main_test_accuracy),
+ os.path.join(args.result_dir, filename),
+ )
+ log_string(f"wrote {filename}")
+
+ c_quizzes, prologs = generate_c_quizzes_with_generator(
+ generator, quiz_machine, args.batch_size
+ )
+
+ seq_logproba = quiz_machine.models_logprobas(
+ models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
+ ) + quiz_machine.models_logprobas(
+ models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
+ )
+
+ probas = seq_logproba.exp()
+
+ u0 = probas <= args.proba_not_understands
+ u2 = probas >= args.proba_understands
+ u1 = (u0 | u2) == False
+
+ predicted_prologs = (
+ (u0.long() * token_prolog_0)
+ + (u1.long() * token_prolog_1)
+ + (u2.long() * token_prolog_2)
+ )
+
+ comments = []
+
+ nb_errors = (predicted_prologs != prologs).long().sum()
+ nb_total = prologs.numel()
+
+ log_string(f"generator_error {nb_errors} / {nb_total}")
+
+ def readable(prologs):
+ return (prologs == token_prolog_1) + 2 * (prologs == token_prolog_2)
+
+ for aa, ee, ff in zip(probas, readable(predicted_prologs), readable(prologs)):
+ sa = "prolog " + " ".join(
+ [f"{e.item()}/{f.item()}" for e, f in zip(ee, ff)]