nb_samples_accumulated = 0
full_input, full_mask_loss = quiz_machine.data_input(
- model, args.nb_test_samples
+ args.nb_test_samples, model.test_c_quiz_bags
)
src = zip(
full_input.split(args.batch_size), full_mask_loss.split(args.batch_size)
hard_w_quizzes = []
- full_input, full_mask_loss = quiz_machine.data_input(model, args.nb_train_samples)
+ full_input, full_mask_loss = quiz_machine.data_input(
+ args.nb_train_samples, model.train_c_quiz_bags
+ )
src = zip(full_input.split(args.batch_size), full_mask_loss.split(args.batch_size))
for input, mask_loss in tqdm.tqdm(
# This is nb_quizzes x nb_models
- seq_logproba = quiz_machine.models_logprobas(
- models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
- ) + quiz_machine.models_logprobas(
- models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
- )
+ l = [
+ quiz_machine.models_logprobas(
+ model, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
+ )
+ + quiz_machine.models_logprobas(
+ model, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
+ )
+ for model in models
+ ]
+
+ seq_logprobas = torch.cat([x[None, :] for x in l])
- probas = seq_logproba.exp()
+ probas = seq_logprobas.exp()
comments = []
- for l in seq_logproba:
+ for l in seq_logprobas:
comments.append("proba " + " ".join([f"{x.exp().item():.02f}" for x in l]))
##
######################################################################
-def record_new_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100):
- nb_to_validate = nb_for_train + nb_for_test
- nb_to_generate_per_iteration = max(args.physical_batch_size, nb_to_validate)
- nb_validated = 0
+def model_proba_solutions(m, quizzes):
+ l = quiz_machine.models_logprobas(
+ m, quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
+ ) + quiz_machine.models_logprobas(
+ m, quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
+ )
+
+ return l.exp()
+
- recorded_validated = []
+def record_new_c_quizzes(models, quiz_machine, nb_for_train, nb_for_test):
+ nb_validated, nb_to_validate = 0, nb_for_train + nb_for_test
+ nb_to_generate_per_iteration = nb_to_validate
start_time = time.perf_counter()
- nb_validated_per_model = torch.zeros(len(models), dtype=torch.int64)
+ for model in models:
+ model.recorded_c_quizzes = []
- while nb_validated_per_model.sum() < nb_to_validate:
+ while nb_validated < nb_to_validate:
model_for_generation = models[torch.randint(len(models), (1,)).item()]
# We generate quizzes with a procedure that injects some
c_quizzes = c_quizzes[to_keep]
- # This is nb_quizzes x nb_models
+ # Compute the responses of all the models on the c_quizzes,
+ # and their proba estimates of their responses
solved_c_quizzes = c_quizzes[:, None, :].expand(-1, len(models), -1).clone()
- seq_logproba = torch.zeros(
+ proba_own_solution = torch.zeros(
c_quizzes.size(0), len(models), device=solved_c_quizzes.device
)
- for m in models:
- (
- solved_c_quizzes[:, m.id],
- _,
- seq_logproba[:, m.id],
- ) = quiz_machine.predict(
- m,
- solved_c_quizzes[:, m.id],
+ for model in models:
+ (solved_c_quizzes[:, model.id], _, _) = quiz_machine.predict(
+ model,
+ solved_c_quizzes[:, model.id],
struct=("A", "f_A", "B", "f_B"),
mask=(0, 0, 0, 1),
)
- #!!!!!!!!!!!!!!!!!!!!
- for m in range(seq_logproba.size(1)):
- l = quiz_machine.models_logprobas(
- [models[m]],
- solved_c_quizzes[:, m, :],
- ("A", "f_A", "B", "f_B"),
- (0, 0, 0, 1),
- (0, 0, 0, 0),
- )
- for s in range(seq_logproba.size(0)):
- print("DEBUG", seq_logproba[s, m].item(), l[s, 0].item())
- exit(0)
- #!!!!!!!!!!!!!!!!!!!!!!!!!
+ u = model_proba_solutions(model, solved_c_quizzes[:, model.id])
- # FINISH
+ proba_own_solution[:, model.id] = u
- seq_logproba = quiz_machine.models_logprobas(
- models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
- ) + quiz_machine.models_logprobas(
- models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
- )
-
- probas = seq_logproba.exp()
-
- nb_succeed = (probas >= args.proba_understands).long().sum(dim=1)
- nb_fail = (probas <= args.proba_not_understands).long().sum(dim=1)
-
- to_keep = (
- # (nb_succeed + nb_fail == probas.size(1))
- (nb_succeed >= args.min_succeed_to_validate)
- & (nb_fail >= 1)
- & (nb_fail <= args.max_fail_to_validate)
- )
-
- c_quizzes = c_quizzes[to_keep]
-
- if c_quizzes.size(0) > 0:
- nb_validated_per_model[model_for_generation.id] += c_quizzes.size(0)
- recorded_validated.append(c_quizzes)
- nb_validated = c_quizzes.size(0)
- else:
- nb_validated = 0
+ # Now for every model not confident of its response, we pick
+ # the most consistent from a model which is confident
- total_nb_validated = nb_validated_per_model.sum().item()
+ for s in range(proba_own_solution.size(0)):
+ dont_get_it = proba_own_solution[s, :] < args.proba_understands
+ if not dont_get_it.all():
+ for model in models:
+ if dont_get_it[model.id]:
+ proba_other_solutions = model_proba_solutions(
+ model, solved_c_quizzes[s]
+ )
+ proba_other_solutions[dont_get_it] = -1
+ i = proba_other_solutions.argmax()
+ model.recorded_c_quizzes.append(solved_c_quizzes[s, i])
+ nb_validated += 1
duration = time.perf_counter() - start_time
- if total_nb_validated > 0:
- if total_nb_validated < nb_to_validate:
- d = (
- (nb_to_validate - total_nb_validated)
- * duration
- / total_nb_validated
- )
+ if nb_validated > 0:
+ if nb_validated < nb_to_validate:
+ d = (nb_to_validate - nb_validated) * duration / nb_validated
e = (datetime.datetime.now() + datetime.timedelta(seconds=d)).strftime(
"%a %H:%M"
)
e = "???"
log_string(
- f"keep c_quizzes model {model_for_generation.id} validated {nb_validated} / {nb_to_generate_per_iteration} ({100*nb_validated/nb_to_generate_per_iteration:.02f}%) nb_accumulated {total_nb_validated} / {nb_to_validate} (finishes {e} -- {int((total_nb_validated * 3600)/duration)}/h)"
- )
-
- validated_quizzes = torch.cat(recorded_validated, dim=0)
-
- ######################################################################
- # store the new c_quizzes which have been validated
-
- v_train = validated_quizzes[:nb_for_train]
- quiz_machine.store_c_quizzes(v_train, for_train=True)
-
- v_test = validated_quizzes[nb_for_train:nb_to_validate]
- quiz_machine.store_c_quizzes(v_test, for_train=False)
-
- ######################################################################
- # save images
-
- vq = validated_quizzes[torch.randperm(validated_quizzes.size(0))[:128]]
-
- if vq.size(0) > 0:
- seq_logproba = quiz_machine.models_logprobas(
- models, vq, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
- ) + quiz_machine.models_logprobas(
- models, vq, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
- )
-
- probas = seq_logproba.exp()
-
- comments = []
-
- for l in seq_logproba:
- comments.append("proba " + " ".join([f"{x.exp().item():.02f}" for x in l]))
-
- filename = f"culture_c_quiz_{n_epoch:04d}.png"
- quiz_machine.problem.save_quizzes_as_image(
- args.result_dir, filename, vq, comments=comments
- )
-
-
-######################################################################
-
-# The generator is very similar to a "solving GPT" except that it
-# deals with quizzes prologued with one token per solving GPT that
-# indicates if the said model solves it or not.
-#
-# There are three levels of solving 0->proba<=proba_not_understands,
-# 2->proba>=proba_understands and 1 otherwise.
-
-
-def generate_c_quizzes_with_generator(generator, quiz_machine, nb):
- generator.to(main_device)
-
- struct = ("A", "f_A", "B", "f_B")
-
- c_quizzes = quiz_machine.problem.create_empty_quizzes(nb, struct=struct)
- ar_mask = quiz_machine.make_quiz_mask(c_quizzes, struct, (1, 1, 1, 1))
-
- i = F.one_hot(
- torch.randint(args.nb_gpts, (c_quizzes.size(0),)),
- num_classes=args.nb_gpts,
- )
-
- prologs_c_quizzes = token_prolog_0 * i + token_prolog_2 * (1 - i)
- prologs_ar_mask = ar_mask.new_zeros(ar_mask.size(0), prologs_c_quizzes.size(1))
-
- prologued_c_quizzes = torch.cat([prologs_c_quizzes, c_quizzes], dim=1).to(
- main_device
- )
- prologued_ar_mask = torch.cat([prologs_ar_mask, ar_mask], dim=1).to(main_device)
-
- seq_logproba = torch.zeros(
- prologued_c_quizzes.size(0), device=prologued_c_quizzes.device
- )
-
- generator.temperature = args.temperature_hot
-
- with torch.autograd.no_grad():
- t = generator.training
- generator.eval()
-
- one_batch_masked_inplace_autoregression(
- generator,
- prologued_c_quizzes,
- prologued_ar_mask,
- seq_logproba,
- deterministic_synthesis=False,
+ f"keep c_quizzes model {model_for_generation.id} validated {nb_validated} / {nb_to_generate_per_iteration} ({100*nb_validated/nb_to_generate_per_iteration:.02f}%) nb_accumulated {nb_validated} / {nb_to_validate} (finishes {e} -- {int((nb_validated * 3600)/duration)}/h)"
)
- generator.train(t)
-
- generator.reset_transformations()
-
- prologued_c_quizzes = (
- prologued_c_quizzes * (prologued_c_quizzes < vocabulary_size).long()
- )
-
- c_quizzes = prologued_c_quizzes[:, prologs_c_quizzes.size(1) :]
-
- return c_quizzes.to("cpu"), prologs_c_quizzes.to("cpu")
-
-
-def batches_for_generator(generator, quiz_machine, models, fraction_w_quizzes=1.0):
- samples = []
-
- for _ in range(args.nb_train_samples // args.batch_size):
- while sum([x.size(0) for x in samples]) < args.batch_size:
- # Generate a bunch of quizzes
-
- if torch.rand(1).item() <= fraction_w_quizzes:
- # Either we start with the world quizzes
- c_quizzes = quiz_machine.problem.generate_w_quizzes(
- args.batch_size, progress_bar=False
- )
- else:
- # Or we use the generator itself to generate them
- c_quizzes, _ = generate_c_quizzes_with_generator(
- generator, quiz_machine, args.batch_size
- )
-
- # We remove the trivial ones
- to_keep = quiz_machine.problem.trivial(c_quizzes) == False
- c_quizzes = c_quizzes[to_keep]
-
- # If there are remaining ones, we compute the true prolog
- # that indicates how the GPTs solve it
-
- if c_quizzes.size(0) > 0:
- seq_logproba = quiz_machine.models_logprobas(
- models,
- c_quizzes,
- ("A", "f_A", "B", "f_B"),
- (0, 0, 0, 1),
- (0, 0, 1, 0),
- ) + quiz_machine.models_logprobas(
- models,
- c_quizzes,
- ("f_A", "A", "f_B", "B"),
- (0, 0, 0, 1),
- (0, 0, 1, 0),
- )
-
- probas = seq_logproba.exp()
-
- u0 = probas <= args.proba_not_understands
- u2 = probas >= args.proba_understands
- u1 = (u0 | u2) == False
-
- prologs = (
- (u0.long() * token_prolog_0)
- + (u1.long() * token_prolog_1)
- + (u2.long() * token_prolog_2)
- )
-
- prologued_c_quizzes = torch.cat([prologs, c_quizzes], dim=1)
-
- # nb_u2 = u2.long().sum(dim=1)
- # nb_u0 = u0.long().sum(dim=1)
- # prologued_c_quizzes = prologued_c_quizzes[(nb_u2 >= 1) & (nb_u0 >= 1)]
-
- if prologued_c_quizzes.size(0) > 0:
- samples.append(prologued_c_quizzes)
-
- # Now we yield a batch
-
- x = torch.cat(samples, dim=0)
- samples = [x[args.batch_size :]]
-
- yield x[: args.batch_size]
-
-
-def one_generator_epoch(
- generator, quiz_machine, models, fraction_w_quizzes, local_device=main_device
-):
- model.to(local_device).train()
-
- optimizer = torch.optim.Adam(generator.parameters(), lr=args.learning_rate)
-
- nb_train_samples, acc_train_loss = 0, 0.0
-
- src = batches_for_generator(
- generator=generator,
- quiz_machine=quiz_machine,
- models=models,
- fraction_w_quizzes=fraction_w_quizzes,
- )
-
- for input in tqdm.tqdm(
- src,
- dynamic_ncols=True,
- desc="training",
- total=args.nb_train_samples // args.batch_size,
- ):
- input = input.to(local_device)
-
- if nb_train_samples % args.batch_size == 0:
- optimizer.zero_grad()
-
- targets = input
-
- output = generator(mygpt.BracketedSequence(input)).x
- loss = F.cross_entropy(output.transpose(1, 2), targets)
- acc_train_loss += loss.item() * input.size(0)
- nb_train_samples += input.size(0)
-
- loss.backward()
-
- if nb_train_samples % args.batch_size == 0:
- optimizer.step()
-
- train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
-
- log_string(f"train_perplexity {n_epoch} generator - {train_perplexity}")
-
- generator.to(main_device)
-
-
-######################################################################
-
-
-def train_complexifier(model_gen, model_pred1, model_pred2):
- samples = []
- perf = []
+ for model in models:
+ new_bag = torch.cat([q[None, :] for q in model.recorded_c_quizzes], dim=0)
- optimizer = torch.optim.Adam(model_gen.parameters(), lr=args.learning_rate)
+ if new_bag.size(0) > 0:
+ n = (new_bag.size(0) * nb_for_train) // (nb_for_train + nb_for_test)
+ if n > 0:
+ model.train_c_quiz_bags.append(new_bag[:n])
+ if n < new_bag.size(0):
+ model.test_c_quiz_bags.append(new_bag[:n])
- nb_train_samples, acc_train_loss = 0, 0.0
+ vq = new_bag[:128]
- for n_epoch in range(args.nb_epochs):
- for b in range(args.nb_train_samples // args.batch_size):
- while sum([x.size(0) for x in samples]) < args.batch_size:
- c_quizzes = quiz_machine.generate_c_quizzes(
- args.inference_batch_size,
- model_for_generation=model_gen,
- procedure=c_quizzes_procedure,
- )
- to_keep = quiz_machine.problem.trivial(c_quizzes) == False
- c_quizzes = c_quizzes[to_keep]
- if c_quizzes.size(0) > 0:
- seq_logproba = quiz_machine.models_logprobas(
- [model_pred1, model_pred2],
- c_quizzes,
- ("A", "f_A", "B", "f_B"),
- (0, 0, 0, 1),
- ) + quiz_machine.models_logprobas(
- [model_pred1, model_pred2],
- c_quizzes,
- ("f_A", "A", "f_B", "B"),
- (0, 0, 0, 1),
- )
- probas = seq_logproba.exp()
- to_keep = (probas[:, model_pred1.id] >= args.proba_understands) & (
- probas[:, model_pred2.id] <= args.proba_not_understands
- )
- log_string(
- f"generating {to_keep.long().sum()} / {c_quizzes.size(0)}"
- )
- c_quizzes = c_quizzes[to_keep]
- if c_quizzes.size(0):
- samples.append(c_quizzes)
-
- log_string(f"full batch {sum([x.size(0) for x in samples])}")
-
- x = torch.cat(samples, dim=0)
-
- input = x[: args.batch_size]
- samples = [x[args.batch_size :]]
-
- # -------------------
-
- seq_logproba = quiz_machine.models_logprobas(
- [model_pred1, model_pred2],
- input,
- ("A", "f_A", "B", "f_B"),
- (0, 0, 0, 1),
+ seq_logprobas = quiz_machine.models_logprobas(
+ models, vq, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
) + quiz_machine.models_logprobas(
- [model_pred1, model_pred2],
- input,
- ("f_A", "A", "f_B", "B"),
- (0, 0, 0, 1),
+ models, vq, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
)
+ probas = seq_logprobas.exp()
+
comments = []
- for l in seq_logproba:
+ for l in seq_logprobas:
comments.append(
- f"proba {l[model_pred1.id].exp().item():.02f} {l[model_pred2.id].exp().item():.02f}"
+ "proba " + " ".join([f"{x.exp().item():.02f}" for x in l])
)
- filename = f"batch_{n_epoch:04d}_{b:04d}.png"
+ filename = f"culture_c_quiz_{n_epoch:04d}.png"
quiz_machine.problem.save_quizzes_as_image(
- args.result_dir, filename, input, comments=comments
+ args.result_dir, filename, vq, comments=comments
)
- log_string(f"wrote {filename}")
-
- # ------------------------
-
- input = input.to(main_device)
-
- if nb_train_samples % args.batch_size == 0:
- optimizer.zero_grad()
-
- output = model_gen(mygpt.BracketedSequence(input)).x
- loss = F.cross_entropy(output.transpose(1, 2), input)
- acc_train_loss += loss.item() * input.size(0)
- nb_train_samples += input.size(0)
-
- loss.backward()
- if nb_train_samples % args.batch_size == 0:
- optimizer.step()
-
- train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
-
- log_string(f"train_perplexity {n_epoch} model ae {train_perplexity}")
+ log_string(
+ f"nb_c_quizzes model {model.id} train {sum([q.size(0) for q in model.train_c_quiz_bags ])} test {sum([q.size(0) for q in model.test_c_quiz_bags ])}"
+ )
######################################################################
).to(main_device)
model.id = k
- model.c_quiz_bags = []
+ model.train_c_quiz_bags = []
+ model.test_c_quiz_bags = []
if args.schedule_free:
model.optimizer = schedulefree.AdamWScheduleFree(
######################################################################
-if args.test == "quant":
- nb_bits = 8
- for model in models:
- model.trunk.insert(
- 12,
- mygpt.CacheWrapper(
- mygpt.RandomBypass(
- nn.Sequential(
- nn.Linear(args.dim_model, nb_bits),
- mygpt.BSQ(nb_bits),
- nn.Linear(nb_bits, args.dim_model),
- ),
- 0.1,
- )
- ),
- )
-
- print(model)
- exit(0)
-
-
-######################################################################
-
current_epoch = 0
if args.resume:
######################################################################
-if args.test == "tsne":
- model = models[0]
-
- quizzes = []
- labels = []
- nb_samples_per_task = 1000
-
- for n, t in enumerate(args.grids_world_tasks.split(",")):
- quizzes.append(
- quiz_machine.problem.generate_w_quizzes(nb_samples_per_task, [t])
- )
- labels.append(torch.full((quizzes[-1].size(0),), n))
-
- quizzes = torch.cat(quizzes, dim=0)
- labels = torch.cat(labels, dim=0)
-
- with torch.autograd.no_grad():
- model.eval().to(main_device)
- record = []
- for input, targets in zip(
- quizzes.split(args.batch_size), labels.split(args.batch_size)
- ):
- input = input.to(main_device)
- bs = mygpt.BracketedSequence(input)
- bs = mygpt.BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
- bs = model.embedding(bs)
- bs = model.trunk[args.nb_blocks // 2](bs)
- record.append((bs.x.to("cpu"), targets))
-
- x = torch.cat([x for x, y in record], dim=0).flatten(1)
- y = torch.cat([y for x, y in record], dim=0)
-
- print(f"{x.size()=} {y.size()=}")
- # torch.save((x,y), "/tmp/embed.pth")
- # exit(0)
-
- from sklearn.manifold import TSNE
-
- x_np = x.numpy()
- z_np = TSNE(n_components=2, perplexity=50).fit_transform(x_np)
- z = torch.from_numpy(z_np)
-
- print(f"{z.size()=}")
-
- with open("/tmp/result.dat", "w") as f:
- for k in range(z.size(0)):
- f.write(f"{y[k]} {z[k,0]} {z[k,1]}\n")
-
- exit(0)
-
-######################################################################
-
-if args.test == "generator":
- token_prolog_0 = vocabulary_size + 0
- token_prolog_1 = vocabulary_size + 1
- token_prolog_2 = vocabulary_size + 2
- generator_vocabulary_size = vocabulary_size + 3
-
- generator = mygpt.MyGPT(
- vocabulary_size=generator_vocabulary_size,
- dim_model=args.dim_model,
- dim_keys=args.dim_keys,
- dim_hidden=args.dim_hidden,
- nb_heads=args.nb_heads,
- nb_blocks=args.nb_blocks,
- compute_attzero=compute_causal_attzero,
- dropout=args.dropout,
- ).to(main_device)
-
- generator.main_test_accuracy = 0.0
-
- filename = f"generator.pth"
-
- try:
- d = torch.load(os.path.join(args.result_dir, filename))
- generator.load_state_dict(d[0])
- generator.main_test_accuracy = d[1]
- log_string(f"successfully loaded {filename}")
- except FileNotFoundError:
- log_string(f"cannot find {filename}")
- pass
-
- for n_epoch in range(args.nb_epochs):
- one_generator_epoch(
- generator,
- quiz_machine=quiz_machine,
- models=models,
- fraction_w_quizzes=1 if n_epoch < 25 else 0.5,
- local_device=main_device,
- )
-
- filename = f"generator.pth"
- torch.save(
- (generator.state_dict(), generator.main_test_accuracy),
- os.path.join(args.result_dir, filename),
- )
- log_string(f"wrote {filename}")
-
- c_quizzes, prologs = generate_c_quizzes_with_generator(
- generator, quiz_machine, args.batch_size
- )
-
- seq_logproba = quiz_machine.models_logprobas(
- models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
- ) + quiz_machine.models_logprobas(
- models, c_quizzes, ("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0)
- )
-
- probas = seq_logproba.exp()
-
- u0 = probas <= args.proba_not_understands
- u2 = probas >= args.proba_understands
- u1 = (u0 | u2) == False
-
- predicted_prologs = (
- (u0.long() * token_prolog_0)
- + (u1.long() * token_prolog_1)
- + (u2.long() * token_prolog_2)
- )
-
- comments = []
-
- nb_errors = (predicted_prologs != prologs).long().sum()
- nb_total = prologs.numel()
-
- log_string(f"generator_error {nb_errors} / {nb_total}")
-
- def readable(prologs):
- return (prologs == token_prolog_1) + 2 * (prologs == token_prolog_2)
-
- for aa, ee, ff in zip(probas, readable(predicted_prologs), readable(prologs)):
- sa = "prolog " + " ".join(
- [f"{e.item()}/{f.item()}" for e, f in zip(ee, ff)]
- )
- sp = "proba " + " ".join([f"{p.item():.02f}" for p in aa])
- comments.append(sa + "\n" + sp)
-
- filename = f"generator_batch_{n_epoch:04d}.png"
- quiz_machine.problem.save_quizzes_as_image(
- args.result_dir, filename, c_quizzes, comments=comments
- )
- log_string(f"wrote {filename}")
-
- exit(0)
-
-######################################################################
-
for n_epoch in range(current_epoch, args.nb_epochs):
state = {"current_epoch": n_epoch}
filename = "state.pth"
record_new_c_quizzes(
models,
quiz_machine,
- nb_for_train=args.nb_new_c_quizzes_for_train,
- nb_for_test=args.nb_new_c_quizzes_for_test,
+ args.nb_new_c_quizzes_for_train,
+ args.nb_new_c_quizzes_for_test,
)
filename = "c_quizzes.pth"
model,
input,
ar_mask,
- acc_seq_logproba,
+ acc_seq_logprobas,
deterministic_synthesis=False,
):
if input.size(0) == 0:
all_n = torch.arange(t_next.size(0))
- acc_seq_logproba += ar_mask[:, s] * logits.log_softmax(dim=1)[all_n, t_next]
+ acc_seq_logprobas += ar_mask[:, s] * logits.log_softmax(dim=1)[all_n, t_next]
input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
model,
input,
ar_mask,
- seq_logproba,
+ seq_logprobas,
progress_bar_desc=None,
):
assert input.size() == ar_mask.size()
batches = zip(
input.split(self.batch_size),
ar_mask.split(self.batch_size),
- seq_logproba.split(self.batch_size),
+ seq_logprobas.split(self.batch_size),
)
if progress_bar_desc is not None:
t = model.training
model.eval()
- for input, ar_mask, seq_logproba in batches:
+ for input, ar_mask, seq_logprobas in batches:
one_batch_masked_inplace_autoregression(
model=model,
input=input,
ar_mask=ar_mask,
- acc_seq_logproba=seq_logproba,
+ acc_seq_logprobas=seq_logprobas,
deterministic_synthesis=False,
)
######################################################################
- def data_input(self, model, nb_samples):
- if len(model.c_quiz_bags) > 0:
- c_quizzes = torch.cat(model.c_quiz_bags, dim=0)
+ def data_input(self, nb_samples, c_quiz_bags):
+ if len(c_quiz_bags) > 0:
+ c_quizzes = torch.cat(c_quiz_bags, dim=0)
if c_quizzes.size(0) > nb_samples // 2:
i = torch.randperm(c_quizzes.size(0))[: nb_samples // 2]
ar_mask = self.make_quiz_mask(quizzes=quizzes, struct=struct, mask=mask)
result = quizzes * (1 - ar_mask)
- seq_logproba = torch.zeros(quizzes.size(0), device=self.device)
+ seq_logprobas = torch.zeros(quizzes.size(0), device=self.device)
self.autoregression(
model=model,
input=result,
ar_mask=ar_mask,
- seq_logproba=seq_logproba,
+ seq_logprobas=seq_logprobas,
progress_bar_desc="accuracy",
)
correct = (result == quizzes).min(dim=1).values.long()
- result = result.to("cpu")
- correct = correct.to("cpu")
- seq_logproba = seq_logproba.to("cpu")
+ # result = result.to("cpu")
+ # correct = correct.to("cpu")
+ # seq_logprobas = seq_logprobas.to("cpu")
- return result, correct, seq_logproba
+ return result, correct, seq_logprobas
######################################################################
result[i], correct[i], _ = self.predict(
model=model, quizzes=input[i], struct=struct, mask=mask_generate
)
+
predicted_parts[i] = torch.tensor(mask_generate, device=self.device)[
None, :
]
def models_logprobas(
self,
- models_for_validation,
+ model,
c_quizzes,
struct,
mask_loss,
c_quizzes = self.problem.reconfigure(c_quizzes, struct)
- seq_logproba = torch.zeros(
+ seq_logprobas = torch.zeros(
c_quizzes.size(0),
- max([m.id for m in models_for_validation]) + 1,
device=device,
)
# c_quizzes, self.prompt_noise, struct=struct, mask=mask_noise
# )
- for model in models_for_validation:
- with torch.autograd.no_grad():
- t = model.training
- model.eval()
-
- for input, l in zip(
- c_quizzes.split(self.batch_size),
- seq_logproba.split(self.batch_size),
- ):
- input = input.to(device)
- quiz_mask_loss = self.make_quiz_mask(
- input, struct=struct, mask=mask_loss
- )
- output = model(mygpt.BracketedSequence(input)).x
- l[:, model.id] = (
- -F.cross_entropy(
- output.transpose(1, 2), input, reduction="none"
- )
- * quiz_mask_loss
- ).sum(dim=1)
+ with torch.autograd.no_grad():
+ t = model.training
+ model.eval()
+
+ for input, l in zip(
+ c_quizzes.split(self.batch_size),
+ seq_logprobas.split(self.batch_size),
+ ):
+ input = input.to(device)
+ quiz_mask_loss = self.make_quiz_mask(
+ input, struct=struct, mask=mask_loss
+ )
+ output = model(mygpt.BracketedSequence(input)).x
+ l[...] = (
+ -F.cross_entropy(output.transpose(1, 2), input, reduction="none")
+ * quiz_mask_loss
+ ).sum(dim=1)
- model.train(t)
+ model.train(t)
- return seq_logproba.to("cpu")
+ return seq_logprobas.to("cpu")
######################################################################
def generate_c_quizzes(self, nb, model_for_generation, procedure, recorder=None):
- seq_logproba = torch.zeros(nb, device=self.device)
+ seq_logprobas = torch.zeros(nb, device=self.device)
c_quizzes = None
model=model_for_generation,
input=c_quizzes,
ar_mask=self.make_quiz_mask(c_quizzes, s, m),
- seq_logproba=seq_logproba,
+ seq_logprobas=seq_logprobas,
)
model_for_generation.reset_transformations()