"--task",
type=str,
default="twotargets",
- help="byheart, learnop, guessop, mixing, twotargets, addition, picoclvr, mnist, maze, snake, stack, expr, rpl, grid, qmlp",
+ help="byheart, learnop, guessop, mixing, memory, twotargets, addition, picoclvr, mnist, maze, snake, stack, expr, rpl, grid, qmlp",
)
parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
"nb_train_samples": 50000,
"nb_test_samples": 10000,
},
+ "memory": {
+ "model": "4M",
+ "batch_size": 100,
+ "nb_train_samples": 5000,
+ "nb_test_samples": 1000,
+ },
"mixing": {
"model": "37M",
"batch_size": 25,
"nb_heads": 2,
"nb_blocks": 2,
},
+ "4M": {
+ "dim_model": 256,
+ "dim_keys": 32,
+ "dim_hidden": 1024,
+ "nb_heads": 4,
+ "nb_blocks": 6,
+ },
"37M": {
"dim_model": 512,
"dim_keys": 64,
device=device,
)
+elif args.task == "memory":
+ task = tasks.SandBox(
+ problem=problems.ProblemMemory(),
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.batch_size,
+ logger=log_string,
+ device=device,
+ )
+
elif args.task == "mixing":
task = tasks.SandBox(
problem=problems.ProblemMixing(
####################
+class ProblemMemory(Problem):
+ def __init__(self, len_total=25):
+ self.len_total = len_total
+ self.max_len_pattern = 5
+ self.nb_noise_tokens = 10
+ self.start_pattern_token = 0
+ self.end_pattern_token = 1
+ self.start_result_token = 2
+ self.end_result_token = 3
+ self.token_string = "[]<>" + "".join(
+ [chr(ord("a") + k) for k in range(self.nb_noise_tokens)]
+ )
+
+ def generate_sequences(self, nb):
+ sequences = (
+ torch.randint(self.nb_noise_tokens, (nb, self.len_total))
+ + self.end_result_token
+ + 1
+ )
+ len_patterns = torch.randint(self.max_len_pattern, (nb,)) + 1
+ pattern_positions = torch.randint(
+ self.len_total - (5 + 2 * self.max_len_pattern), (nb,)
+ )
+ k = self.len_total - (3 + self.max_len_pattern)
+ for i in range(nb):
+ l = len_patterns[i]
+ j = pattern_positions[i]
+ sequences[i, j] = self.start_pattern_token
+ sequences[i, j + l + 2] = self.end_pattern_token
+ sequences[i, k] = self.start_result_token
+ sequences[i, k + l + 2] = self.end_result_token
+ sequences[i, k + 1 : k + 2 + l] = sequences[i, j + 1 : j + 2 + l]
+
+ j = torch.arange(self.len_total)[None, :]
+ ar_mask = (j > k).long() * (j <= k + 1 + len_patterns[:, None]).long()
+
+ return sequences, ar_mask
+
+ def seq2str(self, seq):
+ return "".join(self.token_string[x.item()] for x in seq)
+
+
class ProblemTwoTargets(Problem):
def __init__(self, len_total=10, len_targets=3):
assert len_targets >= 3
return y
def start_error(self, x):
- i = torch.arange(self.height, device=x.device).reshape(1, -1, 1).expand_as(x)
- j = torch.arange(self.width, device=x.device).reshape(1, 1, -1).expand_as(x)
-
- ri = (
- (x == self.height * self.width).long().sum(dim=-1).argmax(-1).view(-1, 1, 1)
- )
- rj = (
- (x == self.height * self.width).long().sum(dim=-2).argmax(-1).view(-1, 1, 1)
- )
+ if self.random_start:
+ i = (
+ torch.arange(self.height, device=x.device)
+ .reshape(1, -1, 1)
+ .expand_as(x)
+ )
+ j = torch.arange(self.width, device=x.device).reshape(1, 1, -1).expand_as(x)
+
+ ri = (
+ (x == self.height * self.width)
+ .long()
+ .sum(dim=-1)
+ .argmax(-1)
+ .view(-1, 1, 1)
+ )
+ rj = (
+ (x == self.height * self.width)
+ .long()
+ .sum(dim=-2)
+ .argmax(-1)
+ .view(-1, 1, 1)
+ )
- m = 1 - torch.logical_or(i == ri, j == rj).long().flatten(1)
+ m = 1 - torch.logical_or(i == ri, j == rj).long().flatten(1)
+ else:
+ m = 1
x = x.flatten(1)
u = torch.arange(self.height * self.width, device=x.device).reshape(1, -1)
d = (x - (m * u + (1 - m) * self.height * self.width)).abs().sum(-1)
+
return d
def moves(self, x):
####################
if __name__ == "__main__":
- p = ProblemMixing()
+ p = ProblemMixing(height=3, width=3, random_start=False)
+
s, m = p.generate_sequences(10000)
for x in s[:5]:
print(p.seq2str(x))
(0, 1),
}
+ if logger is not None:
+ for s, a in zip(self.train_input[:100], self.train_ar_mask[:100]):
+ logger(f"train_sequences {self.problem.seq2str(s)}")
+ a = "".join(["01"[x.item()] for x in a])
+ logger(f" {a}")
+
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
input = self.train_input if split == "train" else self.test_input
with torch.autograd.no_grad():
t = model.training
model.eval()
- model.record_attention(True)
+ # model.record_attention(True)
model(BracketedSequence(input))
model.train(t)
- ram = model.retrieve_attention()
- model.record_attention(False)
-
- tokens_output = [c for c in self.problem.seq2str(input[0])]
- tokens_input = ["n/a"] + tokens_output[:-1]
- for n_head in range(ram[0].size(1)):
- filename = os.path.join(
- result_dir, f"sandbox_attention_{k}_h{n_head}.pdf"
- )
- attention_matrices = [m[0, n_head] for m in ram]
- save_attention_image(
- filename,
- tokens_input,
- tokens_output,
- attention_matrices,
- k_top=10,
- # min_total_attention=0.9,
- token_gap=12,
- layer_gap=50,
- )
- logger(f"wrote {filename}")
+ # ram = model.retrieve_attention()
+ # model.record_attention(False)
+
+ # tokens_output = [c for c in self.problem.seq2str(input[0])]
+ # tokens_input = ["n/a"] + tokens_output[:-1]
+ # for n_head in range(ram[0].size(1)):
+ # filename = os.path.join(
+ # result_dir, f"sandbox_attention_{k}_h{n_head}.pdf"
+ # )
+ # attention_matrices = [m[0, n_head] for m in ram]
+ # save_attention_image(
+ # filename,
+ # tokens_input,
+ # tokens_output,
+ # attention_matrices,
+ # k_top=10,
+ ##min_total_attention=0.9,
+ # token_gap=12,
+ # layer_gap=50,
+ # )
+ # logger(f"wrote {filename}")
######################################################################