"--task",
type=str,
default="twotargets",
- help="byheart, learnop, guessop, degradation, twotargets, addition, picoclvr, mnist, maze, snake, stack, expr, rpl, grid, qmlp",
+ help="byheart, learnop, guessop, mixing, twotargets, addition, picoclvr, mnist, maze, snake, stack, expr, rpl, grid, qmlp",
)
parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
##############################
# Misc
-parser.add_argument("--degradation_hard", action="store_true", default=False)
+parser.add_argument("--mixing_hard", action="store_true", default=False)
######################################################################
"nb_train_samples": 50000,
"nb_test_samples": 10000,
},
- "degradation": {
+ "mixing": {
"model": "37M",
"batch_size": 25,
"nb_train_samples": 250000,
device=device,
)
-elif args.task == "degradation":
+elif args.task == "mixing":
task = tasks.SandBox(
- problem=problems.ProblemDegradation(hard=args.degradation_hard),
+ problem=problems.ProblemMixing(hard=args.mixing_hard),
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
batch_size=args.batch_size,
####################
-
-
class ProblemDegradation(Problem):
def __init__(self, nb_state_tokens=5, nb_time_steps=12, value_max=25, hard=False):
assert value_max // nb_state_tokens >= 2
return "".join(self.id2char[x.item()] for x in seq)
+####################
+
+
+class ProblemMixing(Problem):
+ def __init__(self, height=3, width=3, nb_time_steps=12, hard=False):
+ self.height = height
+ self.width = width
+ self.nb_time_steps = nb_time_steps
+ self.hard = hard
+
+ def start(self, nb):
+ return (
+ torch.arange(self.height * self.width)
+ .reshape(1, 1, self.height, self.width)
+ .expand(nb, -1, -1, -1)
+ )
+
+ def moves(self, x):
+ y = (
+ x[:, None, :, :]
+ .expand(-1, self.height * 2 + self.width * 2, -1, -1)
+ .clone()
+ )
+ k = 0
+
+ for i in range(self.height):
+ y[:, k, i, :] = y[:, k, i, :].roll(dims=-1, shifts=-1)
+ k += 1
+ y[:, k, i, :] = y[:, k, i, :].roll(dims=-1, shifts=1)
+ k += 1
+
+ for j in range(self.width):
+ y[:, k, :, j] = y[:, k, :, j].roll(dims=-1, shifts=-1)
+ k += 1
+ y[:, k, :, j] = y[:, k, :, j].roll(dims=-1, shifts=1)
+ k += 1
+
+ return y
+
+ def generate_sequences(self, nb):
+ y = self.start(nb)
+ x = y[torch.arange(nb), torch.randint(y.size(1), (nb,))]
+
+ seq = [x.flatten(1)]
+
+ for t in range(self.nb_time_steps - 1):
+ y = self.moves(x)
+ x = y[torch.arange(nb), torch.randint(y.size(1), (nb,))]
+ seq.append(x.flatten(1))
+
+ if self.hard:
+ seq.reverse()
+
+ seq = torch.cat(seq, dim=1)
+ return seq, seq.new_full(seq.size(), 1, dtype=torch.int64)
+
+ def compute_nb_correct(self, input, ar_mask, result):
+ a = [
+ x.reshape(result.size(0), self.height, self.width)
+ for x in result.split(self.height * self.width, dim=1)
+ ]
+ if self.hard:
+ a.reverse()
+
+ x = a[0]
+
+ y = self.start(result.size(0)).to(x.device)
+ d = (x[:, None] - y).abs().sum((-1, -2)).min(dim=-1).values
+
+ for t in range(self.nb_time_steps - 1):
+ x0, x = a[t], a[t + 1]
+ y = self.moves(x0)
+ d = d + (x[:, None] - y).abs().sum((-1, -2)).min(dim=-1).values
+
+ nb_total, nb_correct = result.size(0), (d == 0).long().sum().item()
+
+ return nb_total, nb_correct
+
+ def seq2str(self, seq):
+ return " | ".join(
+ [
+ " ".join(
+ ["-".join([f"{x:02d}" for x in s]) for s in r.split(self.width)]
+ )
+ for r in seq.split(self.height * self.width)
+ ]
+ )
+
+
+####################
+
if __name__ == "__main__":
- p = ProblemDegradation(hard=False)
+ p = ProblemMixing(hard=True)
s, m = p.generate_sequences(10000)
- for x in s[:100]:
+ for x in s[:5]:
print(p.seq2str(x))
print(p.compute_nb_correct(None, None, s))