From 64906d6dc98df6f5d2f382127330077ef1dbedcf Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 29 Jul 2024 21:14:40 +0200 Subject: [PATCH] Update. --- main.py | 57 ++++++++++++++++++++++++-------------------------------- mygpt.py | 4 +++- 2 files changed, 27 insertions(+), 34 deletions(-) diff --git a/main.py b/main.py index 9c8e0bd..1cf31b3 100755 --- a/main.py +++ b/main.py @@ -324,9 +324,10 @@ elif args.problem == "grids": nb_threads=args.nb_threads, tasks=args.grids_science_tasks, ) - science_w_quizzes = science_problem.generate_w_quizzes(args.nb_test_samples) + science_w_quizzes = science_problem.generate_w_quizzes(100) + if not args.resume: - problem.save_some_examples(args.result_dir, "science_") + science_problem.save_some_examples(args.result_dir, "science_") else: @@ -454,7 +455,7 @@ def one_epoch(model, quiz_machine, local_device=main_device): def model_transformer_hot(model): # model.temperature = args.temperature_hot - model.set_noise_injection(5.0, ("ffw", args.nb_blocks // 2)) + model.set_noise_injection(1.0, ("ffw", args.nb_blocks // 2)) def model_transformer_cold(model): @@ -813,39 +814,34 @@ for k in range(args.nb_gpts): current_epoch = 0 if args.resume: - try: - for model in models: - filename = f"gpt_{model.id:03d}.pth" - - try: - d = torch.load(os.path.join(args.result_dir, filename)) - model.load_state_dict(d[0]) - model.main_test_accuracy = d[1] - log_string(f"successfully loaded {filename}") - except FileNotFoundError: - log_string(f"cannot find {filename}") - pass + for model in models: + filename = f"gpt_{model.id:03d}.pth" try: - filename = "c_quizzes.pth" - quiz_machine.load_c_quizzes(os.path.join(args.result_dir, filename)) + d = torch.load(os.path.join(args.result_dir, filename)) + model.load_state_dict(d[0]) + model.main_test_accuracy = d[1] log_string(f"successfully loaded {filename}") except FileNotFoundError: log_string(f"cannot find {filename}") pass - try: - filename = "state.pth" - state = torch.load(os.path.join(args.result_dir, filename)) - log_string(f"successfully loaded {filename}") - current_epoch = state["current_epoch"] - except FileNotFoundError: - log_string(f"cannot find {filename}") - pass + try: + filename = "c_quizzes.pth" + quiz_machine.load_c_quizzes(os.path.join(args.result_dir, filename)) + log_string(f"successfully loaded {filename}") + except FileNotFoundError: + log_string(f"cannot find {filename}") + pass - except: - log_string(f"error when loading {filename}.") - exit(1) + try: + filename = "state.pth" + state = torch.load(os.path.join(args.result_dir, filename)) + log_string(f"successfully loaded {filename}") + current_epoch = state["current_epoch"] + except FileNotFoundError: + log_string(f"cannot find {filename}") + pass ###################################################################### @@ -872,11 +868,6 @@ if args.dirty_debug: args.nb_new_c_quizzes_for_train = 100 args.nb_new_c_quizzes_for_test = 10 - def compute_valid_quizzes(token_logprobas): - l = token_logprobas.sum(dim=-1).sort(dim=-1).values - return torch.rand(l[:, 0].size(), device=l.device) < 0.5 - - ###################################################################### for n_epoch in range(current_epoch, args.nb_epochs): diff --git a/mygpt.py b/mygpt.py index 7c51bae..15ed80e 100755 --- a/mygpt.py +++ b/mygpt.py @@ -233,7 +233,9 @@ class NoiseInjector(nn.Module): def forward(self, x): if self.noise_std > 0: - x = x + torch.randn(x.size(), device=x.device) * self.noise_std + x = x * ( + 1 - 2 * (torch.rand(x.size(), device=x.device) < self.noise_std).long() + ) return x -- 2.20.1