self.check_structure(quizzes, struct)
return struct
- def inject_noise(self, quizzes, noise, struct, mask):
+ def inject_noise(self, quizzes, noise, struct, quad):
assert self.check_structure(quizzes, struct=struct)
S = self.height * self.width
- mask = torch.tensor(mask, device=quizzes.device)
+ mask = torch.tensor(quad, device=quizzes.device)
mask = mask[None, :, None].expand(1, 4, S + 1).clone()
mask[:, :, 0] = 0
mask = mask.reshape(1, -1).expand_as(quizzes)
).values
def make_quiz_mask(
- self, quizzes, struct=("A", "f_A", "B", "f_B"), mask=(0, 0, 0, 1)
+ self, quizzes, struct=("A", "f_A", "B", "f_B"), quad=(0, 0, 0, 1)
):
assert self.check_structure(quizzes, struct)
S = self.height * self.width
a = ar_mask.reshape(ar_mask.size(0), 4, S + 1)[:, :, 1:]
- a[:, 0, :] = mask[0]
- a[:, 1, :] = mask[1]
- a[:, 2, :] = mask[2]
- a[:, 3, :] = mask[3]
+ a[:, 0, :] = quad[0]
+ a[:, 1, :] = quad[1]
+ a[:, 2, :] = quad[2]
+ a[:, 3, :] = quad[3]
return ar_mask
model,
solved_c_quizzes[:, model.id],
struct=("A", "f_A", "B", "f_B"),
- mask=(0, 0, 0, 1),
+ quad=(0, 0, 0, 1),
)
proba_own_solution[:, model.id] = model_proba_solutions(
######################################################################
+from mygpt import (
+ WithResidual,
+ CacheWrapper,
+ AddPositionalEncoding,
+ QKVAttention,
+ BracketedSequence,
+)
+
+
+class MyAttentionVAE(nn.Module):
+ def __init__(
+ self,
+ vocabulary_size,
+ dim_model,
+ dim_keys,
+ dim_hidden,
+ nb_heads,
+ nb_blocks,
+ dropout=0.0,
+ len_max=1e5,
+ ):
+ super().__init__()
+
+ assert dim_model % nb_heads == 0
+
+ self.embedding = nn.Sequential(
+ CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
+ )
+
+ self.positional_encoding = AddPositionalEncoding(len_max)
+
+ trunk_blocks = []
+
+ for b in range(nb_blocks):
+ trunk_blocks += [
+ WithResidual(
+ CacheWrapper(
+ nn.LayerNorm((dim_model,)),
+ ),
+ QKVAttention(
+ dim_in=dim_model,
+ dim_qk=dim_keys,
+ dim_v=dim_model // nb_heads,
+ nb_heads=nb_heads,
+ attention_dropout=dropout,
+ ),
+ ),
+ WithResidual(
+ CacheWrapper(
+ nn.LayerNorm((dim_model,)),
+ nn.Linear(in_features=dim_model, out_features=dim_hidden),
+ nn.ReLU(),
+ nn.Linear(in_features=dim_hidden, out_features=dim_model),
+ nn.Dropout(dropout),
+ ),
+ ),
+ ]
+
+ self.trunk = nn.Sequential(*trunk_blocks)
+
+ self.readout = CacheWrapper(
+ nn.Linear(in_features=dim_model, out_features=vocabulary_size)
+ )
+
+ with torch.no_grad():
+ for m in self.modules():
+ if isinstance(m, nn.Embedding):
+ m.weight.normal_(mean=0, std=2e-2)
+ elif isinstance(m, nn.LayerNorm):
+ m.bias.zero_()
+ m.weight.fill_(1.0)
+
+ def forward(self, bs):
+ bs = self.embedding(bs)
+ bs = self.positional_encoding(bs)
+ bs = self.trunk(bs)
+ bs = self.readout(bs)
+ return bs
+
+
+def test_ae(local_device=main_device):
+ model = MyAttentionVAE(
+ vocabulary_size=vocabulary_size,
+ dim_model=args.dim_model,
+ dim_keys=args.dim_keys,
+ dim_hidden=args.dim_hidden,
+ nb_heads=args.nb_heads,
+ nb_blocks=args.nb_blocks,
+ dropout=args.dropout,
+ ).to(main_device)
+
+ model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+
+ model.to(local_device).train()
+ optimizer_to(model.optimizer, local_device)
+
+ if args.schedule_free:
+ model.optimizer.train()
+
+ for n_epoch in range(args.nb_epochs):
+ # ----------------------
+ # Train
+
+ model.train()
+ nb_train_samples, acc_train_loss = 0, 0.0
+
+ full_input, full_mask_loss = quiz_machine.data_input(args.nb_train_samples)
+
+ src = zip(
+ full_input.split(args.batch_size), full_mask_loss.split(args.batch_size)
+ )
+
+ for input, mask_loss in tqdm.tqdm(
+ src,
+ dynamic_ncols=True,
+ desc="training",
+ total=full_input.size(0) // args.batch_size,
+ ):
+ input = input.to(local_device)
+ mask_loss = mask_loss.to(local_device)
+
+ if nb_train_samples % args.batch_size == 0:
+ model.optimizer.zero_grad()
+
+ targets = input
+ input = (mask_loss == 0).long() * input
+ output = model(mygpt.BracketedSequence(input)).x
+ loss = F.cross_entropy(output.transpose(1, 2), targets)
+ acc_train_loss += loss.item() * input.size(0)
+ nb_train_samples += input.size(0)
+ loss.backward()
+
+ if nb_train_samples % args.batch_size == 0:
+ model.optimizer.step()
+
+ train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
+
+ log_string(f"train_loss {n_epoch} model AE {acc_train_loss/nb_train_samples}")
+
+ # ----------------------
+ # Test
+
+ with torch.autograd.no_grad():
+ model.eval()
+
+ nb_test_samples, acc_test_loss = 0, 0.0
+
+ full_input, full_mask_loss = quiz_machine.data_input(args.nb_test_samples)
+
+ src = zip(
+ full_input.split(args.batch_size), full_mask_loss.split(args.batch_size)
+ )
+
+ for input, mask_loss in tqdm.tqdm(
+ src,
+ dynamic_ncols=True,
+ desc="testing",
+ total=full_input.size(0) // args.batch_size,
+ ):
+ input = input.to(local_device)
+ mask_loss = mask_loss.to(local_device)
+ targets = input
+ input = (mask_loss == 0).long() * input
+ output = model(mygpt.BracketedSequence(input)).x
+ loss = F.cross_entropy(output.transpose(1, 2), targets)
+ acc_test_loss += loss.item() * input.size(0)
+ nb_test_samples += input.size(0)
+
+ log_string(f"test_loss {n_epoch} model AE {acc_test_loss/nb_test_samples}")
+
+ input, mask_loss = quiz_machine.data_input(128)
+ input = input.to(local_device)
+ mask_loss = mask_loss.to(local_device)
+ targets = input
+ input = (mask_loss == 0).long() * input
+ logits = model(mygpt.BracketedSequence(input)).x
+ dist = torch.distributions.categorical.Categorical(logits=logits)
+ result = dist.sample()
+ L = input.size(1) // 4
+ result[:, 0 * L] = input[:, 0 * L]
+ result[:, 1 * L] = input[:, 1 * L]
+ result[:, 2 * L] = input[:, 2 * L]
+ result[:, 3 * L] = input[:, 3 * L]
+ filename = f"prediction_ae_{n_epoch:04d}.png"
+
+ quiz_machine.problem.save_quizzes_as_image(
+ args.result_dir,
+ filename,
+ quizzes=result,
+ )
+
+ log_string(f"wrote {filename}")
+
+
+if args.test == "ae":
+ test_ae(local_device=main_device)
+ exit(0)
+
+######################################################################
+
+
def create_models():
models = []
procedure=c_quizzes_procedure,
)
+ filename = f"test_{n_epoch:04d}.png"
+
quiz_machine.problem.save_quizzes_as_image(
args.result_dir,
- f"test_{n_epoch:04d}.png",
+ filename,
quizzes=input,
)
m = max(nb_c_quizzes_per_model)
- if m >= args.nb_train_samples:
+ if m * args.c_quiz_multiplier >= args.nb_train_samples:
break
model = models[nb_c_quizzes_per_model.index(m)]
self.answer_len = None
self.prompt_noise = prompt_noise
- # struct, mask_generate, mask_noise, mask_loss
+ # struct, quad_generate, quad_noise, quad_loss
self.train_structures = [
(("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
(("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
######################################################################
- def data_input(self, nb_samples, c_quiz_bags, c_quiz_multiplier=1):
+ def data_input(self, nb_samples, c_quiz_bags=[], c_quiz_multiplier=1):
if len(c_quiz_bags) > 0:
c_quizzes = torch.cat(c_quiz_bags, dim=0)
quiz_mask_loss = quizzes.new_full(quizzes.size(), 1)
if self.prompt_noise > 0.0:
- for struct, _, mask_noise, mask_loss in self.train_structures:
+ for struct, _, quad_noise, quad_loss in self.train_structures:
i = self.problem.indices_select(quizzes=quizzes, struct=struct)
if i.any():
quizzes[i] = self.problem.inject_noise(
- quizzes[i], self.prompt_noise, struct=struct, mask=mask_noise
+ quizzes[i], self.prompt_noise, struct=struct, quad=quad_noise
)
quiz_mask_loss[i] = self.make_quiz_mask(
- quizzes=quizzes[i], struct=struct, mask=mask_loss
+ quizzes=quizzes[i], struct=struct, quad=quad_loss
)
return quizzes, quiz_mask_loss
######################################################################
- def make_quiz_mask(self, quizzes, struct, mask):
+ def make_quiz_mask(self, quizzes, struct, quad):
assert struct in [s for s, _, _, _ in self.train_structures]
- return self.problem.make_quiz_mask(quizzes, struct=struct, mask=mask)
+ return self.problem.make_quiz_mask(quizzes, struct=struct, quad=quad)
######################################################################
- def predict(self, model, quizzes, struct, mask):
+ def predict(self, model, quizzes, struct, quad):
quizzes = quizzes.to(self.device)
- ar_mask = self.make_quiz_mask(quizzes=quizzes, struct=struct, mask=mask)
+ ar_mask = self.make_quiz_mask(quizzes=quizzes, struct=struct, quad=quad)
result = quizzes * (1 - ar_mask)
seq_logprobas = torch.zeros(quizzes.size(0), device=self.device)
nb = 0
# We consider all the configurations that we train for
- for struct, mask_generate, _, _ in self.test_structures:
+ for struct, quad_generate, _, _ in self.test_structures:
i = self.problem.indices_select(quizzes=input, struct=struct)
nb += i.long().sum()
result[i], correct[i], _ = self.predict(
- model=model, quizzes=input[i], struct=struct, mask=mask_generate
+ model=model, quizzes=input[i], struct=struct, quad=quad_generate
)
- predicted_parts[i] = torch.tensor(mask_generate, device=self.device)[
+ predicted_parts[i] = torch.tensor(quad_generate, device=self.device)[
None, :
]
solution_is_deterministic = predicted_parts[i].sum(dim=-1) == 1
model,
c_quizzes,
struct,
- mask_loss,
- mask_noise=None,
+ quad_loss,
+ quad_noise=None,
temperature=1.0,
device=None,
):
device=device,
)
- # if self.prompt_noise > 0.0 and mask_noise is not None:
+ # if self.prompt_noise > 0.0 and quad_noise is not None:
# c_quizzes = self.problem.inject_noise(
- # c_quizzes, self.prompt_noise, struct=struct, mask=mask_noise
+ # c_quizzes, self.prompt_noise, struct=struct, quad=quad_noise
# )
with torch.autograd.no_grad():
):
input = input.to(device)
quiz_mask_loss = self.make_quiz_mask(
- input, struct=struct, mask=mask_loss
+ input, struct=struct, quad=quad_loss
)
output = model(mygpt.BracketedSequence(input)).x / temperature
l[...] = (
c_quizzes = None
for n_step, setup in enumerate(procedure):
- s, m, mt = setup
+ struct, quad_generate, model_modifier = setup
if c_quizzes is None:
- c_quizzes = self.problem.create_empty_quizzes(nb, s)
+ c_quizzes = self.problem.create_empty_quizzes(nb, struct)
c_quizzes = c_quizzes.to(self.device)
- elif s != pred_s:
- c_quizzes = self.problem.reconfigure(c_quizzes, s)
- pred_s = s
+ elif struct != pred_struct:
+ c_quizzes = self.problem.reconfigure(c_quizzes, struct)
+ pred_struct = struct
- if mt is not None:
- mt(model_for_generation)
+ if model_modifier is not None:
+ model_modifier(model_for_generation)
self.autoregression(
model=model_for_generation,
input=c_quizzes,
- ar_mask=self.make_quiz_mask(c_quizzes, s, m),
+ ar_mask=self.make_quiz_mask(c_quizzes, struct, quad_generate),
seq_logprobas=seq_logprobas,
progress_bar_desc=f"autoregression {n_step+1}/{len(procedure)}",
)
if recorder is not None:
x = c_quizzes.clone()
- t = torch.tensor(m, device=x.device)[None, :].expand(x.size(0), -1)
+ t = torch.tensor(quad_generate, device=x.device)[None, :].expand(
+ x.size(0), -1
+ )
recorder.append(
self.problem.reconfigure([x, t], ("A", "f_A", "B", "f_B"))
)