parser.add_argument("--reboot", action="store_true", default=False)
-parser.add_argument("--schedule_free", action="store_true", default=False)
-
# ----------------------------------
parser.add_argument("--model", type=str, default="37M")
def run_tests(model, quiz_machine, local_device=main_device):
with torch.autograd.no_grad():
model.to(local_device).eval()
- if args.schedule_free:
- model.optimizer.eval()
nb_test_samples, acc_test_loss = 0, 0.0
nb_samples_accumulated = 0
model.to(local_device).train()
optimizer_to(model.optimizer, local_device)
- if args.schedule_free:
- model.optimizer.train()
-
nb_train_samples, acc_train_loss = 0, 0.0
full_input, _, full_mask_loss = quiz_machine.data_input(
m.weight.fill_(1.0)
def forward(self, bs):
+ if torch.is_tensor(bs):
+ return self.forward(BracketedSequence(bs)).x
bs = self.embedding(bs)
bs = self.positional_encoding(bs)
bs = self.trunk(bs)
return bs
-def ae_batches(quiz_machine, nb, data_structures, local_device, desc=None):
+def ae_batches(
+ quiz_machine,
+ nb,
+ data_structures,
+ local_device,
+ desc=None,
+ batch_size=args.batch_size,
+):
full_input, full_mask_generate, full_mask_loss = quiz_machine.data_input(
nb, data_structures=data_structures
)
src = zip(
- full_input.split(args.batch_size),
- full_mask_generate.split(args.batch_size),
- full_mask_loss.split(args.batch_size),
+ full_input.split(batch_size),
+ full_mask_generate.split(batch_size),
+ full_mask_loss.split(batch_size),
)
if desc is not None:
src,
dynamic_ncols=True,
desc=desc,
- total=full_input.size(0) // args.batch_size,
+ total=full_input.size(0) // batch_size,
)
for input, mask_generate, mask_loss in src:
)
-def degrade_input_inplace(input, mask_generate, pure_noise=False):
- if pure_noise:
- mask_diffusion_noise = torch.rand(
- mask_generate.size(), device=mask_generate.device
- ) <= torch.rand((mask_generate.size(0), 1), device=mask_generate.device)
+def degrade_input(input, mask_generate, *ts):
+ noise = torch.randint(
+ quiz_machine.problem.nb_colors, input.size(), device=input.device
+ )
- mask_diffusion_noise = mask_diffusion_noise.long()
+ r = torch.rand(mask_generate.size(), device=mask_generate.device)
- input[...] = (
- mask_generate
- * mask_diffusion_noise
- * torch.randint(
- quiz_machine.problem.nb_colors, input.size(), device=input.device
- )
- + (1 - mask_generate * mask_diffusion_noise) * input
- )
+ result = []
- else:
- model.eval()
- for it in range(torch.randint(5, (1,)).item()):
- logits = model(
- mygpt.BracketedSequence(
- torch.cat([input[:, :, None], mask_generate[:, :, None]], dim=2)
- )
- ).x
- dist = torch.distributions.categorical.Categorical(logits=logits)
- input[...] = (1 - mask_generate) * input + mask_generate * dist.sample()
- model.train()
+ for t in ts:
+ mask_diffusion_noise = mask_generate * (r <= t).long()
+ x = (1 - mask_diffusion_noise) * input + mask_diffusion_noise * noise
+ result.append(x)
+
+ return result
+
+ # quiz_machine.problem.save_quizzes_as_image(
+ # args.result_dir,
+ # filename="a.png",
+ # quizzes=a,
+ # )
+
+ # quiz_machine.problem.save_quizzes_as_image(
+ # args.result_dir,
+ # filename="b.png",
+ # quizzes=b,
+ # )
+
+ # time.sleep(1000)
+
+
+def NTC_masked_cross_entropy(output, targets, mask):
+ loss_per_token = F.cross_entropy(output.transpose(1, 2), targets, reduction="none")
+ return (loss_per_token * mask).mean()
+
+
+def NTC_channel_cat(*x):
+ return torch.cat([a.expand_as(x[0])[:, :, None] for a in x], dim=2)
+
+
+def ae_generate(model, input, mask_generate, n_epoch, nb_iterations):
+ noise = torch.randint(
+ quiz_machine.problem.nb_colors, input.size(), device=input.device
+ )
+ input = (1 - mask_generate) * input + mask_generate * noise
+
+ for it in range(nb_iterations):
+ rho = input.new_full((input.size(0),), nb_iterations - 1 - it)
+ input_with_mask = NTC_channel_cat(input, mask_generate, rho[:, None])
+ logits = model(input_with_mask)
+ dist = torch.distributions.categorical.Categorical(logits=logits)
+ input = (1 - mask_generate) * input + mask_generate * dist.sample()
+
+ return input
def test_ae(local_device=main_device):
dropout=args.dropout,
).to(main_device)
- pure_noise = True
-
# quad_order, quad_generate, quad_noise, quad_loss
+
data_structures = [
(("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 0, 1)),
(("A", "f_A", "B", "f_B"), (0, 0, 1, 0), (0, 0, 0, 1), (0, 0, 1, 0)),
model.to(local_device).train()
optimizer_to(model.optimizer, local_device)
- if args.schedule_free:
- model.optimizer.train()
+ nb_iterations = 5
for n_epoch in range(args.nb_epochs):
# ----------------------
if nb_train_samples % args.batch_size == 0:
model.optimizer.zero_grad()
- targets = input.clone()
- degrade_input_inplace(input, mask_generate, pure_noise=pure_noise)
-
- output = model(
- mygpt.BracketedSequence(
- torch.cat([input[:, :, None], mask_generate[:, :, None]], dim=2)
- )
- ).x
-
- # for filename, quizzes in [
- # ("targets.png", targets),
- # ("input.png", input),
- # ("mask_generate.png", mask_generate),
- # ("mask_loss.png", mask_loss),
- # ]:
- # quiz_machine.problem.save_quizzes_as_image(
- # args.result_dir,
- # filename,
- # quizzes=quizzes,
- # )
- # time.sleep(10000)
-
- loss_per_token = F.cross_entropy(
- output.transpose(1, 2), targets, reduction="none"
+ rho = torch.randint(nb_iterations, (input.size(0), 1), device=input.device)
+ targets, input = degrade_input(
+ input, mask_generate, rho / nb_iterations, (rho + 1) / nb_iterations
)
- loss = (loss_per_token * mask_loss).mean()
+ input_with_mask = NTC_channel_cat(input, mask_generate, rho)
+ output = model(input_with_mask)
+ loss = NTC_masked_cross_entropy(output, targets, mask_loss)
acc_train_loss += loss.item() * input.size(0)
nb_train_samples += input.size(0)
+
loss.backward()
if nb_train_samples % args.batch_size == 0:
local_device,
"test",
):
- targets = input.clone()
- degrade_input_inplace(input, mask_generate, pure_noise=pure_noise)
- output = model(
- mygpt.BracketedSequence(
- torch.cat([input[:, :, None], mask_generate[:, :, None]], dim=2)
- )
- ).x
- loss_per_token = F.cross_entropy(
- output.transpose(1, 2), targets, reduction="none"
+ rho = torch.randint(
+ nb_iterations, (input.size(0), 1), device=input.device
+ )
+ targets, input = degrade_input(
+ input, mask_generate, rho / nb_iterations, (rho + 1) / nb_iterations
)
- loss = (loss_per_token * mask_loss).mean()
+ input_with_mask = NTC_channel_cat(input, mask_generate, rho)
+ output = model(input_with_mask)
+ loss = NTC_masked_cross_entropy(output, targets, mask_loss)
acc_test_loss += loss.item() * input.size(0)
nb_test_samples += input.size(0)
for ns, s in enumerate(data_structures):
quad_order, quad_generate, _, _ = s
- input, mask_generate, mask_loss = next(
- ae_batches(quiz_machine, 128, [s], local_device)
+ input, mask_generate, _ = next(
+ ae_batches(quiz_machine, 128, [s], local_device, batch_size=128)
)
targets = input.clone()
- degrade_input_inplace(input, mask_generate, pure_noise=pure_noise)
- result = input
-
- not_converged = torch.full(
- (result.size(0),), True, device=result.device
- )
-
- for it in range(100):
- pred_result = result.clone()
- logits = model(
- mygpt.BracketedSequence(
- torch.cat(
- [
- result[not_converged, :, None],
- mask_generate[not_converged, :, None],
- ],
- dim=2,
- )
- )
- ).x
- dist = torch.distributions.categorical.Categorical(logits=logits)
- update = (1 - mask_generate[not_converged]) * input[
- not_converged
- ] + mask_generate[not_converged] * dist.sample()
- result[not_converged] = update
- not_converged = (pred_result != result).max(dim=1).values
- if not not_converged.any():
- log_string(f"diffusion_converged {it=}")
- break
-
- correct = (result == targets).min(dim=1).values.long()
- predicted_parts = input.new(input.size(0), 4)
-
- nb = 0
-
- predicted_parts = torch.tensor(quad_generate, device=result.device)[
- None, :
- ]
+ input = ae_generate(model, input, mask_generate, n_epoch, nb_iterations)
+ correct = (input == targets).min(dim=1).values.long()
+ predicted_parts = torch.tensor(quad_generate, device=input.device)
+ predicted_parts = predicted_parts[None, :].expand(input.size(0), -1)
solution_is_deterministic = predicted_parts.sum(dim=-1) == 1
correct = (2 * correct - 1) * (solution_is_deterministic).long()
-
nb_correct = (correct == 1).long().sum()
nb_total = (correct != 0).long().sum()
+ correct_parts = predicted_parts * correct[:, None]
log_string(
f"test_accuracy {n_epoch} model AE setup {ns} {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)"
)
- correct_parts = predicted_parts * correct[:, None]
- predicted_parts = predicted_parts.expand_as(correct_parts)
-
- filename = f"prediction_ae_{n_epoch:04d}_{ns}.png"
+ filename = f"prediction_ae_{n_epoch:04d}_structure_{ns}.png"
quiz_machine.problem.save_quizzes_as_image(
args.result_dir,
filename,
- quizzes=result,
+ quizzes=input,
predicted_parts=predicted_parts,
correct_parts=correct_parts,
)
- log_string(f"wrote {filename}")
+ log_string(f"wrote {filename}")
if args.test == "ae":
def compute_causal_attzero(t_q, t_k):
return t_q < t_k
- if args.schedule_free:
- import schedulefree
-
for k in range(args.nb_gpts):
log_string(f"creating model {k}")
model.train_c_quiz_bags = []
model.test_c_quiz_bags = []
- if args.schedule_free:
- model.optimizer = schedulefree.AdamWScheduleFree(
- model.parameters(), lr=args.learning_rate
- )
- else:
- model.optimizer = torch.optim.Adam(
- model.parameters(), lr=args.learning_rate
- )
+ model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
model.test_accuracy = 0.0
model.gen_test_accuracy = 0.0