######################################################################
+ def sample_rworld_states(self, N=1000):
+ nb_rec_max = 5
+ while True:
+ rn = torch.randint(nb_rec_max - 1, (N,)) + 2
+ ri = torch.randint(self.height, (N, nb_rec_max, 2)).sort(dim=2).values
+ rj = torch.randint(self.width, (N, nb_rec_max, 2)).sort(dim=2).values
+ rz = torch.randint(2, (N, nb_rec_max))
+ rc = torch.randint(self.nb_colors - 1, (N, nb_rec_max)) + 1
+ n = torch.arange(nb_rec_max)
+ nb_collisions = (
+ (
+ (ri[:, :, None, 0] <= ri[:, None, :, 1])
+ & (ri[:, :, None, 1] >= ri[:, None, :, 0])
+ & (rj[:, :, None, 0] <= rj[:, None, :, 1])
+ & (rj[:, :, None, 1] >= rj[:, None, :, 0])
+ & (rz[:, :, None] == rz[:, None, :])
+ & (n[None, :, None] < rn[:, None, None])
+ & (n[None, None, :] < n[None, :, None])
+ )
+ .long()
+ .flatten(1)
+ .sum(dim=1)
+ )
+
+ no_collision = nb_collisions == 0
+
+ if no_collision.any():
+ print(no_collision.long().sum() / N)
+ self.rn = rn[no_collision]
+ self.ri = ri[no_collision]
+ self.rj = rj[no_collision]
+ self.rz = rz[no_collision]
+ self.rc = rc[no_collision]
+ break
+
+ def task_rworld_change_color(self, A, f_A, B, f_B):
+ nb_rec = 3
+ c = torch.randperm(self.nb_colors - 1)[: nb_rec + 1] + 1
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ r = self.rec_coo(nb_rec, prevent_overlap=True)
+ for n in range(nb_rec):
+ i1, j1, i2, j2 = r[n]
+ X[i1:i2, j1:j2] = c[n]
+ f_X[i1:i2, j1:j2] = c[n if n > 0 else -1]
+
+ ######################################################################
+
# @torch.compile
def task_replace_color(self, A, f_A, B, f_B):
nb_rec = 3
# grids = Grids(max_nb_cached_chunks=5, chunk_size=100, nb_threads=4)
grids = Grids()
+ grids.sample_rworld_states()
+ exit(0)
# nb = 5
# quizzes = grids.generate_w_quizzes_(nb, tasks=[grids.task_fill])
parser.add_argument("--dirty_debug", action="store_true", default=False)
-parser.add_argument("--test_generator", action="store_true", default=False)
+parser.add_argument("--test", type=str, default=None)
######################################################################
######################################################################
-if args.test_generator:
+if args.test == "tsne":
+ model = models[0]
+
+ quizzes = []
+ labels = []
+ nb_samples_per_task = 1000
+
+ for n, t in enumerate(args.grids_world_tasks.split(",")):
+ quizzes.append(
+ quiz_machine.problem.generate_w_quizzes(nb_samples_per_task, [t])
+ )
+ labels.append(torch.full((quizzes[-1].size(0),), n))
+
+ quizzes = torch.cat(quizzes, dim=0)
+ labels = torch.cat(labels, dim=0)
+
+ with torch.autograd.no_grad():
+ model.eval().to(main_device)
+ record = []
+ for input, targets in zip(
+ quizzes.split(args.batch_size), labels.split(args.batch_size)
+ ):
+ input = input.to(main_device)
+ bs = mygpt.BracketedSequence(input)
+ bs = mygpt.BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
+ bs = model.embedding(bs)
+ bs = model.trunk[args.nb_blocks // 2](bs)
+ record.append((bs.x.to("cpu"), targets))
+
+ x = torch.cat([x for x, y in record], dim=0).flatten(1)
+ y = torch.cat([y for x, y in record], dim=0)
+
+ print(f"{x.size()=} {y.size()=}")
+ # torch.save((x,y), "/tmp/embed.pth")
+ # exit(0)
+
+ from sklearn.manifold import TSNE
+
+ x_np = x.numpy()
+ z_np = TSNE(n_components=2, perplexity=50).fit_transform(x_np)
+ z = torch.from_numpy(z_np)
+
+ print(f"{z.size()=}")
+
+ with open("/tmp/result.dat", "w") as f:
+ for k in range(z.size(0)):
+ f.write(f"{y[k]} {z[k,0]} {z[k,1]}\n")
+
+ exit(0)
+
+######################################################################
+
+if args.test == "generator":
token_prolog_0 = vocabulary_size + 0
token_prolog_1 = vocabulary_size + 1
token_prolog_2 = vocabulary_size + 2