nb_test_samples, acc_test_loss = 0, 0.0
nb_samples_accumulated = 0
- full_input, full_mask_loss = quiz_machine.data_input(
+ full_input, _, full_mask_loss = quiz_machine.data_input(
args.nb_test_samples, model.test_c_quiz_bags, args.c_quiz_multiplier
)
src = zip(
log_string(f"test_perplexity {n_epoch} model {model.id} {test_perplexity}")
- input, _ = quiz_machine.data_input(
+ input, _, _ = quiz_machine.data_input(
2000, model.test_c_quiz_bags, args.c_quiz_multiplier
)
nb_train_samples, acc_train_loss = 0, 0.0
- full_input, full_mask_loss = quiz_machine.data_input(
+ full_input, _, full_mask_loss = quiz_machine.data_input(
args.nb_train_samples,
model.train_c_quiz_bags + common_c_quiz_bags,
args.c_quiz_multiplier,
from mygpt import (
WithResidual,
CacheWrapper,
- AddPositionalEncoding,
+ VaswaniPositionalEncoding,
+ TrainablePositionalEncoding,
QKVAttention,
BracketedSequence,
)
self.embedding = nn.Sequential(
CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
- AddPositionalEncoding(len_max),
+ VaswaniPositionalEncoding(len_max),
)
def trunk(depth):
from mygpt import (
WithResidual,
CacheWrapper,
- AddPositionalEncoding,
+ VaswaniPositionalEncoding,
QKVAttention,
BracketedSequence,
)
nb_heads,
nb_blocks,
dropout=0.0,
- len_max=1e5,
+ len_max=1024,
):
super().__init__()
CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
)
- self.positional_encoding = AddPositionalEncoding(len_max)
+ self.positional_encoding = TrainablePositionalEncoding(dim_model, len_max)
trunk_blocks = []
(("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (0, 0, 0, 1)),
]
- full_input, full_mask_loss = quiz_machine.data_input(
+ full_input, full_mask_generate, full_mask_loss = quiz_machine.data_input(
args.nb_train_samples, data_structures=data_structures
)
model.optimizer.zero_grad()
targets = input
- input = (mask_loss == 0).long() * input
+ input = (mask_generate == 0).long() * input
output = model(mygpt.BracketedSequence(input)).x
loss = F.cross_entropy(output.transpose(1, 2), targets)
nb_test_samples, acc_test_loss = 0, 0.0
- full_input, full_mask_loss = quiz_machine.data_input(args.nb_test_samples)
+ full_input, full_mask_generate, full_mask_loss = quiz_machine.data_input(
+ args.nb_test_samples, data_structures=data_structures
+ )
src = zip(
full_input.split(args.batch_size), full_mask_loss.split(args.batch_size)
input = input.to(local_device)
mask_loss = mask_loss.to(local_device)
targets = input
- input = (mask_loss == 0).long() * input
+ input = (mask_generate == 0).long() * input
output = model(mygpt.BracketedSequence(input)).x
loss = F.cross_entropy(output.transpose(1, 2), targets)
acc_test_loss += loss.item() * input.size(0)
log_string(f"test_loss {n_epoch} model AE {acc_test_loss/nb_test_samples}")
- input, mask_loss = quiz_machine.data_input(128)
+ input, mask_generate, mask_loss = quiz_machine.data_input(
+ 128, data_structures=data_structures
+ )
input = input.to(local_device)
mask_loss = mask_loss.to(local_device)
targets = input
- input = (mask_loss == 0).long() * input
+ input = (mask_generate == 0).long() * input
logits = model(mygpt.BracketedSequence(input)).x
dist = torch.distributions.categorical.Categorical(logits=logits)
result = dist.sample()
##############################
-class AddPositionalEncoding(nn.Module):
+class VaswaniPositionalEncoding(nn.Module):
def __init__(self, len_max):
super().__init__()
self.len_max = len_max
##############################
+class TrainablePositionalEncoding(nn.Module):
+ def __init__(self, dim, len_max):
+ super().__init__()
+ self.len_max = len_max
+ self.pe = nn.Parameter(torch.randn(1, len_max, dim) / math.sqrt(dim))
+
+ def forward(self, bs):
+ if bs.first == 0:
+ self.cache_y = bs.x.new(bs.x.size())
+
+ self.cache_y[:, bs.first : bs.first + bs.nb] = (
+ bs.slice() + self.pe[bs.first : bs.first + bs.nb]
+ )
+
+ return BracketedSequence(self.cache_y, bs.first, bs.nb)
+
+
+##############################
+
+
class EncoderHead(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
)
- self.positional_encoding = AddPositionalEncoding(len_max)
+ self.positional_encoding = VaswaniPositionalEncoding(len_max)
trunk_blocks = []
quizzes, structs=[s for s, _, _, _ in data_structures]
)
+ quiz_mask_generate = quizzes.new_full(quizzes.size(), 1)
quiz_mask_loss = quizzes.new_full(quizzes.size(), 1)
- for struct, _, quad_noise, quad_loss in data_structures:
+ for struct, quad_generate, quad_noise, quad_loss in data_structures:
i = self.problem.indices_select(quizzes=quizzes, struct=struct)
if i.any():
if self.prompt_noise > 0.0:
quizzes[i] = self.problem.inject_noise(
quizzes[i], self.prompt_noise, struct=struct, quad=quad_noise
)
+ quiz_mask_generate[i] = self.make_quiz_mask(
+ quizzes=quizzes[i], struct=struct, quad=quad_generate
+ )
quiz_mask_loss[i] = self.make_quiz_mask(
quizzes=quizzes[i], struct=struct, quad=quad_loss
)
- print("quad_loss", quad_loss)
- print("quiz_mask_loss", quiz_mask_loss)
-
- return quizzes, quiz_mask_loss
+ return quizzes, quiz_mask_generate, quiz_mask_loss
######################################################################