def generate_sequences(self, nb):
pass
- def log_performance(self, sequences, logger):
- pass
+ def seq2str(self, seq):
+ return "[NOT IMPLEMENTED]"
+
+
+####################
-class ProblemByheart(Problem):
- def __init__(self):
- nb_seq, len_prompt, len_result = 100, 5, 5
+class ProblemLevel0(Problem):
+ def __init__(self, nb_sentences=100, len_prompt=5, len_result=5):
self.seq = torch.randint(10, (nb_seq, len_prompt + 1 + len_result))
self.seq[:, len_prompt] = 10
ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
return sequences, ar_mask
- # problems = [ProblemByheart()]
- # nb_common_codes = 100
- # def generate_sequences(nb_samples):
- # problem_indexes = torch.randint(len(problems), (nb_samples,))
- # nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0)
- # print(f"{nb_samples_per_problem}")
- # all_seq = []
- # for nb, p in zip(nb_samples_per_problem, problems):
- # all_seq.append(p.generate_sequences(nb_samples_per_problem[nb]))
- # return all_seq
+class ProblemLevel1(Problem):
+ def __init__(self, nb_operators=100, len_prompt=5, len_result=8):
+ self.len_prompt = len_prompt
+ self.len_result = len_result
+ self.len_nb_operator = int(math.log(nb_operators) / math.log(10)) + 1
+ self.operators = F.one_hot(
+ torch.rand(nb_operators, len_result, len_prompt).argmax(-1),
+ num_classes=len_prompt,
+ )
+
+ def generate_sequences(self, nb):
+ a = self.len_nb_operator
+ b = a + 1 + self.len_prompt
+ sequences = torch.empty(nb, b + 1 + self.len_result, dtype=torch.int64)
+ nb_operators = torch.randint(self.operators.size(0), (nb,))
+ sequences[:, :a] = (nb_operators[:, None] / 10 ** torch.arange(a)) % 10
+ sequences[:, a] = 10
+ sequences[:, a + 1 : b] = torch.randint(10, (nb, b - a - 1))
+ sequences[:, b] = 11
+
+ o = self.operators[nb_operators]
+ p = sequences[:, a + 1 : b]
+ print(f"{o.size()=} {p.size()=} {sequences[:,b+1:].size()=}")
+ sequences[:, b + 1 :] = o.bmm(p[:, :, None]).squeeze(-1)
+ ar_mask = (sequences == 11).long()
+ ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
+ return sequences, ar_mask
+
+ def seq2str(self, seq):
+ return "".join(self.id2char[x.item()] for x in seq)
- # for strain, stest in zip(train_seq, test_seq):
- # s = torch.cat((strain, stest), 0)
+
+####################
+
+
+class ProblemAddition(Problem):
+ def __init__(self, nb_digits=10, zero_padded=False, inverted_result=False):
+ self.nb_digits = nb_digits
+ self.zero_padded = zero_padded
+ self.inverted_result = inverted_result
+ self.char2id = dict([(c, n) for n, c in enumerate("0123456789+=$")])
+ self.id2char = dict([(n, c) for c, n in self.char2id.items()])
+
+ def tensorize(self, strings):
+ len_max = max([len(x) for x in strings])
+ return torch.cat(
+ [
+ torch.tensor(
+ [
+ [self.char2id[c] for c in s + "$" * (len_max - len(s))]
+ for s in strings
+ ]
+ )
+ ],
+ 0,
+ )
+
+ def generate_sequences(self, nb):
+ sequences = []
+ for k in range(nb):
+ a, b = torch.randint(10**self.nb_digits, (2,))
+ c = a + b
+ a, b, c = str(a.item()), str(b.item()), str(c.item())
+ if self.zero_padded:
+ a = "0" * (self.nb_digits - len(a)) + a
+ b = "0" * (self.nb_digits - len(b)) + b
+ c = "0" * (self.nb_digits + 1 - len(c)) + c
+ if self.inverted_result:
+ c = c[::-1]
+ sequences.append(f"{a}+{b}={c}$")
+
+ sequences = self.tensorize(sequences)
+ ar_mask = (sequences == self.char2id["="]).long()
+ ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
+ return sequences, ar_mask
+
+ def seq2str(self, seq):
+ return "".join(self.id2char[x.item()] for x in seq)
+
+
+# class ProblemUnion(Problem):
+# problems = [ProblemByheart()]
+# nb_common_codes = 100
+
+# def generate_sequences(nb_samples):
+# problem_indexes = torch.randint(len(problems), (nb_samples,))
+# nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0)
+# print(f"{nb_samples_per_problem}")
+# all_seq = []
+# for nb, p in zip(nb_samples_per_problem, problems):
+# all_seq.append(p.generate_sequences(nb_samples_per_problem[nb]))
+# return all_seq
+
+# for strain, stest in zip(train_seq, test_seq):
+# s = torch.cat((strain, stest), 0)
+
+####################
class SandBox(Task):
self.batch_size = batch_size
self.device = device
+ self.problem = problem
- self.train_input, self.train_ar_mask = problem.generate_sequences(
+ self.train_input, self.train_ar_mask = self.problem.generate_sequences(
nb_train_samples
)
- self.test_input, self.test_ar_mask = problem.generate_sequences(nb_test_samples)
+ self.test_input, self.test_ar_mask = self.problem.generate_sequences(
+ nb_test_samples
+ )
+
+ self.train_input, self.train_ar_mask = self.train_input.to(
+ device
+ ), self.train_ar_mask.to(device)
+ self.test_input, self.test_ar_mask = self.test_input.to(
+ device
+ ), self.test_ar_mask.to(device)
self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
return self.nb_codes
def produce_results(
- self, n_epoch, model, result_dir, logger, deterministic_synthesis
+ self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
):
- def compute_accuracy(input, ar_mask):
+ def compute_accuracy(input, ar_mask, logger=None):
+ input, ar_mask = input[:nmax], ar_mask[:nmax]
result = input.clone() * (1 - ar_mask)
+
masked_inplace_autoregression(
model,
self.batch_size,
device=self.device,
)
+ if logger is not None:
+ for sp, st in zip(result[:10], input[:10]):
+ logger(
+ f"test_sequences {n_epoch} prediction {self.problem.seq2str(sp)}"
+ )
+ logger(
+ f" {n_epoch} ground truth {self.problem.seq2str(st)}"
+ )
+
nb_total = ar_mask.sum().item()
nb_correct = ((result == input).long() * ar_mask).sum().item()
)
test_nb_total, test_nb_correct = compute_accuracy(
- self.test_input, self.test_ar_mask
+ self.test_input, self.test_ar_mask, logger
)
logger(
device_storage=device_storage,
)
- print(f"{train_action_seq.size()=}")
-
train_frame_seq = self.frame2seq(train_frames).to(device_storage)
test_frame_seq = self.frame2seq(test_frames).to(device_storage)
self.nb_codes = nb_frame_codes + nb_action_codes
train_frame_seq = train_frame_seq.reshape(train_frame_seq.size(0) // 2, 2, -1)
- print(f"{train_action_seq.device=} {nb_frame_codes.device=}")
+
train_action_seq += nb_frame_codes
self.train_input = torch.cat(
(train_frame_seq[:, 0, :], train_action_seq, train_frame_seq[:, 1, :]), 1
(seq_start[:, None, :], seq_end[:, None, :], seq_predicted[:, None, :]), 1
)
result = result.reshape(-1, result.size(-1))
- print(f"{result.size()=}")
frames = self.seq2frame(result)
image_name = os.path.join(result_dir, f"world_result_{n_epoch:04d}.png")