return (output - targets).abs().sum() / masks.sum()
-def oneshot(gpt, task):
+def oneshot(gpt, learning_rate_scheduler, task):
t = gpt.training
gpt.eval()
nn.Linear(args.dim_model, dim_out),
).to(device)
+ learning_rate_scheduler.reset()
+
for n_epoch in range(args.nb_epochs):
- learning_rate = learning_rate_schedule[n_epoch]
+ learning_rate = learning_rate_scheduler.learning_rate()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
acc_train_loss, nb_train_samples = 0, 0
loss.backward()
optimizer.step()
+ learning_rate_scheduler.update(n_epoch + 1, acc_train_loss)
+
acc_test_loss, nb_test_samples = 0, 0
for mazes, policies in task.policy_batches(split="test"):
output_gpt = eval_mygpt(
######################################################################
+class LearningRateScheduler:
+ def learning_rate(self):
+ pass
+
+ def update(self, nb_finished_epochs, loss):
+ pass
+
+ def reset(self):
+ pass
+
+ def get_state(self):
+ return vars(self)
+
+ def set_state(self, state):
+ for k, v in state.item():
+ setattr(self, k, v)
+
+
+class StepWiseScheduler(LearningRateScheduler):
+ def __init__(self, schedule):
+ self.nb_finished_epochs = 0
+ self.schedule = schedule
+
+ def learning_rate(self):
+ return self.schedule[self.nb_finished_epochs]
+
+ def reset(self):
+ self.nb_finished_epochs = 0
+
+
+######################################################################
+
+
class Task:
def batches(self, split="train", nb_to_use=-1, desc=None):
pass
######################################################################
+if args.learning_rate_schedule == "auto":
+ pass
+
+elif args.learning_rate_schedule == "cos":
+ schedule = {}
+ for n_epoch in range(args.nb_epochs):
+ u = n_epoch / args.nb_epochs * math.pi
+ schedule[n_epoch] = args.learning_rate * 0.5 * (1 + math.cos(u))
+ learning_rate_scheduler = StepWiseScheduler(schedule)
+ log_string(f"learning_rate_schedule {schedule}")
+
+else:
+ u = {
+ int(k): float(v)
+ for k, v in [
+ tuple(x.split(":")) for x in args.learning_rate_schedule.split(",")
+ ]
+ }
+
+ schedule = {}
+ learning_rate = args.learning_rate
+ for n_epoch in range(args.nb_epochs):
+ if n_epoch in u:
+ learning_rate = u[n_epoch]
+ schedule[n_epoch] = learning_rate
+ learning_rate_scheduler = StepWiseScheduler(schedule)
+ log_string(f"learning_rate_schedule {schedule}")
+
+######################################################################
+
nb_epochs_finished = 0
if args.no_checkpoint:
##############################
-if args.learning_rate_schedule == "cos":
- learning_rate_schedule = {}
- for n_epoch in range(args.nb_epochs):
- u = n_epoch / args.nb_epochs * math.pi
- learning_rate_schedule[n_epoch] = args.learning_rate * 0.5 * (1 + math.cos(u))
-else:
- u = {
- int(k): float(v)
- for k, v in [
- tuple(x.split(":")) for x in args.learning_rate_schedule.split(",")
- ]
- }
-
- learning_rate_schedule = {}
- learning_rate = args.learning_rate
- for n_epoch in range(args.nb_epochs):
- if n_epoch in u:
- learning_rate = u[n_epoch]
- learning_rate_schedule[n_epoch] = learning_rate
-
-log_string(f"learning_rate_schedule {learning_rate_schedule}")
-
-##############################
-
if nb_epochs_finished >= args.nb_epochs:
n_epoch = nb_epochs_finished
train_perplexity = compute_perplexity(
##############################
+learning_rate_scheduler.reset()
+
for n_epoch in range(nb_epochs_finished, args.nb_epochs):
- learning_rate = learning_rate_schedule[n_epoch]
+ learning_rate = learning_rate_scheduler.learning_rate()
log_string(f"learning_rate {learning_rate}")
loss.backward()
optimizer.step()
+ learning_rate_scheduler.update(n_epoch + 1, acc_train_loss)
+
train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
test_perplexity = compute_perplexity(
model, task, prompt_len=task.height * task.width, split="test"
######################################################################
if args.oneshot:
- oneshot(model, task)
+ oneshot(model, learning_rate_scheduler, task)
######################################################################