nb_train_samples, acc_train_loss = 0, 0.0
- hard_w_quizzes = []
-
full_input, full_mask_loss = quiz_machine.data_input(
args.nb_train_samples, model.train_c_quiz_bags
)
run_tests(model, quiz_machine)
- # threshold = torch.cat([l for _, l in hard_w_quizzes], dim=0).sort().values
- # threshold = threshold[threshold.size(0) // 2]
-
- # model.hard_w_quizzes = torch.cat(
- # [x[l >= threshold] for x, l in hard_w_quizzes], dim=0
- # )
-
model.to(main_device)
optimizer_to(model.optimizer, main_device)
######################################################################
-def model_transformer_hot(model):
+def model_modifier_hot(model):
model.temperature = args.temperature_hot
# model.set_noise_injection(1.0, ("ffw", args.nb_blocks // 2))
-def model_transformer_cold(model):
+def model_modifier_cold(model):
model.temperature = args.temperature_cold
# pass
c_quizzes_procedure = [
- (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_transformer_hot),
- (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_transformer_cold),
- (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold),
- (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_transformer_cold),
+ (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_modifier_hot),
+ (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_modifier_cold),
+ (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_modifier_cold),
+ (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_modifier_cold),
]
######################################################################
def save_additional_results(model, models):
- # Save generated quizzes with the successive steps
+ # Save generated quizzes with the successive generation steps
recorder = []
)
+######################################################################
+
+from mygpt import (
+ WithResidual,
+ CacheWrapper,
+ AddPositionalEncoding,
+ QKVAttention,
+ BracketedSequence,
+)
+
+
+class Thinker(nn.Module):
+ def __init__(
+ self,
+ vocabulary_size,
+ dim_model,
+ dim_keys,
+ dim_hidden,
+ nb_heads,
+ nb_blocks,
+ f_len,
+ dropout=0.0,
+ len_max=1e5,
+ ):
+ super().__init__()
+
+ assert dim_model % nb_heads == 0
+
+ self.embedding = nn.Sequential(
+ CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
+ AddPositionalEncoding(len_max),
+ )
+
+ def trunk(depth):
+ trunk_blocks = []
+
+ for b in range(nb_blocks):
+ trunk_blocks += [
+ WithResidual(
+ CacheWrapper(
+ nn.LayerNorm((dim_model,)),
+ ),
+ QKVAttention(
+ dim_in=dim_model,
+ dim_qk=dim_keys,
+ dim_v=dim_model // nb_heads,
+ nb_heads=nb_heads,
+ attention_dropout=dropout,
+ ),
+ ),
+ WithResidual(
+ CacheWrapper(
+ nn.LayerNorm((dim_model,)),
+ nn.Linear(in_features=dim_model, out_features=dim_hidden),
+ nn.ReLU(),
+ nn.Linear(in_features=dim_hidden, out_features=dim_model),
+ nn.Dropout(dropout),
+ ),
+ ),
+ ]
+
+ return nn.Sequential(*trunk_blocks)
+
+ self.bottom_trunk = trunk(nb_blocks // 2)
+
+ self.top_trunk = trunk(nb_blocks // 2)
+
+ self.readout = CacheWrapper(
+ nn.Linear(in_features=dim_model, out_features=vocabulary_size)
+ )
+
+ self.fun_embedding = nn.Parameter(torch.randn(1, f_len, dim_model))
+
+ with torch.no_grad():
+ for m in self.modules():
+ if isinstance(m, nn.Embedding):
+ m.weight.normal_(mean=0, std=2e-2)
+ elif isinstance(m, nn.LayerNorm):
+ m.bias.zero_()
+ m.weight.fill_(1.0)
+
+ def forward(self, bs):
+ for m in self.modules():
+ m.loss = 0
+
+ L = bs.x.size(1) // 3
+
+ bs = self.embedding(bs)
+ A_fA = BracketedSequence(bs.x[:, : 2 * L])
+ B = BracketedSequence(bs.x[:, -L:])
+
+ bs = BracketedSequence(
+ torch.cat([A_fA.x, self.fun_embedding.expand(bs.x.size(0), -1, -1)], dim=1)
+ )
+ bs = self.bottom_trunk(bs)
+ bs = BracketedSequence(torch.cat([bs.x[:, -f_len:, :], B.x], dim=1))
+ bs = self.top_trunk(bs)
+ bs = BracketedSequence(bs.x[:, f_len:, :])
+ bs = self.readout(bs)
+
+ for m in self.modules():
+ if m is not self:
+ self.loss += m.loss
+
+ return bs
+
+
+if args.test == "func":
+ train_input = quiz_machine.problem.generate_w_quizzes(args.nb_train_samples)
+ test_input = quiz_machine.problem.generate_w_quizzes(args.nb_test_samples)
+
+ L = train_input.size(1) // 4
+ f_len = 25
+
+ model = Thinker(
+ vocabulary_size=vocabulary_size,
+ dim_model=args.dim_model,
+ dim_keys=args.dim_keys,
+ dim_hidden=args.dim_hidden,
+ nb_heads=args.nb_heads,
+ nb_blocks=args.nb_blocks,
+ f_len=20,
+ dropout=args.dropout,
+ ).to(main_device)
+
+ model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+
+ for n_epoch in range(args.nb_epochs):
+ model.train()
+
+ nb_train_samples, acc_train_loss = 0, 0.0
+
+ for input in tqdm.tqdm(
+ train_input.split(args.batch_size),
+ dynamic_ncols=True,
+ desc="training",
+ total=train_input.size(0) // args.batch_size,
+ ):
+ input = input.to(main_device)
+
+ if nb_train_samples % args.batch_size == 0:
+ model.optimizer.zero_grad()
+
+ output = model(mygpt.BracketedSequence(input[:, : 3 * L])).x
+ targets = input[:, 3 * L :]
+ loss = F.cross_entropy(output.transpose(1, 2), targets)
+ acc_train_loss += loss.item() * input.size(0)
+
+ nb_train_samples += input.size(0)
+
+ loss.backward()
+
+ if nb_train_samples % args.batch_size == 0:
+ model.optimizer.step()
+
+ train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
+
+ log_string(f"train_perplexity {n_epoch} model thinker {train_perplexity}")
+
+ with torch.autograd.no_grad():
+ model.eval()
+
+ nb_test_samples, acc_test_loss = 0, 0.0
+
+ for input in tqdm.tqdm(
+ test_input.split(args.batch_size),
+ dynamic_ncols=True,
+ desc="testing",
+ total=test_input.size(0) // args.batch_size,
+ ):
+ input = input.to(main_device)
+
+ output = model(mygpt.BracketedSequence(input[:, : 3 * L])).x
+ targets = input[:, 3 * L :]
+ loss = F.cross_entropy(output.transpose(1, 2), targets)
+ acc_test_loss += loss.item() * input.size(0)
+
+ nb_test_samples += input.size(0)
+
+ test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
+
+ log_string(f"test_perplexity {n_epoch} model thinker {test_perplexity}")
+
+ input = test_input[:128].clone().to(main_device)
+
+ output = model(mygpt.BracketedSequence(input[:, : 3 * L])).x
+ dist = torch.distributions.categorical.Categorical(logits=output)
+ input[:, 3 * L :] = dist.sample()
+
+
######################################################################
models = []