sys.stdout.flush()
+log_string(f"cmd {' '.join(sys.argv)}")
+
for n in vars(args):
log_string(f"args.{n} {getattr(args, n)}")
return (output - targets).abs().sum() / masks.sum()
-def oneshot(gpt, learning_rate_scheduler, task):
+def oneshot(model, learning_rate_scheduler, task):
+ t = model.training
+ model.eval()
+ mazes = task.test_input[:48].clone()
+ mazes[:, task.height * task.width :] = 0
+ policies = task.test_policies[:48]
+ targets = maze.stationary_densities(
+ mazes[:, : task.height * task.width].view(-1, task.height, task.width),
+ policies.view(-1, 4, task.height, task.width),
+ ).flatten(-2)
+ output = eval_mygpt(model, mazes, prompt_len=task.height * task.width)
+ output = F.softmax(output, dim=2)
+ print(f"{output.size()=}")
+ proba_path = output[:, task.height * task.width :, 4].reshape(
+ -1, task.height, task.width
+ )
+ mazes = mazes[:, : task.height * task.width].reshape(-1, task.height, task.width)
+ targets = targets.reshape(-1, task.height, task.width)
+ paths = task.test_input[:48, task.height * task.width :].reshape(
+ -1, task.height, task.width
+ )
+ filename = f"oneshot.png"
+ maze.save_image(
+ os.path.join(args.result_dir, filename),
+ mazes=mazes,
+ # target_paths=paths,
+ score_paths=proba_path,
+ score_truth=targets,
+ )
+ log_string(f"wrote {filename}")
+
+
+def oneshot_old(gpt, learning_rate_scheduler, task):
t = gpt.training
gpt.eval()
learning_rate_scheduler.reset()
for n_epoch in range(args.nb_epochs):
- learning_rate = learning_rate_scheduler.learning_rate()
+ learning_rate = learning_rate_scheduler.get_learning_rate()
+ log_string(f"learning_rate {n_epoch} {learning_rate}")
+
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
acc_train_loss, nb_train_samples = 0, 0
)
# -------------------
- mazes = task.test_input[:32, : task.height * task.width]
- policies = task.test_policies[:32]
+ mazes = task.test_input[:48, : task.height * task.width]
+ policies = task.test_policies[:48]
output_gpt = eval_mygpt(
gpt, mazes, mode=args.oneshot_input, prompt_len=task.height * task.width
)
class LearningRateScheduler:
- def learning_rate(self):
+ def get_learning_rate(self):
pass
def update(self, nb_finished_epochs, loss):
return vars(self)
def set_state(self, state):
- for k, v in state.item():
+ print(f"{state=}")
+ for k, v in state.items():
setattr(self, k, v)
self.nb_finished_epochs = 0
self.schedule = schedule
- def learning_rate(self):
+ def get_learning_rate(self):
return self.schedule[self.nb_finished_epochs]
+ def update(self, nb_finished_epochs, loss):
+ self.nb_finished_epochs = nb_finished_epochs
+
def reset(self):
self.nb_finished_epochs = 0
+ def get_state(self):
+ return {"nb_finished_epochs": self.nb_finished_epochs}
+
+
+class AutoScheduler(LearningRateScheduler):
+ def __init__(self, learning_rate_init, growth=1.0, degrowth=0.2):
+ self.learning_rate_init = learning_rate_init
+ self.learning_rate = learning_rate_init
+ self.growth = growth
+ self.degrowth = degrowth
+ self.pred_loss = None
+
+ def get_learning_rate(self):
+ return self.learning_rate
+
+ def update(self, nb_finished_epochs, loss):
+ if self.pred_loss is not None:
+ if loss >= self.pred_loss:
+ self.learning_rate *= self.degrowth
+ else:
+ self.learning_rate *= self.growth
+ self.pred_loss = loss
+
+ def reset(self):
+ self.learning_rate = self.learning_rate_init
+
+ def get_state(self):
+ return {
+ "learning_rate_init": self.learning_rate_init,
+ "pred_loss": self.pred_loss,
+ }
+
######################################################################
f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
)
- input = self.test_input[:32]
+ input = self.test_input[:48]
result = input.clone()
ar_mask = result.new_zeros(result.size())
ar_mask[:, self.height * self.width :] = 1
q = torch.arange(d)[:, None]
k = torch.arange(d)[None, :]
s = args.maze_height * args.maze_width
- # return torch.logical_and(q < k, torch.logical_or(q >= s, k >= s))
- return q < k
+ return torch.logical_and(q < k, torch.logical_or(q >= s, k >= s))
+ # return q < k
+
+def noncausal_prompt_oneshot_amm_generator(d):
+ q = torch.arange(d)[:, None]
+ k = torch.arange(d)[None, :]
+ s = args.maze_height * args.maze_width
+ return k >= s
+ # return q < k
-amm_generator = None
-if args.noncausal_prompt:
+if args.oneshot:
+ amm_generator = noncausal_prompt_oneshot_amm_generator
+elif args.noncausal_prompt:
amm_generator = noncausal_prompt_amm_generator
+else:
+ amm_generator = None
model = mygpt.MyGPT(
vocabulary_size=vocabulary_size,
######################################################################
if args.learning_rate_schedule == "auto":
- pass
+ learning_rate_scheduler = AutoScheduler(args.learning_rate)
elif args.learning_rate_schedule == "cos":
schedule = {}
checkpoint = torch.load(checkpoint_name)
nb_epochs_finished = checkpoint["nb_epochs_finished"]
model.load_state_dict(checkpoint["model_state"])
+ learning_rate_scheduler.set_state(checkpoint["learning_rate_scheduler_state"])
torch.set_rng_state(checkpoint["rng_state"])
if torch.cuda.is_available():
torch.cuda.set_rng_state(checkpoint["cuda_rng_state"])
except FileNotFoundError:
log_string("starting from scratch.")
- except:
- log_string("error when loading the checkpoint.")
- exit(1)
+ # except:
+ # log_string("error when loading the checkpoint.")
+ # exit(1)
+
+######################################################################
+
+if args.oneshot:
+ oneshot(model, learning_rate_scheduler, task)
+ exit(0)
######################################################################
learning_rate_scheduler.reset()
for n_epoch in range(nb_epochs_finished, args.nb_epochs):
- learning_rate = learning_rate_scheduler.learning_rate()
-
- log_string(f"learning_rate {learning_rate}")
+ learning_rate = learning_rate_scheduler.get_learning_rate()
+ log_string(f"learning_rate {n_epoch} {learning_rate}")
if args.optim == "sgd":
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
checkpoint = {
"nb_epochs_finished": n_epoch + 1,
"model_state": model.state_dict(),
+ "learning_rate_scheduler_state": learning_rate_scheduler.get_state(),
"rng_state": torch.get_rng_state(),
}
log_string(f"saved checkpoint {checkpoint_name}")
######################################################################
-
-if args.oneshot:
- oneshot(model, learning_rate_scheduler, task)
-
-######################################################################