parser.add_argument("--test", type=str, default=None)
+parser.add_argument("--logit_std_max", type=float, default=-1)
+
######################################################################
grids_tasks = ", ".join(
dropout=args.dropout,
).to(main_device)
+ class UpperBoundStd(nn.Module):
+ def __init__(self, std_max=1.0):
+ super().__init__()
+ self.std_max = std_max
+
+ def forward(self, x):
+ std = x.std(dim=-1, keepdim=True)
+ y = (x - x.mean(dim=-1, keepdim=True)) / std.clamp(max=self.std_max)
+ return y
+
+ if args.logit_std_max > 0:
+ model.readout.f = nn.Sequential(
+ model.readout.f, UpperBoundStd(std_max=args.logit_std_max)
+ )
+
model.id = k
model.train_c_quiz_bags = []
model.test_c_quiz_bags = []
######################################################################
+
if args.test == "entropy":
model = models[0]
model.to(main_device)
- log_string("starting testing entropy maximization")
-
- train_input = quiz_machine.generate_c_quizzes(
- 1000,
- model_for_generation=model,
- procedure=c_quizzes_procedure,
- )
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
- for n_epoch in range(10):
- nb_train_samples, acc_train_loss = 0, 0.0
+ log_string("starting testing entropy maximization")
- for input in train_input.split(args.batch_size):
- input = input.to(main_device)
- output = model(mygpt.BracketedSequence(input)).x
- loss = output.log_softmax(dim=1).mean()
+ for n_epoch in range(100):
+ input = quiz_machine.generate_c_quizzes(
+ 128,
+ model_for_generation=model,
+ procedure=c_quizzes_procedure,
+ )
- acc_train_loss += loss.item() * input.size(0)
- nb_train_samples += input.size(0)
+ quiz_machine.problem.save_quizzes_as_image(
+ args.result_dir,
+ f"test_{n_epoch:04d}.png",
+ quizzes=input,
+ )
- model.optimizer.zero_grad()
- loss.backward()
- model.optimizer.step()
+ log_string(f"wrote {filename}")
- log_string(
- f"increase_entropy {n_epoch} entropy {acc_train_loss/nb_train_samples}"
- )
+ with torch.no_grad():
+ for p in model.parameters():
+ p += torch.randn(p.size(), device=p.device) * 1e-3
+
+ # nb_train_samples, acc_train_loss = 0, 0.0
+
+ # for k in range(1000 // args.batch_size):
+ # input = quiz_machine.generate_c_quizzes(
+ # args.batch_size,
+ # model_for_generation=model,
+ # procedure=[(("f_B", "f_A", "A", "B"), (1, 1, 1, 1), None)],
+ # )
+
+ # input = input.to(main_device)
+ # targets = input
+ # output = model(mygpt.BracketedSequence(input)).x
+ # loss = -F.cross_entropy(output.transpose(1, 2), targets)
+ # acc_train_loss += loss.item() * input.size(0)
+ # nb_train_samples += input.size(0)
+
+ # optimizer.zero_grad()
+ # loss.backward()
+ # optimizer.step()
+
+ # log_string(
+ # f"increase_entropy {n_epoch} entropy {acc_train_loss/nb_train_samples}"
+ # )
exit(0)