)
+def degrade_input_inplace(input, mask_generate, pure_noise=False):
+ if pure_noise:
+ mask_diffusion_noise = torch.rand(
+ mask_generate.size(), device=mask_generate.device
+ ) <= torch.rand((mask_generate.size(0), 1), device=mask_generate.device)
+
+ mask_diffusion_noise = mask_diffusion_noise.long()
+
+ input[...] = (
+ 1 - mask_generate
+ ) * input + mask_generate * mask_diffusion_noise * torch.randint(
+ quiz_machine.problem.nb_colors, input.size(), device=input.device
+ )
+ else:
+ model.eval()
+ for it in range(torch.randint(5, (1,)).item()):
+ logits = model(
+ mygpt.BracketedSequence(
+ torch.cat([input[:, :, None], mask_generate[:, :, None]], dim=2)
+ )
+ ).x
+ dist = torch.distributions.categorical.Categorical(logits=logits)
+ input[...] = (1 - mask_generate) * input + mask_generate * dist.sample()
+ model.train()
+
+
def test_ae(local_device=main_device):
model = MyAttentionAE(
vocabulary_size=vocabulary_size,
pure_noise = True
+ # quad_order, quad_generate, quad_noise, quad_loss
data_structures = [
(("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 0, 1)),
(("A", "f_A", "B", "f_B"), (0, 0, 1, 0), (0, 0, 0, 1), (0, 0, 1, 0)),
if nb_train_samples % args.batch_size == 0:
model.optimizer.zero_grad()
- targets = input
-
- input = (1 - mask_generate) * input # PARANOIAAAAAAAAAA
-
- if pure_noise:
- mask_diffusion_noise = torch.rand(
- mask_generate.size(), device=mask_generate.device
- ) <= torch.rand((mask_generate.size(0), 1), device=mask_generate.device)
-
- mask_diffusion_noise = mask_diffusion_noise.long()
-
- input = input + mask_generate * mask_diffusion_noise * torch.randint(
- quiz_machine.problem.nb_colors, input.size(), device=input.device
- )
- else:
- model.eval()
- for it in range(torch.randint(5, (1,)).item()):
- logits = model(
- mygpt.BracketedSequence(
- torch.cat(
- [input[:, :, None], mask_generate[:, :, None]], dim=2
- )
- )
- ).x
- dist = torch.distributions.categorical.Categorical(logits=logits)
- input = (1 - mask_generate) * input + mask_generate * dist.sample()
- model.train()
+ targets = input.clone()
+ degrade_input_inplace(input, mask_generate, pure_noise=pure_noise)
output = model(
mygpt.BracketedSequence(
torch.cat([input[:, :, None], mask_generate[:, :, None]], dim=2)
)
).x
- loss = F.cross_entropy(output.transpose(1, 2), targets)
+ loss_per_token = F.cross_entropy(
+ output.transpose(1, 2), targets, reduction="none"
+ )
+ loss = (loss_per_token * mask_loss).mean()
acc_train_loss += loss.item() * input.size(0)
nb_train_samples += input.size(0)
loss.backward()
local_device,
"test",
):
- targets = input
-
- input = (1 - mask_generate) * input # PARANOIAAAAAAAAAA
-
- if pure_noise:
- mask_diffusion_noise = torch.rand(
- mask_generate.size(), device=mask_generate.device
- ) <= torch.rand(
- (mask_generate.size(0), 1), device=mask_generate.device
- )
-
- mask_diffusion_noise = mask_diffusion_noise.long()
-
- input = (
- input
- + mask_generate
- * mask_diffusion_noise
- * torch.randint(
- quiz_machine.problem.nb_colors,
- input.size(),
- device=input.device,
- )
- )
- else:
- for it in range(torch.randint(5, (1,)).item()):
- logits = model(
- mygpt.BracketedSequence(
- torch.cat(
- [input[:, None], mask_generate[:, None]], dim=1
- )
- )
- ).x
- dist = torch.distributions.categorical.Categorical(
- logits=logits
- )
- input = (
- 1 - mask_generate
- ) * input + mask_generate * dist.sample()
-
+ targets = input.clone()
+ degrade_input_inplace(input, mask_generate, pure_noise=pure_noise)
output = model(
mygpt.BracketedSequence(
torch.cat([input[:, :, None], mask_generate[:, :, None]], dim=2)
)
).x
- loss = F.cross_entropy(output.transpose(1, 2), targets)
+ loss_per_token = F.cross_entropy(
+ output.transpose(1, 2), targets, reduction="none"
+ )
+ loss = (loss_per_token * mask_loss).mean()
acc_test_loss += loss.item() * input.size(0)
nb_test_samples += input.size(0)
ae_batches(quiz_machine, 128, [s], local_device)
)
- targets = input
-
- input = (1 - mask_generate) * input # PARANOIAAAAAAAAAA
-
- if pure_noise:
- mask_diffusion_noise = torch.rand(
- mask_generate.size(), device=mask_generate.device
- ) <= torch.rand(
- (mask_generate.size(0), 1), device=mask_generate.device
- )
-
- mask_diffusion_noise = mask_diffusion_noise.long()
-
- input = (
- input
- + mask_generate
- * mask_diffusion_noise
- * torch.randint(
- quiz_machine.problem.nb_colors,
- input.size(),
- device=input.device,
- )
- )
- else:
- for it in range(torch.randint(5, (1,)).item()):
- logits = model(
- mygpt.BracketedSequence(
- torch.cat(
- [input[:, :, None], mask_generate[:, :, None]],
- dim=2,
- )
- )
- ).x
- dist = torch.distributions.categorical.Categorical(
- logits=logits
- )
- input = (
- 1 - mask_generate
- ) * input + mask_generate * dist.sample()
-
+ targets = input.clone()
+ degrade_input_inplace(input, mask_generate, pure_noise=pure_noise)
result = input
not_converged = torch.full(
torch.cat(
[
result[not_converged, :, None],
- mask_generate[:, :, None],
+ mask_generate[not_converged, :, None],
],
dim=2,
)