From ec3ff2d8fc427ff50f2af0407dc62e347865e23f Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 17 Aug 2024 22:49:53 +0200 Subject: [PATCH] Update. --- main.py | 80 ++++++++++++++++++++++++++++++++++++------------- quiz_machine.py | 2 +- 2 files changed, 60 insertions(+), 22 deletions(-) diff --git a/main.py b/main.py index 92bc05f..127b71b 100755 --- a/main.py +++ b/main.py @@ -115,6 +115,8 @@ parser.add_argument("--dirty_debug", action="store_true", default=False) parser.add_argument("--test", type=str, default=None) +parser.add_argument("--logit_std_max", type=float, default=-1) + ###################################################################### grids_tasks = ", ".join( @@ -820,6 +822,21 @@ for k in range(args.nb_gpts): dropout=args.dropout, ).to(main_device) + class UpperBoundStd(nn.Module): + def __init__(self, std_max=1.0): + super().__init__() + self.std_max = std_max + + def forward(self, x): + std = x.std(dim=-1, keepdim=True) + y = (x - x.mean(dim=-1, keepdim=True)) / std.clamp(max=self.std_max) + return y + + if args.logit_std_max > 0: + model.readout.f = nn.Sequential( + model.readout.f, UpperBoundStd(std_max=args.logit_std_max) + ) + model.id = k model.train_c_quiz_bags = [] model.test_c_quiz_bags = [] @@ -1034,36 +1051,57 @@ def save_generated_c_quizzes(model, filename, nb=64): ###################################################################### + if args.test == "entropy": model = models[0] model.to(main_device) - log_string("starting testing entropy maximization") - - train_input = quiz_machine.generate_c_quizzes( - 1000, - model_for_generation=model, - procedure=c_quizzes_procedure, - ) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) - for n_epoch in range(10): - nb_train_samples, acc_train_loss = 0, 0.0 + log_string("starting testing entropy maximization") - for input in train_input.split(args.batch_size): - input = input.to(main_device) - output = model(mygpt.BracketedSequence(input)).x - loss = output.log_softmax(dim=1).mean() + for n_epoch in range(100): + input = quiz_machine.generate_c_quizzes( + 128, + model_for_generation=model, + procedure=c_quizzes_procedure, + ) - acc_train_loss += loss.item() * input.size(0) - nb_train_samples += input.size(0) + quiz_machine.problem.save_quizzes_as_image( + args.result_dir, + f"test_{n_epoch:04d}.png", + quizzes=input, + ) - model.optimizer.zero_grad() - loss.backward() - model.optimizer.step() + log_string(f"wrote {filename}") - log_string( - f"increase_entropy {n_epoch} entropy {acc_train_loss/nb_train_samples}" - ) + with torch.no_grad(): + for p in model.parameters(): + p += torch.randn(p.size(), device=p.device) * 1e-3 + + # nb_train_samples, acc_train_loss = 0, 0.0 + + # for k in range(1000 // args.batch_size): + # input = quiz_machine.generate_c_quizzes( + # args.batch_size, + # model_for_generation=model, + # procedure=[(("f_B", "f_A", "A", "B"), (1, 1, 1, 1), None)], + # ) + + # input = input.to(main_device) + # targets = input + # output = model(mygpt.BracketedSequence(input)).x + # loss = -F.cross_entropy(output.transpose(1, 2), targets) + # acc_train_loss += loss.item() * input.size(0) + # nb_train_samples += input.size(0) + + # optimizer.zero_grad() + # loss.backward() + # optimizer.step() + + # log_string( + # f"increase_entropy {n_epoch} entropy {acc_train_loss/nb_train_samples}" + # ) exit(0) diff --git a/quiz_machine.py b/quiz_machine.py index 98e0ea5..18136e8 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -355,7 +355,7 @@ class QuizMachine: input=c_quizzes, ar_mask=self.make_quiz_mask(c_quizzes, s, m), seq_logprobas=seq_logprobas, - progress_bar_desc=f"autoregression {n_step}/{len(procedure)}", + progress_bar_desc=f"autoregression {n_step+1}/{len(procedure)}", ) model_for_generation.reset_transformations() -- 2.39.5