# ----------------------------------
-parser.add_argument("--nb_gpts", type=int, default=5)
+parser.add_argument("--nb_models", type=int, default=5)
parser.add_argument("--min_succeed_to_validate", type=int, default=2)
# (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_modifier_cold),
]
+# quad_order, quad_generate, quad_noise, quad_loss
+
+data_structures = [
+ (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 0, 1)),
+ (("A", "f_A", "B", "f_B"), (0, 0, 1, 0), (0, 0, 0, 1), (0, 0, 1, 0)),
+ (("A", "f_A", "B", "f_B"), (0, 1, 0, 0), (1, 0, 0, 0), (0, 1, 0, 0)),
+ (("A", "f_A", "B", "f_B"), (1, 0, 0, 0), (0, 1, 0, 0), (1, 0, 0, 0)),
+ (("A", "f_A", "B", "f_B"), (1, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
+]
+
######################################################################
trunk_blocks = []
for b in range(nb_blocks):
+ # if b == nb_blocks//2:
+ # trunk_blocks += [
+ # QKVAttention(
+ # dim_in=dim_model,
+ # dim_qk=dim_keys,
+ # dim_v=dim_model // nb_heads,
+ # nb_heads=nb_heads,
+ # attention_dropout=dropout,
+ # ),
+ # VaswaniPositionalEncoding(len_max=1e5)
+ # ]
+
trunk_blocks += [
WithResidual(
CacheWrapper(
mask_loss.to(local_device),
)
- # quiz_machine.problem.save_quizzes_as_image(
- # args.result_dir,
- # filename="a.png",
- # quizzes=a,
- # )
-
- # quiz_machine.problem.save_quizzes_as_image(
- # args.result_dir,
- # filename="b.png",
- # quizzes=b,
- # )
-
- # time.sleep(1000)
-
def NTC_masked_cross_entropy(output, targets, mask):
loss_per_token = F.cross_entropy(output.transpose(1, 2), targets, reduction="none")
return (mask_generate.sum(dim=1) < mask_generate.size(1) // 2).long()
+# This function returns a tensor of same shape as low, full of uniform
+# random values in [0,1], such that the values corresponding to the
+# True in low are all lesser than the values corresponding to the
+# False.
+
+
def prioritized_rand(low):
x = torch.rand(low.size(), device=low.device).sort(dim=1, descending=True).values
k = torch.rand(low.size(), device=low.device) + low.long()
return y
-def ae_generate(
- model, input, mask_generate, n_epoch, noise_proba, nb_iterations_max=50
-):
+def ae_generate(model, input, mask_generate, noise_proba, nb_iterations_max=50):
noise = torch.randint(
quiz_machine.problem.nb_colors, input.size(), device=input.device
)
- input = (1 - mask_generate) * input + mask_generate * noise
- proba_erased = noise_proba
+ input = (1 - mask_generate) * input + mask_generate * noise
d = deterministic(mask_generate)[:, None]
+
changed = True
for it in range(nb_iterations_max):
r = prioritized_rand(final != input)
- mask_erased = mask_generate * (r <= proba_erased).long()
+ mask_erased = mask_generate * (r <= noise_proba).long()
+
mask_to_change = d * mask_generate + (1 - d) * mask_erased
update = (1 - mask_to_change) * input + mask_to_change * final
return result
-def test_ae(local_device=main_device):
- model = MyAttentionAE(
- vocabulary_size=vocabulary_size,
- dim_model=args.dim_model,
- dim_keys=args.dim_keys,
- dim_hidden=args.dim_hidden,
- nb_heads=args.nb_heads,
- nb_blocks=args.nb_blocks,
- dropout=args.dropout,
- ).to(main_device)
-
- # quad_order, quad_generate, quad_noise, quad_loss
-
- data_structures = [
- (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 0, 1)),
- (("A", "f_A", "B", "f_B"), (0, 0, 1, 0), (0, 0, 0, 1), (0, 0, 1, 0)),
- (("A", "f_A", "B", "f_B"), (0, 1, 0, 0), (1, 0, 0, 0), (0, 1, 0, 0)),
- (("A", "f_A", "B", "f_B"), (1, 0, 0, 0), (0, 1, 0, 0), (1, 0, 0, 0)),
- (("A", "f_A", "B", "f_B"), (1, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
- ]
-
- model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
-
- model.to(local_device).train()
- optimizer_to(model.optimizer, local_device)
-
- nb_iterations = 25
- probs_iterations = torch.arange(nb_iterations, device=main_device)
- probs_iterations = 0.1 ** (probs_iterations / nb_iterations)
- probs_iterations = probs_iterations[None, :] / probs_iterations.sum()
+######################################################################
- for n_epoch in range(args.nb_epochs):
- # ----------------------
- # Train
- model.train()
- nb_train_samples, acc_train_loss = 0, 0.0
+def run_ae_test(model, quiz_machine, n_epoch, local_device=main_device):
+ with torch.autograd.no_grad():
+ model.eval().to(local_device)
- noise_proba = 0.05
+ nb_test_samples, acc_test_loss = 0, 0.0
for input, mask_generate, mask_loss in ae_batches(
quiz_machine,
- args.nb_train_samples,
+ args.nb_test_samples,
data_structures,
local_device,
- "training",
+ "test",
):
- if nb_train_samples % args.batch_size == 0:
- model.optimizer.zero_grad()
-
d = deterministic(mask_generate)
p = probs_iterations.expand(input.size(0), -1)
dist = torch.distributions.categorical.Categorical(probs=p)
N1 = N0 + 1
N0 = (1 - d) * N0
N1 = (1 - d) * N1 + d * nb_iterations
-
targets, input = degrade_input(
input, mask_generate, (0 * N1, N1), noise_proba=noise_proba
)
-
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- # for n in ["input", "targets"]:
- # filename = f"{n}.png"
- # quiz_machine.problem.save_quizzes_as_image(
- # args.result_dir,
- # filename,
- # quizzes=locals()[n],
- # )
- # log_string(f"wrote {filename}")
- # time.sleep(1000)
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
input_with_mask = NTC_channel_cat(input, mask_generate)
logits = model(input_with_mask)
loss = NTC_masked_cross_entropy(logits, targets, mask_loss)
- acc_train_loss += loss.item() * input.size(0)
- nb_train_samples += input.size(0)
+ acc_test_loss += loss.item() * input.size(0)
+ nb_test_samples += input.size(0)
- loss.backward()
+ log_string(
+ f"test_loss {n_epoch} model {model.id} {acc_test_loss/nb_test_samples}"
+ )
- if nb_train_samples % args.batch_size == 0:
- model.optimizer.step()
+ # -------------------------------------------
+ # Test generation
- train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
+ nb_correct, nb_total, record = 0, 0, []
- log_string(f"train_loss {n_epoch} model AE {acc_train_loss/nb_train_samples}")
+ for input, mask_generate, mask_loss in ae_batches(
+ quiz_machine,
+ args.nb_test_samples,
+ data_structures,
+ local_device,
+ "test",
+ ):
+ targets = input.clone()
+ result = ae_generate(
+ model, (1 - mask_generate) * input, mask_generate, noise_proba
+ )
+ correct = (result == targets).min(dim=1).values.long()
+ predicted_parts = mask_generate.reshape(mask_generate.size(0), 4, -1)[
+ :, :, 1
+ ]
+ solution_is_deterministic = predicted_parts.sum(dim=-1) == 1
+ correct = (2 * correct - 1) * (solution_is_deterministic).long()
+ nb_correct += (correct == 1).long().sum()
+ nb_total += (correct != 0).long().sum()
+ correct_parts = predicted_parts * correct[:, None]
+ record.append((result, predicted_parts, correct_parts))
- # ----------------------
- # Test
+ log_string(
+ f"test_accuracy {n_epoch} model {model.id} nb_correct {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)"
+ )
- with torch.autograd.no_grad():
- model.eval()
+ model.test_accuracy = nb_correct / nb_total
- nb_test_samples, acc_test_loss = 0, 0.0
+ filename = f"prediction_ae_{n_epoch:04d}.png"
- for input, mask_generate, mask_loss in ae_batches(
- quiz_machine,
- args.nb_test_samples,
- data_structures,
- local_device,
- "test",
- ):
- d = deterministic(mask_generate)
- p = probs_iterations.expand(input.size(0), -1)
- dist = torch.distributions.categorical.Categorical(probs=p)
- N0 = dist.sample()
- N1 = N0 + 1
- N0 = (1 - d) * N0
- N1 = (1 - d) * N1 + d * nb_iterations
- targets, input = degrade_input(
- input, mask_generate, (0 * N1, N1), noise_proba=noise_proba
- )
- input_with_mask = NTC_channel_cat(input, mask_generate)
- logits = model(input_with_mask)
- loss = NTC_masked_cross_entropy(logits, targets, mask_loss)
- acc_test_loss += loss.item() * input.size(0)
- nb_test_samples += input.size(0)
+ result, predicted_parts, correct_parts = (
+ torch.cat([x[i] for x in record]) for i in [0, 1, 2]
+ )
+
+ quiz_machine.problem.save_quizzes_as_image(
+ args.result_dir,
+ filename,
+ quizzes=result,
+ predicted_parts=predicted_parts,
+ correct_parts=correct_parts,
+ )
- log_string(f"test_loss {n_epoch} model AE {acc_test_loss/nb_test_samples}")
+ log_string(f"wrote {filename}")
- # -------------------------------------------
- # Test generation
- for ns, s in enumerate(data_structures):
- quad_order, quad_generate, _, _ = s
+######################################################################
- input, mask_generate, _ = next(
- ae_batches(quiz_machine, 128, [s], local_device, batch_size=128)
- )
- targets = input.clone()
- input = ae_generate(
- model,
- input,
- mask_generate,
- n_epoch,
- noise_proba=noise_proba,
- )
+def one_ae_epoch(model, quiz_machine, n_epoch, local_device=main_device):
+ model.train().to(local_device)
- correct = (input == targets).min(dim=1).values.long()
- predicted_parts = torch.tensor(quad_generate, device=input.device)
- predicted_parts = predicted_parts[None, :].expand(input.size(0), -1)
- solution_is_deterministic = predicted_parts.sum(dim=-1) == 1
- correct = (2 * correct - 1) * (solution_is_deterministic).long()
- nb_correct = (correct == 1).long().sum()
- nb_total = (correct != 0).long().sum()
- correct_parts = predicted_parts * correct[:, None]
-
- log_string(
- f"test_accuracy {n_epoch} model AE setup {ns} {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)"
- )
+ nb_train_samples, acc_train_loss = 0, 0.0
- filename = f"prediction_ae_{n_epoch:04d}_{ns}.png"
+ for input, mask_generate, mask_loss in ae_batches(
+ quiz_machine,
+ args.nb_train_samples,
+ data_structures,
+ local_device,
+ "training",
+ ):
+ if nb_train_samples % args.batch_size == 0:
+ model.optimizer.zero_grad()
- quiz_machine.problem.save_quizzes_as_image(
- args.result_dir,
- filename,
- quizzes=input,
- predicted_parts=predicted_parts,
- correct_parts=correct_parts,
- )
+ d = deterministic(mask_generate)
+ p = probs_iterations.expand(input.size(0), -1)
+ dist = torch.distributions.categorical.Categorical(probs=p)
+ N0 = dist.sample()
+ N1 = N0 + 1
+ N0 = (1 - d) * N0
+ N1 = (1 - d) * N1 + d * nb_iterations
- log_string(f"wrote {filename}")
+ targets, input = degrade_input(
+ input, mask_generate, (0 * N1, N1), noise_proba=noise_proba
+ )
+ input_with_mask = NTC_channel_cat(input, mask_generate)
+ logits = model(input_with_mask)
+ loss = NTC_masked_cross_entropy(logits, targets, mask_loss)
+ acc_train_loss += loss.item() * input.size(0)
+ nb_train_samples += input.size(0)
-if args.test == "ae":
- test_ae(local_device=main_device)
- exit(0)
+ loss.backward()
+
+ if nb_train_samples % args.batch_size == 0:
+ model.optimizer.step()
+
+ log_string(
+ f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}"
+ )
+
+ run_ae_test(model, quiz_machine, n_epoch, local_device=local_device)
+
+
+######################################################################
+
+noise_proba = 0.05
+
+nb_iterations = 25
+probs_iterations = 0.1 ** torch.linspace(0, 1, nb_iterations, device=main_device)
+probs_iterations = probs_iterations[None, :] / probs_iterations.sum()
+
+models = []
+
+for i in range(args.nb_models):
+ model = MyAttentionAE(
+ vocabulary_size=vocabulary_size,
+ dim_model=args.dim_model,
+ dim_keys=args.dim_keys,
+ dim_hidden=args.dim_hidden,
+ nb_heads=args.nb_heads,
+ nb_blocks=args.nb_blocks,
+ dropout=args.dropout,
+ ).to(main_device)
+
+ model.id = i
+ model.test_accuracy = 0.0
+ model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+
+ model.to(main_device).train()
+ optimizer_to(model.optimizer, main_device)
+
+ models.append(model)
+
+for n_epoch in range(args.nb_epochs):
+ ranked_models = sorted(models, key=lambda m: float(m.test_accuracy))
+ weakest_models = ranked_models[: len(gpus)]
+
+ threads = []
+
+ start_time = time.perf_counter()
+
+ for gpu, model in zip(gpus, weakest_models):
+ log_string(f"training model {model.id}")
+
+ t = threading.Thread(
+ target=one_ae_epoch, daemon=True, args=(model, quiz_machine, n_epoch, gpu)
+ )
+
+ threads.append(t)
+
+ t.start()
+
+ for t in threads:
+ t.join()
######################################################################
def compute_causal_attzero(t_q, t_k):
return t_q < t_k
- for k in range(args.nb_gpts):
+ for k in range(args.nb_models):
log_string(f"creating model {k}")
model = mygpt.MyGPT(
if args.dirty_debug:
args.accuracy_to_make_c_quizzes = 0.0
- args.nb_gpts = 2
+ args.nb_models = 2
args.nb_new_c_quizzes_for_train = 100
args.nb_new_c_quizzes_for_test = 10