parser.add_argument("--proba_not_understands", type=float, default=0.5)
-parser.add_argument("--temperature_hot", type=float, default=2)
+parser.add_argument("--temperature_hot", type=float, default=1.25)
-parser.add_argument("--temperature_cold", type=float, default=0.75)
+parser.add_argument("--temperature_cold", type=float, default=1.25)
-parser.add_argument("--nb_rounds", type=int, default=1)
+parser.add_argument("--nb_rounds", type=int, default=2)
parser.add_argument("--c_quiz_validation_mode", type=str, default="predict")
######################################################################
+#!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
+
+def train_auto_encoder():
+ model = mygpt.MyGPT(
+ vocabulary_size=vocabulary_size,
+ dim_model=args.dim_model,
+ dim_keys=args.dim_keys,
+ dim_hidden=args.dim_hidden,
+ nb_heads=args.nb_heads,
+ nb_blocks=args.nb_blocks,
+ causal=False,
+ dropout=args.dropout,
+ auto_encoder_dim=64,
+ ).to(main_device)
+
+ test_w_quizzes = quiz_machine.problem.generate_w_quizzes(args.nb_test_samples)
+
+ optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+
+ nb_train_samples, acc_train_loss = 0, 0.0
+
+ for n_epoch in range(args.nb_epochs):
+ train_w_quizzes = quiz_machine.problem.generate_w_quizzes(args.nb_train_samples)
+ for input in tqdm.tqdm(
+ train_w_quizzes.split(args.batch_size),
+ dynamic_ncols=True,
+ desc="training AE",
+ total=train_w_quizzes.size(0) // args.batch_size,
+ ):
+ model.train()
+ l = input.size(1) // 4
+ input = input[:, -l:].to(main_device)
+
+ if nb_train_samples % args.batch_size == 0:
+ optimizer.zero_grad()
+
+ z_shape = model.encode(mygpt.BracketedSequence(input.to(main_device)))
+ output = model.decode(z_shape).x
+ loss = F.cross_entropy(output.transpose(1, 2), input)
+ acc_train_loss += loss.item() * input.size(0)
+
+ nb_train_samples += input.size(0)
+
+ loss.backward()
+
+ if nb_train_samples % args.batch_size == 0:
+ optimizer.step()
+
+ train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
+
+ log_string(f"train_perplexity {n_epoch} model ae {train_perplexity}")
+
+ filename = f"auto_encoder.pth"
+ torch.save(
+ model.state_dict(),
+ os.path.join(args.result_dir, filename),
+ )
+ log_string(f"wrote {filename}")
+
+ with torch.autograd.no_grad():
+ model.eval()
+ input = test_w_quizzes[:128, -l:]
+ z_shape = model.encode(mygpt.BracketedSequence(input.to(main_device)))
+ logits = model.decode(z_shape).x
+
+ # dist = torch.distributions.categorical.Categorical(logits=logits)
+ # q = dist.sample()
+
+ q = logits.argmax(dim=-1)
+ q = q.reshape(q.size(0) // 2, 2, -1)
+ input = input.reshape(input.size(0) // 2, 2, -1)
+ q = torch.cat([input.to("cpu"), q.to("cpu")], dim=1).reshape(q.size(0), -1)
+ quiz_machine.problem.save_quizzes_as_image(
+ args.result_dir,
+ f"culture_ae_{n_epoch:04d}.png",
+ q,
+ )
+
+ return model
+
+
+# ae = train_auto_encoder()
+
+# exit(0)
+
+######################################################################
+
models = []
##############################
+class EncoderHead(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.fc = nn.Linear(dim_in, dim_out)
+
+ def forward(self, bs):
+ z = self.fc(bs.x).mean(dim=1)
+ return z, bs.x.shape
+
+
+class DecoderBottom(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.fc = nn.Linear(dim_in, dim_out)
+
+ def forward(self, z_shape):
+ z, shape = z_shape
+ y = self.fc(z)[:, None, :].expand(shape)
+ return BracketedSequence(y)
+
+
+##############################
+
+
class QKVAttention(nn.Module):
def __init__(
self,
nb_blocks,
causal=False,
dropout=0.0,
+ auto_encoder_dim=-1,
len_max=1e5,
):
super().__init__()
),
]
+ if auto_encoder_dim > 0:
+ self.encoder = nn.Sequential(
+ *(
+ trunk_blocks[: nb_blocks // 2]
+ + [EncoderHead(dim_model, auto_encoder_dim)]
+ )
+ )
+
+ self.decoder = nn.Sequential(
+ *(
+ [
+ DecoderBottom(auto_encoder_dim, dim_model),
+ AddPositionalEncoding(len_max),
+ ]
+ + trunk_blocks[nb_blocks // 2 :]
+ )
+ )
+
self.trunk = nn.Sequential(*trunk_blocks)
self.readout = CacheWrapper(
m.weight.fill_(1.0)
def forward(self, bs):
- # print(f"GENERATE {bs.first} {bs.first+bs.nb}")
bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
bs = self.embedding(bs)
bs = self.trunk(bs)
bs = self.readout(bs)
return bs
+ def encode(self, bs):
+ bs = self.embedding(bs)
+ z = self.encoder(bs)
+ return z
+
+ def decode(self, z_shape):
+ bs = self.decoder(z_shape)
+ bs = self.readout(bs)
+ return bs
+
def partial_forward(self, bs, start_layer=None, end_layer=None):
if start_layer is None:
# print(f"GENERATE {bs.first} {bs.first+bs.nb}")