parser.add_argument("--nb_test_samples", type=int, default=1000)
+parser.add_argument("--nb_train_alien_samples", type=int, default=0)
+
+parser.add_argument("--nb_test_alien_samples", type=int, default=0)
+
parser.add_argument("--nb_c_quizzes", type=int, default=2500)
parser.add_argument("--nb_new_c_quizzes_for_train", type=int, default=None)
logger=log_string,
device=main_device,
)
-
# ------------------------------------------------------
######################################################################
subparam._grad.data = subparam._grad.data.to(device)
-######################################################################
-
-from mygpt import (
- WithResidual,
- CacheWrapper,
- VaswaniPositionalEncoding,
- TrainablePositionalEncoding,
- QKVAttention,
- BracketedSequence,
-)
-
-
-class Thinker(nn.Module):
- def __init__(
- self,
- vocabulary_size,
- dim_model,
- dim_keys,
- dim_hidden,
- nb_heads,
- nb_blocks,
- f_len,
- dropout=0.0,
- len_max=1e5,
- ):
- super().__init__()
-
- assert dim_model % nb_heads == 0
-
- self.embedding = nn.Sequential(
- CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
- VaswaniPositionalEncoding(len_max),
- )
-
- def trunk(depth):
- trunk_blocks = []
-
- for b in range(nb_blocks):
- trunk_blocks += [
- WithResidual(
- CacheWrapper(
- nn.LayerNorm((dim_model,)),
- ),
- QKVAttention(
- dim_in=dim_model,
- dim_qk=dim_keys,
- dim_v=dim_model // nb_heads,
- nb_heads=nb_heads,
- attention_dropout=dropout,
- ),
- ),
- WithResidual(
- CacheWrapper(
- nn.LayerNorm((dim_model,)),
- nn.Linear(in_features=dim_model, out_features=dim_hidden),
- nn.ReLU(),
- nn.Linear(in_features=dim_hidden, out_features=dim_model),
- nn.Dropout(dropout),
- ),
- ),
- ]
-
- return nn.Sequential(*trunk_blocks)
-
- self.bottom_trunk = trunk(nb_blocks // 2)
-
- self.top_trunk = trunk(nb_blocks // 2)
-
- self.readout = CacheWrapper(
- nn.Linear(in_features=dim_model, out_features=vocabulary_size)
- )
-
- self.fun_embedding = nn.Parameter(torch.randn(1, f_len, dim_model))
-
- with torch.no_grad():
- for m in self.modules():
- if isinstance(m, nn.Embedding):
- m.weight.normal_(mean=0, std=2e-2)
- elif isinstance(m, nn.LayerNorm):
- m.bias.zero_()
- m.weight.fill_(1.0)
-
- def forward(self, bs):
- for m in self.modules():
- m.loss = 0
-
- L = bs.x.size(1) // 3
-
- bs = self.embedding(bs)
- A_fA = BracketedSequence(bs.x[:, : 2 * L])
- B = BracketedSequence(bs.x[:, -L:])
-
- bs = BracketedSequence(
- torch.cat([A_fA.x, self.fun_embedding.expand(bs.x.size(0), -1, -1)], dim=1)
- )
- bs = self.bottom_trunk(bs)
- bs = BracketedSequence(torch.cat([bs.x[:, -f_len:, :], B.x], dim=1))
- bs = self.top_trunk(bs)
- bs = BracketedSequence(bs.x[:, f_len:, :])
- bs = self.readout(bs)
-
- for m in self.modules():
- if m is not self:
- self.loss += m.loss
-
- return bs
-
-
######################################################################
from mygpt import (
WithResidual,
CacheWrapper,
- VaswaniPositionalEncoding,
+ CachedVaswaniPositionalEncoding,
QKVAttention,
BracketedSequence,
)
)
# self.positional_encoding = TrainablePositionalEncoding(dim_model, len_max)
- self.positional_encoding = VaswaniPositionalEncoding(len_max=1e5)
+ self.positional_encoding = CachedVaswaniPositionalEncoding(len_max=1e5)
trunk_blocks = []
return bs
-######################################################################
-
-# f = phi(A, f(A)) + phi(B, f(B))
-# \hat{f(A)} = psi(A, f)
-# \hat{A} = psi_inv(f(A), f)
-# \hat{f(B)} = psi(B, f)
-# \hat{B} = psi_inv(f(B), f)
-
-
-def attention_layer(dim_model, dim_keys, nb_heads, dropout):
- return WithResidual(
- CacheWrapper(
- nn.LayerNorm((dim_model,)),
- ),
- QKVAttention(
- dim_in=dim_model,
- dim_qk=dim_keys,
- dim_v=dim_model // nb_heads,
- nb_heads=nb_heads,
- attention_dropout=dropout,
- ),
- )
-
-
-class FunctionalAE(nn.Module):
- def __init__(
- self,
- vocabulary_size,
- dim_model,
- dim_keys,
- dim_hidden,
- nb_heads,
- nb_blocks,
- dropout=0.0,
- len_max=1024,
- ):
- super().__init__()
-
- assert dim_model % nb_heads == 0
-
- self.embedding = CacheWrapper(
- nn.Sequential(
- MultiEmbedding((vocabulary_size, 2), dim_model), nn.Dropout(dropout)
- ),
- )
-
- # self.positional_encoding = TrainablePositionalEncoding(dim_model, len_max)
- self.positional_encoding = VaswaniPositionalEncoding(len_max=1e5)
-
- def trunk(nb, bottom=True):
- trunk_blocks = [VaswaniPositionalEncoding(len_max=1e5)]
-
- la = [
- QKVAttention(
- dim_in=dim_model,
- dim_qk=dim_keys,
- dim_v=dim_model // nb_heads,
- nb_heads=nb_heads,
- attention_dropout=dropout,
- ),
- ]
-
- # if not bottom:
- # trunk_blocks += la
-
- for b in range(nb):
- trunk_blocks += [
- attention_block(dim_model, dim_keys, nb_heads, dropout),
- ffw_block(dim_model, dim_hidden, nb_heads, dropout),
- ]
-
- # if bottom:
- # trunk_blocks += la
-
- return nn.Sequential(*trunk_blocks)
-
- self.phi = trunk(nb_blocks // 2, bottom=True)
- nb_f_tokens = 200
- self.f_tokens = nn.Parameter(
- torch.randn(1, nb_f_tokens, dim_model) / math.sqrt(nb_f_tokens)
- )
- self.psi = trunk(nb_blocks // 2, bottom=False)
- self.psi_inv = trunk(nb_blocks // 2, bottom=False)
- self.internal_pe = VaswaniPositionalEncoding(len_max=1e5)
-
- self.readout = CacheWrapper(
- nn.Linear(in_features=dim_model, out_features=vocabulary_size)
- )
-
- with torch.no_grad():
- for m in self.modules():
- if isinstance(m, nn.Embedding):
- m.weight.normal_(mean=0, std=2e-2)
- elif isinstance(m, nn.LayerNorm):
- m.bias.zero_()
- m.weight.fill_(1.0)
-
- def forward(self, bs):
- def cat(*x):
- return BracketedSequence(torch.cat(x, dim=1))
-
- if torch.is_tensor(bs):
- return self.forward(BracketedSequence(bs)).x
- bs = self.embedding(bs)
- bs = self.positional_encoding(bs)
-
- x_A, x_f_A, x_B, x_f_B = bs.x.chunk(4, dim=1)
-
- K = self.f_tokens.size(1)
- N, L = x_A.size()[:2]
-
- ft = self.f_tokens.expand(N, -1, -1)
-
- theta_A = self.phi(cat(ft, x_A, x_f_A)).x[:, :K, :]
- theta_B = self.phi(cat(ft, x_B, x_f_B)).x[:, :K, :]
-
- # if self.hook_theta is not None:
- # self.hook_theta(theta_A, theta_B)
-
- hat_f_A = self.psi(cat(x_A, theta_B)).x[:, :L]
- hat_f_B = self.psi(cat(x_B, theta_A)).x[:, :L]
-
- hat_A = self.psi_inv(cat(x_f_A, theta_B)).x[:, :L]
- hat_B = self.psi_inv(cat(x_f_B, theta_A)).x[:, :L]
-
- bs = cat(hat_A, hat_f_A, hat_B, hat_f_B)
-
- bs = self.readout(bs)
- return bs
-
-
######################################################################
# quad_order, quad_generate, quad_noise, quad_loss
data_structures,
local_device,
c_quizzes=None,
+ alien_quiz_machine=None,
+ nb_aliens=None,
desc=None,
batch_size=args.batch_size,
):
f"{prefix}test_accuracy {n_epoch} model {model.id} nb_correct {nb_correct} / {nb_total} ({(nb_correct*100)/nb_total:.02f}%)"
)
- model.test_accuracy = nb_correct / nb_total
-
# Save some images
- for f, record in [("prediction", record_d), ("generation", record_nd)]:
- result, predicted_parts, correct_parts = bag_to_tensors(record)
+ if n_epoch < 50:
+ for f, record in [("prediction", record_d), ("generation", record_nd)]:
+ result, predicted_parts, correct_parts = bag_to_tensors(record)
- filename = f"{prefix}culture_{f}_{n_epoch:04d}_{model.id:02d}.png"
+ filename = f"{prefix}culture_{f}_{n_epoch:04d}_{model.id:02d}.png"
- quiz_machine.problem.save_quizzes_as_image(
- args.result_dir,
- filename,
- quizzes=result[:128],
- predicted_parts=predicted_parts[:128],
- correct_parts=correct_parts[:128],
- )
+ quiz_machine.problem.save_quizzes_as_image(
+ args.result_dir,
+ filename,
+ quizzes=result[:128],
+ predicted_parts=predicted_parts[:128],
+ correct_parts=correct_parts[:128],
+ )
- log_string(f"wrote {filename}")
+ log_string(f"wrote {filename}")
+
+ return nb_correct / nb_total
######################################################################
f"train_loss {n_epoch} model {model.id} {acc_train_loss/nb_train_samples}"
)
- run_ae_test(model, quiz_machine, n_epoch, c_quizzes=None, local_device=local_device)
+ model.test_accuracy = run_ae_test(
+ model, quiz_machine, n_epoch, c_quizzes=None, local_device=local_device
+ )
+
+ if args.nb_test_alien_samples > 0:
+ run_ae_test(
+ model,
+ alien_quiz_machine,
+ n_epoch,
+ c_quizzes=None,
+ local_device=local_device,
+ prefix="alien",
+ )
######################################################################
def generate_ae_c_quizzes(models, nb, local_device=main_device):
# To be thread-safe we must make copies
+
+ def copy_for_inference(model):
+ return copy.deepcopy(model).to(local_device).eval()
+
quad_order = ("A", "f_A", "B", "f_B")
template = quiz_machine.problem.create_empty_quizzes(
quizzes=template, quad_order=quad_order, quad_mask=(1, 1, 1, 1)
)
- def copy_for_inference(model):
- return copy.deepcopy(model).to(local_device).eval()
-
wanted_nb = nb
nb_to_save = 256
nb_c_quizzes_per_model = torch.zeros(len(models), device=local_device)