raise ValueError(f"Unknown sandbox level {args.sandbox_level}")
task = tasks.SandBox(
- problem,
+ # problem,
# problems.ProblemAddition(zero_padded=False, inverted_result=False),
+ problems.ProblemLenId(len_max=args.sandbox_levels_len_source),
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
batch_size=args.batch_size,
####################
+class ProblemLenId(Problem):
+ def __init__(self, nb_sentences=100, len_max=5):
+ self.len_max = len_max
+
+ def generate_sequences(self, nb):
+ k = torch.arange(self.len_max * 3 + 3)[None, :]
+ l = torch.randint(self.len_max, (2, nb))[:, :, None] + 1
+ i = torch.randint(10, (2, nb))[:, :, None]
+ a = l[0]
+ b = l[0] + 1 + l[1]
+ c = l[0] + 1 + l[1] + 1 + l[0]
+ sequences = (
+ (k < a) * i[0]
+ + (k == a) * 10
+ + (k > a) * (k < b) * i[1]
+ + (k == b) * 11
+ + (k > b) * (k < c) * i[1]
+ + (k == c) * 12
+ + (k > c) * 13
+ )
+ ar_mask = (sequences == 11).long()
+ ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
+ return sequences, ar_mask
+
+ def seq2str(self, seq):
+ return "".join("0123456789|>.?"[x.item()] for x in seq)
+
+
+####################
+
+
class ProblemLevel0(Problem):
def __init__(self, nb_sentences=100, len_prompt=5, len_result=5):
self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result))
ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
return sequences, ar_mask
+ def seq2str(self, seq):
+ return "".join("0123456789|"[x.item()] for x in seq)
+
+
+####################
+
class ProblemLevel1(Problem):
def __init__(self, nb_operators=100, len_source=5, len_result=8):
return "".join("0123456789|>"[x.item()] for x in seq)
+####################
+
+
class ProblemLevel2(Problem):
def __init__(self, len_source=5, len_result=8):
self.len_source = len_source
f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
)
+ if save_attention_image is not None:
+ ns = torch.randint(self.test_input.size(0), (1,)).item()
+ input = self.test_input[ns : ns + 1].clone()
+
+ with torch.autograd.no_grad():
+ t = model.training
+ model.eval()
+ model.record_attention(True)
+ model(BracketedSequence(input))
+ model.train(t)
+ ram = model.retrieve_attention()
+ model.record_attention(False)
+
+ tokens_output = [c for c in self.problem.seq2str(input[0])]
+ tokens_input = ["n/a"] + tokens_output[:-1]
+ for n_head in range(ram[0].size(1)):
+ filename = os.path.join(
+ result_dir, f"rpl_attention_{n_epoch}_h{n_head}.pdf"
+ )
+ attention_matrices = [m[0, n_head] for m in ram]
+ save_attention_image(
+ filename,
+ tokens_input,
+ tokens_output,
+ attention_matrices,
+ k_top=10,
+ # min_total_attention=0.9,
+ token_gap=12,
+ layer_gap=50,
+ )
+ logger(f"wrote {filename}")
+
######################################################################