)
+class MultiEmbedding(nn.Module):
+ def __init__(self, nb_values, dim):
+ super().__init__()
+ self.embeddings = nn.ModuleList([nn.Embedding(n, dim) for n in nb_values])
+
+ def forward(self, x):
+ y = 0
+ for f, z in zip(self.embeddings, x.split(1, dim=2)):
+ y = y + f(z[:, :, 0])
+ return y
+
+
class MyAttentionAE(nn.Module):
def __init__(
self,
assert dim_model % nb_heads == 0
- self.embedding = nn.Sequential(
- CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
+ self.embedding = CacheWrapper(
+ nn.Sequential(
+ MultiEmbedding((vocabulary_size, 2), dim_model), nn.Dropout(dropout)
+ ),
)
- self.positional_encoding = TrainablePositionalEncoding(dim_model, len_max)
+ # self.positional_encoding = TrainablePositionalEncoding(dim_model, len_max)
+ self.positional_encoding = VaswaniPositionalEncoding(len_max=1e5)
trunk_blocks = []
dropout=args.dropout,
).to(main_device)
+ pure_noise = True
+
data_structures = [
- (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
- (("A", "f_A", "B", "f_B"), (0, 0, 1, 0), (0, 0, 0, 1), (1, 1, 1, 1)),
- (("A", "f_A", "B", "f_B"), (0, 1, 0, 0), (1, 0, 0, 0), (1, 1, 1, 1)),
- (("A", "f_A", "B", "f_B"), (1, 0, 0, 0), (0, 1, 0, 0), (1, 1, 1, 1)),
- (("A", "f_A", "B", "f_B"), (1, 1, 1, 0), (0, 0, 0, 0), (1, 1, 1, 1)),
+ (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 0, 1)),
+ (("A", "f_A", "B", "f_B"), (0, 0, 1, 0), (0, 0, 0, 1), (0, 0, 1, 0)),
+ (("A", "f_A", "B", "f_B"), (0, 1, 0, 0), (1, 0, 0, 0), (0, 1, 0, 0)),
+ (("A", "f_A", "B", "f_B"), (1, 0, 0, 0), (0, 1, 0, 0), (1, 0, 0, 0)),
+ (("A", "f_A", "B", "f_B"), (1, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
]
model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
targets = input
- # mask_diffusion_noise = (mask_generate == 1) & (
- # torch.rand(mask_generate.size(), device=mask_generate.device)
- # <= torch.rand((mask_generate.size(0), 1), device=mask_generate.device)
- # )
+ input = (1 - mask_generate) * input # PARANOIAAAAAAAAAA
- # mask_diffusion_noise = mask_diffusion_noise.long()
+ if pure_noise:
+ mask_diffusion_noise = torch.rand(
+ mask_generate.size(), device=mask_generate.device
+ ) <= torch.rand((mask_generate.size(0), 1), device=mask_generate.device)
- # input = (
- # 1 - mask_diffusion_noise
- # ) * input + mask_diffusion_noise * torch.randint(
- # quiz_machine.problem.nb_colors, input.size(), device=input.device
- # )
+ mask_diffusion_noise = mask_diffusion_noise.long()
- # ------------------------------
- input = (1 - mask_generate) * input # PARANOIAAAAAAAAAA
- model.eval()
- for it in range(torch.randint(5, (1,)).item()):
- logits = model(mygpt.BracketedSequence(input)).x
- dist = torch.distributions.categorical.Categorical(logits=logits)
- input = (1 - mask_generate) * input + mask_generate * dist.sample()
- model.train()
- # -----------------------------
+ input = input + mask_generate * mask_diffusion_noise * torch.randint(
+ quiz_machine.problem.nb_colors, input.size(), device=input.device
+ )
+ else:
+ model.eval()
+ for it in range(torch.randint(5, (1,)).item()):
+ logits = model(
+ mygpt.BracketedSequence(
+ torch.cat(
+ [input[:, :, None], mask_generate[:, :, None]], dim=2
+ )
+ )
+ ).x
+ dist = torch.distributions.categorical.Categorical(logits=logits)
+ input = (1 - mask_generate) * input + mask_generate * dist.sample()
+ model.train()
- output = model(mygpt.BracketedSequence(input)).x
+ output = model(
+ mygpt.BracketedSequence(
+ torch.cat([input[:, :, None], mask_generate[:, :, None]], dim=2)
+ )
+ ).x
loss = F.cross_entropy(output.transpose(1, 2), targets)
acc_train_loss += loss.item() * input.size(0)
nb_train_samples += input.size(0)
):
targets = input
- # mask_diffusion_noise = (mask_generate == 1) & (
- # torch.rand(mask_generate.size(), device=mask_generate.device)
- # <= torch.rand(
- # (mask_generate.size(0), 1), device=mask_generate.device
- # )
- # )
-
- # mask_diffusion_noise = mask_diffusion_noise.long()
-
- # input = (
- # 1 - mask_diffusion_noise
- # ) * input + mask_diffusion_noise * torch.randint(
- # quiz_machine.problem.nb_colors, input.size(), device=input.device
- # )
-
- # ------------------------------
input = (1 - mask_generate) * input # PARANOIAAAAAAAAAA
- for it in range(torch.randint(5, (1,)).item()):
- logits = model(mygpt.BracketedSequence(input)).x
- dist = torch.distributions.categorical.Categorical(logits=logits)
- input = (1 - mask_generate) * input + mask_generate * dist.sample()
- # -----------------------------
-
- output = model(mygpt.BracketedSequence(input)).x
+ if pure_noise:
+ mask_diffusion_noise = torch.rand(
+ mask_generate.size(), device=mask_generate.device
+ ) <= torch.rand(
+ (mask_generate.size(0), 1), device=mask_generate.device
+ )
+
+ mask_diffusion_noise = mask_diffusion_noise.long()
+
+ input = (
+ input
+ + mask_generate
+ * mask_diffusion_noise
+ * torch.randint(
+ quiz_machine.problem.nb_colors,
+ input.size(),
+ device=input.device,
+ )
+ )
+ else:
+ for it in range(torch.randint(5, (1,)).item()):
+ logits = model(
+ mygpt.BracketedSequence(
+ torch.cat(
+ [input[:, None], mask_generate[:, None]], dim=1
+ )
+ )
+ ).x
+ dist = torch.distributions.categorical.Categorical(
+ logits=logits
+ )
+ input = (
+ 1 - mask_generate
+ ) * input + mask_generate * dist.sample()
+
+ output = model(
+ mygpt.BracketedSequence(
+ torch.cat([input[:, :, None], mask_generate[:, :, None]], dim=2)
+ )
+ ).x
loss = F.cross_entropy(output.transpose(1, 2), targets)
acc_test_loss += loss.item() * input.size(0)
nb_test_samples += input.size(0)
input = (1 - mask_generate) * input # PARANOIAAAAAAAAAA
- result = (1 - mask_generate) * input
+ if pure_noise:
+ mask_diffusion_noise = torch.rand(
+ mask_generate.size(), device=mask_generate.device
+ ) <= torch.rand(
+ (mask_generate.size(0), 1), device=mask_generate.device
+ )
+
+ mask_diffusion_noise = mask_diffusion_noise.long()
+
+ input = (
+ input
+ + mask_generate
+ * mask_diffusion_noise
+ * torch.randint(
+ quiz_machine.problem.nb_colors,
+ input.size(),
+ device=input.device,
+ )
+ )
+ else:
+ for it in range(torch.randint(5, (1,)).item()):
+ logits = model(
+ mygpt.BracketedSequence(
+ torch.cat(
+ [input[:, :, None], mask_generate[:, :, None]],
+ dim=2,
+ )
+ )
+ ).x
+ dist = torch.distributions.categorical.Categorical(
+ logits=logits
+ )
+ input = (
+ 1 - mask_generate
+ ) * input + mask_generate * dist.sample()
- # + mask_generate * torch.randint(
- # quiz_machine.problem.nb_colors, input.size(), device=input.device
- # )
+ result = input
not_converged = torch.full(
(result.size(0),), True, device=result.device
for it in range(100):
pred_result = result.clone()
- logits = model(mygpt.BracketedSequence(result[not_converged])).x
+ logits = model(
+ mygpt.BracketedSequence(
+ torch.cat(
+ [
+ result[not_converged, :, None],
+ mask_generate[:, :, None],
+ ],
+ dim=2,
+ )
+ )
+ ).x
dist = torch.distributions.categorical.Categorical(logits=logits)
update = (1 - mask_generate[not_converged]) * input[
not_converged