From e6eefefa4bafb35a5312b4d512c0fe03fd98ffc3 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 24 Aug 2024 16:02:25 +0200 Subject: [PATCH] Update. --- main.py | 227 +++++++++++++++++++++++--------------------------------- 1 file changed, 93 insertions(+), 134 deletions(-) diff --git a/main.py b/main.py index 7ce9b03..d3d237e 100755 --- a/main.py +++ b/main.py @@ -67,8 +67,6 @@ parser.add_argument("--learning_rate", type=float, default=5e-4) parser.add_argument("--reboot", action="store_true", default=False) -parser.add_argument("--schedule_free", action="store_true", default=False) - # ---------------------------------- parser.add_argument("--model", type=str, default="37M") @@ -335,8 +333,6 @@ def optimizer_to(optim, device): def run_tests(model, quiz_machine, local_device=main_device): with torch.autograd.no_grad(): model.to(local_device).eval() - if args.schedule_free: - model.optimizer.eval() nb_test_samples, acc_test_loss = 0, 0.0 nb_samples_accumulated = 0 @@ -389,9 +385,6 @@ def one_epoch(model, quiz_machine, local_device=main_device): model.to(local_device).train() optimizer_to(model.optimizer, local_device) - if args.schedule_free: - model.optimizer.train() - nb_train_samples, acc_train_loss = 0, 0.0 full_input, _, full_mask_loss = quiz_machine.data_input( @@ -829,6 +822,8 @@ class MyAttentionAE(nn.Module): m.weight.fill_(1.0) def forward(self, bs): + if torch.is_tensor(bs): + return self.forward(BracketedSequence(bs)).x bs = self.embedding(bs) bs = self.positional_encoding(bs) bs = self.trunk(bs) @@ -836,15 +831,22 @@ class MyAttentionAE(nn.Module): return bs -def ae_batches(quiz_machine, nb, data_structures, local_device, desc=None): +def ae_batches( + quiz_machine, + nb, + data_structures, + local_device, + desc=None, + batch_size=args.batch_size, +): full_input, full_mask_generate, full_mask_loss = quiz_machine.data_input( nb, data_structures=data_structures ) src = zip( - full_input.split(args.batch_size), - full_mask_generate.split(args.batch_size), - full_mask_loss.split(args.batch_size), + full_input.split(batch_size), + full_mask_generate.split(batch_size), + full_mask_loss.split(batch_size), ) if desc is not None: @@ -852,7 +854,7 @@ def ae_batches(quiz_machine, nb, data_structures, local_device, desc=None): src, dynamic_ncols=True, desc=desc, - total=full_input.size(0) // args.batch_size, + total=full_input.size(0) // batch_size, ) for input, mask_generate, mask_loss in src: @@ -863,34 +865,60 @@ def ae_batches(quiz_machine, nb, data_structures, local_device, desc=None): ) -def degrade_input_inplace(input, mask_generate, pure_noise=False): - if pure_noise: - mask_diffusion_noise = torch.rand( - mask_generate.size(), device=mask_generate.device - ) <= torch.rand((mask_generate.size(0), 1), device=mask_generate.device) +def degrade_input(input, mask_generate, *ts): + noise = torch.randint( + quiz_machine.problem.nb_colors, input.size(), device=input.device + ) - mask_diffusion_noise = mask_diffusion_noise.long() + r = torch.rand(mask_generate.size(), device=mask_generate.device) - input[...] = ( - mask_generate - * mask_diffusion_noise - * torch.randint( - quiz_machine.problem.nb_colors, input.size(), device=input.device - ) - + (1 - mask_generate * mask_diffusion_noise) * input - ) + result = [] - else: - model.eval() - for it in range(torch.randint(5, (1,)).item()): - logits = model( - mygpt.BracketedSequence( - torch.cat([input[:, :, None], mask_generate[:, :, None]], dim=2) - ) - ).x - dist = torch.distributions.categorical.Categorical(logits=logits) - input[...] = (1 - mask_generate) * input + mask_generate * dist.sample() - model.train() + for t in ts: + mask_diffusion_noise = mask_generate * (r <= t).long() + x = (1 - mask_diffusion_noise) * input + mask_diffusion_noise * noise + result.append(x) + + return result + + # quiz_machine.problem.save_quizzes_as_image( + # args.result_dir, + # filename="a.png", + # quizzes=a, + # ) + + # quiz_machine.problem.save_quizzes_as_image( + # args.result_dir, + # filename="b.png", + # quizzes=b, + # ) + + # time.sleep(1000) + + +def NTC_masked_cross_entropy(output, targets, mask): + loss_per_token = F.cross_entropy(output.transpose(1, 2), targets, reduction="none") + return (loss_per_token * mask).mean() + + +def NTC_channel_cat(*x): + return torch.cat([a.expand_as(x[0])[:, :, None] for a in x], dim=2) + + +def ae_generate(model, input, mask_generate, n_epoch, nb_iterations): + noise = torch.randint( + quiz_machine.problem.nb_colors, input.size(), device=input.device + ) + input = (1 - mask_generate) * input + mask_generate * noise + + for it in range(nb_iterations): + rho = input.new_full((input.size(0),), nb_iterations - 1 - it) + input_with_mask = NTC_channel_cat(input, mask_generate, rho[:, None]) + logits = model(input_with_mask) + dist = torch.distributions.categorical.Categorical(logits=logits) + input = (1 - mask_generate) * input + mask_generate * dist.sample() + + return input def test_ae(local_device=main_device): @@ -904,9 +932,8 @@ def test_ae(local_device=main_device): dropout=args.dropout, ).to(main_device) - pure_noise = True - # quad_order, quad_generate, quad_noise, quad_loss + data_structures = [ (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 0, 1)), (("A", "f_A", "B", "f_B"), (0, 0, 1, 0), (0, 0, 0, 1), (0, 0, 1, 0)), @@ -920,8 +947,7 @@ def test_ae(local_device=main_device): model.to(local_device).train() optimizer_to(model.optimizer, local_device) - if args.schedule_free: - model.optimizer.train() + nb_iterations = 5 for n_epoch in range(args.nb_epochs): # ---------------------- @@ -940,34 +966,16 @@ def test_ae(local_device=main_device): if nb_train_samples % args.batch_size == 0: model.optimizer.zero_grad() - targets = input.clone() - degrade_input_inplace(input, mask_generate, pure_noise=pure_noise) - - output = model( - mygpt.BracketedSequence( - torch.cat([input[:, :, None], mask_generate[:, :, None]], dim=2) - ) - ).x - - # for filename, quizzes in [ - # ("targets.png", targets), - # ("input.png", input), - # ("mask_generate.png", mask_generate), - # ("mask_loss.png", mask_loss), - # ]: - # quiz_machine.problem.save_quizzes_as_image( - # args.result_dir, - # filename, - # quizzes=quizzes, - # ) - # time.sleep(10000) - - loss_per_token = F.cross_entropy( - output.transpose(1, 2), targets, reduction="none" + rho = torch.randint(nb_iterations, (input.size(0), 1), device=input.device) + targets, input = degrade_input( + input, mask_generate, rho / nb_iterations, (rho + 1) / nb_iterations ) - loss = (loss_per_token * mask_loss).mean() + input_with_mask = NTC_channel_cat(input, mask_generate, rho) + output = model(input_with_mask) + loss = NTC_masked_cross_entropy(output, targets, mask_loss) acc_train_loss += loss.item() * input.size(0) nb_train_samples += input.size(0) + loss.backward() if nb_train_samples % args.batch_size == 0: @@ -992,17 +1000,15 @@ def test_ae(local_device=main_device): local_device, "test", ): - targets = input.clone() - degrade_input_inplace(input, mask_generate, pure_noise=pure_noise) - output = model( - mygpt.BracketedSequence( - torch.cat([input[:, :, None], mask_generate[:, :, None]], dim=2) - ) - ).x - loss_per_token = F.cross_entropy( - output.transpose(1, 2), targets, reduction="none" + rho = torch.randint( + nb_iterations, (input.size(0), 1), device=input.device + ) + targets, input = degrade_input( + input, mask_generate, rho / nb_iterations, (rho + 1) / nb_iterations ) - loss = (loss_per_token * mask_loss).mean() + input_with_mask = NTC_channel_cat(input, mask_generate, rho) + output = model(input_with_mask) + loss = NTC_masked_cross_entropy(output, targets, mask_loss) acc_test_loss += loss.item() * input.size(0) nb_test_samples += input.size(0) @@ -1014,73 +1020,36 @@ def test_ae(local_device=main_device): for ns, s in enumerate(data_structures): quad_order, quad_generate, _, _ = s - input, mask_generate, mask_loss = next( - ae_batches(quiz_machine, 128, [s], local_device) + input, mask_generate, _ = next( + ae_batches(quiz_machine, 128, [s], local_device, batch_size=128) ) targets = input.clone() - degrade_input_inplace(input, mask_generate, pure_noise=pure_noise) - result = input - - not_converged = torch.full( - (result.size(0),), True, device=result.device - ) - - for it in range(100): - pred_result = result.clone() - logits = model( - mygpt.BracketedSequence( - torch.cat( - [ - result[not_converged, :, None], - mask_generate[not_converged, :, None], - ], - dim=2, - ) - ) - ).x - dist = torch.distributions.categorical.Categorical(logits=logits) - update = (1 - mask_generate[not_converged]) * input[ - not_converged - ] + mask_generate[not_converged] * dist.sample() - result[not_converged] = update - not_converged = (pred_result != result).max(dim=1).values - if not not_converged.any(): - log_string(f"diffusion_converged {it=}") - break - - correct = (result == targets).min(dim=1).values.long() - predicted_parts = input.new(input.size(0), 4) - - nb = 0 - - predicted_parts = torch.tensor(quad_generate, device=result.device)[ - None, : - ] + input = ae_generate(model, input, mask_generate, n_epoch, nb_iterations) + correct = (input == targets).min(dim=1).values.long() + predicted_parts = torch.tensor(quad_generate, device=input.device) + predicted_parts = predicted_parts[None, :].expand(input.size(0), -1) solution_is_deterministic = predicted_parts.sum(dim=-1) == 1 correct = (2 * correct - 1) * (solution_is_deterministic).long() - nb_correct = (correct == 1).long().sum() nb_total = (correct != 0).long().sum() + correct_parts = predicted_parts * correct[:, None] log_string( f"test_accuracy {n_epoch} model AE setup {ns} {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)" ) - correct_parts = predicted_parts * correct[:, None] - predicted_parts = predicted_parts.expand_as(correct_parts) - - filename = f"prediction_ae_{n_epoch:04d}_{ns}.png" + filename = f"prediction_ae_{n_epoch:04d}_structure_{ns}.png" quiz_machine.problem.save_quizzes_as_image( args.result_dir, filename, - quizzes=result, + quizzes=input, predicted_parts=predicted_parts, correct_parts=correct_parts, ) - log_string(f"wrote {filename}") + log_string(f"wrote {filename}") if args.test == "ae": @@ -1096,9 +1065,6 @@ def create_models(): def compute_causal_attzero(t_q, t_k): return t_q < t_k - if args.schedule_free: - import schedulefree - for k in range(args.nb_gpts): log_string(f"creating model {k}") @@ -1132,14 +1098,7 @@ def create_models(): model.train_c_quiz_bags = [] model.test_c_quiz_bags = [] - if args.schedule_free: - model.optimizer = schedulefree.AdamWScheduleFree( - model.parameters(), lr=args.learning_rate - ) - else: - model.optimizer = torch.optim.Adam( - model.parameters(), lr=args.learning_rate - ) + model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) model.test_accuracy = 0.0 model.gen_test_accuracy = 0.0 -- 2.39.5