From: François Fleuret Date: Sun, 14 Jul 2024 22:20:00 +0000 (+0200) Subject: Merge branch 'dev' X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=5f5c6c079c2751a76887444c211c5c464e875ed0;hp=00f7b3d445af8bb57376faabbf74eadc145faf1f;p=culture.git Merge branch 'dev' --- diff --git a/grids.py b/grids.py index 47e5861..eea8c6c 100755 --- a/grids.py +++ b/grids.py @@ -17,6 +17,92 @@ from torch.nn import functional as F import problem +def grow_islands(nb, height, width, nb_seeds, nb_iterations): + w = torch.empty(5, 1, 3, 3) + + w[0, 0] = torch.tensor( + [ + [1.0, 1.0, 1.0], + [1.0, 0.0, 1.0], + [1.0, 1.0, 1.0], + ] + ) + + w[1, 0] = torch.tensor( + [ + [-1.0, 1.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ] + ) + + w[2, 0] = torch.tensor( + [ + [0.0, 1.0, -1.0], + [0.0, 0.0, 1.0], + [0.0, 0.0, 0.0], + ] + ) + + w[3, 0] = torch.tensor( + [ + [0.0, 0.0, 0.0], + [0.0, 0.0, 1.0], + [0.0, 1.0, -1.0], + ] + ) + + w[4, 0] = torch.tensor( + [ + [0.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [-1.0, 1.0, 0.0], + ] + ) + + Z = torch.zeros(nb, height, width) + U = Z.flatten(1) + + for _ in range(nb_seeds): + M = F.conv2d(Z[:, None, :, :], w, padding=1) + M = torch.cat([M[:, :1], M[:, 1:].min(dim=1, keepdim=True).values], dim=1) + M = ((M[:, 0] == 0) & (Z == 0)).long() + Q = (M.flatten(1).max(dim=1).values > 0).long()[:, None] + M = M * torch.rand(M.size()) + M = M.flatten(1) + M = F.one_hot(M.argmax(dim=1), num_classes=M.size(1)) + U += M * Q + + for _ in range(nb_iterations): + M = F.conv2d(Z[:, None, :, :], w, padding=1) + M = torch.cat([M[:, :1], M[:, 1:].min(dim=1, keepdim=True).values], dim=1) + M = ((M[:, 1] >= 0) & (Z == 0)).long() + Q = (M.flatten(1).max(dim=1).values > 0).long()[:, None] + M = M * torch.rand(M.size()) + M = M.flatten(1) + M = F.one_hot(M.argmax(dim=1), num_classes=M.size(1)) + U = Z.flatten(1) + U += M * Q + + M = Z.clone() + Z = Z * (torch.arange(Z.size(1) * Z.size(2)) + 1).reshape(1, Z.size(1), Z.size(2)) + + while True: + W = Z.clone() + Z = F.max_pool2d(Z, 3, 1, 1) * M + if Z.equal(W): + break + + Z = Z.long() + U = Z.flatten(1) + V = F.one_hot(U).max(dim=1).values + W = V.cumsum(dim=1) - V + N = torch.arange(Z.size(0))[:, None, None].expand_as(Z) + Z = W[N, Z] + + return Z + + class Grids(problem.Problem): named_colors = [ ("white", [255, 255, 255]), @@ -32,11 +118,40 @@ class Grids(problem.Problem): ("gray", [128, 128, 128]), ] - def __init__(self, device=torch.device("cpu")): + def __init__( + self, + max_nb_cached_chunks=None, + chunk_size=None, + nb_threads=-1, + tasks=None, + ): self.colors = torch.tensor([c for _, c in self.named_colors]) self.height = 10 self.width = 10 - self.device = device + self.cache_rec_coo = {} + + all_tasks = [ + self.task_replace_color, + self.task_translate, + self.task_grow, + self.task_half_fill, + self.task_frame, + self.task_detect, + self.task_count, + self.task_trajectory, + self.task_bounce, + self.task_scale, + self.task_symbols, + self.task_isometry, + # self.task_islands, + ] + + if tasks is None: + self.all_tasks = all_tasks + else: + self.all_tasks = [getattr(self, "task_" + t) for t in tasks.split(",")] + + super().__init__(max_nb_cached_chunks, chunk_size, nb_threads) ###################################################################### @@ -110,10 +225,10 @@ class Grids(problem.Problem): c = c.long()[:, None] c = ( (1 - ((c == 1).long() + (c == 0).long() + (c == -1).long())) - * torch.tensor([64, 64, 64], device=c.device) - + (c == 1).long() * torch.tensor([0, 255, 0], device=c.device) - + (c == 0).long() * torch.tensor([255, 255, 255], device=c.device) - + (c == -1).long() * torch.tensor([255, 0, 0], device=c.device) + * torch.tensor([64, 64, 64]) + + (c == 1).long() * torch.tensor([0, 255, 0]) + + (c == 0).long() * torch.tensor([255, 255, 255]) + + (c == -1).long() * torch.tensor([255, 0, 0]) ) y[...] = c[:, :, None, None] @@ -195,121 +310,133 @@ class Grids(problem.Problem): return len(self.colors) # @torch.compile - def rec_coo_(self, nb_rec, min_height=3, min_width=3): - # @torch.compile - def overlap(ia, ja, ib, jb): - return ( - ia[1] >= ib[0] and ia[0] <= ib[1] and ja[1] >= jb[0] and ja[0] <= jb[1] - ) + def rec_coo( + self, + nb_rec, + min_height=3, + min_width=3, + surface_max=None, + prevent_overlap=False, + ): + if surface_max is None: + surface_max = self.height * self.width // 2 - if nb_rec == 3: - while True: - i = torch.randint(self.height + 1, (nb_rec, 2)).sort(dim=1).values - j = torch.randint(self.width + 1, (nb_rec, 2)).sort(dim=1).values - if ( - not ( - overlap(i[0], j[0], i[1], j[1]) - or overlap(i[0], j[0], i[2], j[2]) - or overlap(i[1], j[1], i[2], j[2]) - ) - and (i[:, 1] - i[:, 0]).min() >= min_height - and (j[:, 1] - j[:, 0]).min() >= min_width - ): - break - return ( - (i[0, 0], j[0, 0], i[0, 1], j[0, 1]), - (i[1, 0], j[1, 0], i[1, 1], j[1, 1]), - (i[2, 0], j[2, 0], i[2, 1], j[2, 1]), - ) + signature = (nb_rec, min_height, min_width, surface_max) - # That's quite a tensorial spaghetti mess to sample - # non-overlapping rectangles quickly, but made the generation of - # 100k samples go from 1h50 with a lame pure python code to 3min30s - # with this one. - # @torch.compile - def rec_coo(self, nb_rec, min_height=3, min_width=3): - nb_trials = 200 + try: + return self.cache_rec_coo[signature].pop() + except IndexError: + pass + except KeyError: + pass + N = 10000 while True: - v = ( - ( - torch.rand(nb_trials * nb_rec, self.height + 1, device=self.device) - .sort(dim=-1) - .indices - < 2 - ) - .long() - .cumsum(dim=1) - == 1 - ).long() + while True: + i = torch.randint(self.height, (N * nb_rec, 2)).sort(dim=-1).values + j = torch.randint(self.width, (N * nb_rec, 2)).sort(dim=-1).values - h = ( - ( - torch.rand(nb_trials * nb_rec, self.width + 1, device=self.device) - .sort(dim=-1) - .indices - < 2 + big_enough = ( + (i[:, 1] >= i[:, 0] + min_height) + & (j[:, 1] >= j[:, 0] + min_height) + & ((i[:, 1] - i[:, 0]) * (j[:, 1] - j[:, 0]) <= surface_max) ) - .long() - .cumsum(dim=1) - == 1 - ).long() - - i = torch.logical_and( - v.sum(dim=-1) >= min_height, h.sum(dim=-1) >= min_width - ) - v, h = v[i], h[i] - v = v[: v.size(0) - v.size(0) % nb_rec] - h = h[: h.size(0) - h.size(0) % nb_rec] - v = v.reshape(v.size(0) // nb_rec, nb_rec, -1) - h = h.reshape(h.size(0) // nb_rec, nb_rec, -1) + i, j = i[big_enough], j[big_enough] - r = v[:, :, :, None] * h[:, :, None, :] + n = i.size(0) - i.size(0) % nb_rec - valid = r.sum(dim=1).flatten(1).max(dim=-1).values == 1 + if n > 0: + break - v = v[valid] - h = h[valid] + i = i[:n].reshape(n // nb_rec, nb_rec, -1) + j = j[:n].reshape(n // nb_rec, nb_rec, -1) + + if prevent_overlap: + can_fit = ((i[:, :, 1] - i[:, :, 0]) * (j[:, :, 1] - j[:, :, 0])).sum( + dim=-1 + ) <= self.height * self.width + i, j = i[can_fit], j[can_fit] + if nb_rec == 2: + A_i1, A_i2, A_j1, A_j2 = ( + i[:, 0, 0], + i[:, 0, 1], + j[:, 0, 0], + j[:, 0, 1], + ) + B_i1, B_i2, B_j1, B_j2 = ( + i[:, 1, 0], + i[:, 1, 1], + j[:, 1, 0], + j[:, 1, 1], + ) + no_overlap = torch.logical_not( + (A_i1 >= B_i2) + & (A_i2 <= B_i1) + & (A_j1 >= B_j1) + & (A_j2 <= B_j1) + ) + i, j = i[no_overlap], j[no_overlap] + elif nb_rec == 3: + A_i1, A_i2, A_j1, A_j2 = ( + i[:, 0, 0], + i[:, 0, 1], + j[:, 0, 0], + j[:, 0, 1], + ) + B_i1, B_i2, B_j1, B_j2 = ( + i[:, 1, 0], + i[:, 1, 1], + j[:, 1, 0], + j[:, 1, 1], + ) + C_i1, C_i2, C_j1, C_j2 = ( + i[:, 2, 0], + i[:, 2, 1], + j[:, 2, 0], + j[:, 2, 1], + ) + no_overlap = ( + ( + (A_i1 >= B_i2) + | (A_i2 <= B_i1) + | (A_j1 >= B_j2) + | (A_j2 <= B_j1) + ) + & ( + (A_i1 >= C_i2) + | (A_i2 <= C_i1) + | (A_j1 >= C_j2) + | (A_j2 <= C_j1) + ) + & ( + (B_i1 >= C_i2) + | (B_i2 <= C_i1) + | (B_j1 >= C_j2) + | (B_j2 <= C_j1) + ) + ) + i, j = (i[no_overlap], j[no_overlap]) + else: + assert nb_rec == 1 - if v.size(0) > 0: + if i.size(0) > 1: break - av = torch.arange(v.size(2), device=self.device)[None, :] - ah = torch.arange(h.size(2), device=self.device)[None, :] - - return [ - (i1.item(), j1.item(), i2.item() + 1, j2.item() + 1) - for i1, j1, i2, j2 in zip( - v.size(2) - (v[0] * (v.size(2) - av)).max(dim=-1).values, - h.size(2) - (h[0] * (h.size(2) - ah)).max(dim=-1).values, - (v[0] * av).max(dim=-1).values, - (h[0] * ah).max(dim=-1).values, - ) + self.cache_rec_coo[signature] = [ + [ + ( + i[n, k, 0].item(), + j[n, k, 0].item(), + i[n, k, 1].item(), + j[n, k, 1].item(), + ) + for k in range(nb_rec) + ] + for n in range(i.size(0)) ] - # @torch.compile - def rec_coo_(self, x, n, min_height=3, min_width=3): - collision = x.new(x.size()) - while True: - collision[...] = 0 - result = [] - for _ in range(n): - while True: - i1, i2 = torch.randint(x.size(0), (2,)) - if i1 + min_height <= i2: - break - while True: - j1, j2 = torch.randint(x.size(1), (2,)) - if j1 + min_width <= j2: - break - collision[i1:i2, j1:j2] += 1 - if collision.max() > 1: - break - result.append((i1, j1, i2, j2)) - if collision.max() == 1: - break - return result + return self.cache_rec_coo[signature].pop() ###################################################################### @@ -318,7 +445,7 @@ class Grids(problem.Problem): nb_rec = 3 c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1 for X, f_X in [(A, f_A), (B, f_B)]: - r = self.rec_coo(nb_rec) + r = self.rec_coo(nb_rec, prevent_overlap=True) for n in range(nb_rec): i1, j1, i2, j2 = r[n] X[i1:i2, j1:j2] = c[n] @@ -326,12 +453,16 @@ class Grids(problem.Problem): # @torch.compile def task_translate(self, A, f_A, B, f_B): - di, dj = torch.randint(3, (2,)) - 1 + while True: + di, dj = torch.randint(3, (2,)) - 1 + if di.abs() + dj.abs() > 0: + break + nb_rec = 3 c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1 for X, f_X in [(A, f_A), (B, f_B)]: while True: - r = self.rec_coo(nb_rec) + r = self.rec_coo(nb_rec, prevent_overlap=True) i1, j1, i2, j2 = r[nb_rec - 1] if ( i1 + di >= 0 @@ -354,10 +485,10 @@ class Grids(problem.Problem): di, dj = torch.randint(2, (2,)) * 2 - 1 nb_rec = 3 c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1 - direction = torch.randint(2, (1,)) + direction = torch.randint(2, (1,)).item() for X, f_X in [(A, f_A), (B, f_B)]: while True: - r = self.rec_coo(nb_rec) + r = self.rec_coo(nb_rec, prevent_overlap=True) i1, j1, i2, j2 = r[nb_rec - 1] if i1 + 3 < i2 and j1 + 3 < j2: break @@ -376,13 +507,13 @@ class Grids(problem.Problem): f_X[i1:i2, j1:j2] = c[n] # @torch.compile - def task_color_grow(self, A, f_A, B, f_B): + def task_half_fill(self, A, f_A, B, f_B): di, dj = torch.randint(2, (2,)) * 2 - 1 nb_rec = 3 c = torch.randperm(len(self.colors) - 1)[: 2 * nb_rec] + 1 - direction = torch.randint(4, (1,)) + direction = torch.randint(4, (1,)).item() for X, f_X in [(A, f_A), (B, f_B)]: - r = self.rec_coo(nb_rec) + r = self.rec_coo(nb_rec, prevent_overlap=True) for n in range(nb_rec): i1, j1, i2, j2 = r[n] X[i1:i2, j1:j2] = c[2 * n] @@ -422,20 +553,24 @@ class Grids(problem.Problem): nb_rec = 3 c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1 for X, f_X in [(A, f_A), (B, f_B)]: - r = self.rec_coo(nb_rec) + r = self.rec_coo(nb_rec, prevent_overlap=True) for n in range(nb_rec): i1, j1, i2, j2 = r[n] X[i1:i2, j1:j2] = c[n] - f_X[i1:i2, j1:j2] = c[n] if n == nb_rec - 1: - f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = 0 + f_X[i1:i2, j1] = c[n] + f_X[i1:i2, j2 - 1] = c[n] + f_X[i1, j1:j2] = c[n] + f_X[i2 - 1, j1:j2] = c[n] + else: + f_X[i1:i2, j1:j2] = c[n] # @torch.compile def task_detect(self, A, f_A, B, f_B): nb_rec = 3 c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1 for X, f_X in [(A, f_A), (B, f_B)]: - r = self.rec_coo(nb_rec) + r = self.rec_coo(nb_rec, prevent_overlap=True) for n in range(nb_rec): i1, j1, i2, j2 = r[n] X[i1:i2, j1:j2] = c[n] @@ -478,29 +613,51 @@ class Grids(problem.Problem): return no, nq, nq_diag - # @torch.compile def task_count(self, A, f_A, B, f_B): - N = (torch.randint(4, (1,)) + 2).item() - c = torch.randperm(len(self.colors) - 1)[:N] + 1 + while True: + error = False + + N = torch.randint(5, (1,)).item() + 1 + c = torch.zeros(N + 1) + c[1:] = torch.randperm(len(self.colors) - 1)[:N] + 1 + + for X, f_X in [(A, f_A), (B, f_B)]: + if not hasattr(self, "cache_count") or len(self.cache_count) == 0: + self.cache_count = list( + grow_islands( + 1000, + self.height, + self.width, + nb_seeds=self.height * self.width // 8, + nb_iterations=self.height * self.width // 10, + ) + ) - for X, f_X in [(A, f_A), (B, f_B)]: - nb = torch.zeros(N, dtype=torch.int64) - q = torch.randint(N, (self.height * self.width,)) - k = torch.randperm(self.height * self.width) - for p in range(self.height * self.width): - i, j = k[p] % self.height, k[p] // self.height - no, nq, nq_diag = self.contact(X, i, j, c[q[p]]) - if no == 0 and nq_diag == 0: - if nq == 0: - if nb[q[p]] < self.width: - X[i, j] = c[q[p]] - nb[q[p]] += 1 - if nq == 1: - X[i, j] = c[q[p]] - - for n in range(N): - for j in range(nb[n]): - f_X[n, j] = c[n] + X[...] = self.cache_count.pop() + + k = (X.max() + 1 + (c.size(0) - 1)).item() + V = torch.arange(k) // (c.size(0) - 1) + V = (V + torch.rand(V.size())).sort().indices[: X.max() + 1] % ( + c.size(0) - 1 + ) + 1 + V[0] = 0 + X[...] = c[V[X]] + + if F.one_hot(X.flatten()).max(dim=0).values.sum().item() == N + 1: + f_X[...] = 0 + for e in range(1, N + 1): + for j in range((X == c[e]).sum() + 1): + if j < self.width: + f_X[e - 1, j] = c[e] + else: + error = True + break + else: + error = True + break + + if not error: + break # @torch.compile def task_trajectory(self, A, f_A, B, f_B): @@ -508,7 +665,10 @@ class Grids(problem.Problem): for X, f_X in [(A, f_A), (B, f_B)]: while True: di, dj = torch.randint(7, (2,)) - 3 - i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,)) + i, j = ( + torch.randint(self.height, (1,)).item(), + torch.randint(self.width, (1,)).item(), + ) if ( abs(di) + abs(dj) > 0 and i + 2 * di >= 0 @@ -549,8 +709,9 @@ class Grids(problem.Problem): X[...] = 0 for _ in range((self.height * self.width) // 10): - i, j = torch.randint(self.height, (1,)), torch.randint( - self.width, (1,) + i, j = ( + torch.randint(self.height, (1,)).item(), + torch.randint(self.width, (1,)).item(), ) X[i, j] = c[0] f_X[i, j] = c[0] @@ -560,7 +721,10 @@ class Grids(problem.Problem): if abs(di) + abs(dj) == 1: break - i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,)) + i, j = ( + torch.randint(self.height, (1,)).item(), + torch.randint(self.width, (1,)).item(), + ) X[i, j] = c[1] f_X[i, j] = c[1] @@ -598,18 +762,21 @@ class Grids(problem.Problem): def task_scale(self, A, f_A, B, f_B): c = torch.randperm(len(self.colors) - 1)[:2] + 1 - i, j = torch.randint(self.height // 2, (1,)), torch.randint( - self.width // 2, (1,) + i, j = ( + torch.randint(self.height // 2, (1,)).item(), + torch.randint(self.width // 2, (1,)).item(), ) for X, f_X in [(A, f_A), (B, f_B)]: for _ in range(3): while True: - i1, j1 = torch.randint(self.height // 2 + 1, (1,)), torch.randint( - self.width // 2 + 1, (1,) + i1, j1 = ( + torch.randint(self.height // 2 + 1, (1,)).item(), + torch.randint(self.width // 2 + 1, (1,)).item(), ) - i2, j2 = torch.randint(self.height // 2 + 1, (1,)), torch.randint( - self.width // 2 + 1, (1,) + i2, j2 = ( + torch.randint(self.height // 2 + 1, (1,)).item(), + torch.randint(self.width // 2 + 1, (1,)).item(), ) if i1 < i2 and j1 < j2 and min(i2 - i1, j2 - j1) <= 3: break @@ -639,7 +806,7 @@ class Grids(problem.Problem): ai, aj = i.float().mean(), j.float().mean() - q = torch.randint(3, (1,)) + 1 + q = torch.randint(3, (1,)).item() + 1 X[i[0] + delta // 2 - 1, j[0] + delta // 2 - 1] = c[0] X[i[0] + delta // 2 - 1, j[0] + delta // 2 + 1] = c[0] @@ -656,12 +823,12 @@ class Grids(problem.Problem): f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q] # @torch.compile - def task_ortho(self, A, f_A, B, f_B): + def task_isometry(self, A, f_A, B, f_B): nb_rec = 3 di, dj = torch.randint(3, (2,)) - 1 o = torch.tensor([[0.0, 1.0], [-1.0, 0.0]]) m = torch.eye(2) - for _ in range(torch.randint(4, (1,))): + for _ in range(torch.randint(4, (1,)).item()): m = m @ o if torch.rand(1) < 0.5: m[0, :] = -m[0, :] @@ -710,9 +877,88 @@ class Grids(problem.Problem): ): break + def compute_distance(self, walls, goal_i, goal_j): + max_length = walls.numel() + dist = torch.full_like(walls, max_length) + + dist[goal_i, goal_j] = 0 + pred_dist = torch.empty_like(dist) + + while True: + pred_dist.copy_(dist) + dist[1:-1, 1:-1] = ( + torch.cat( + ( + dist[None, 1:-1, 1:-1], + dist[None, 1:-1, 0:-2], + dist[None, 2:, 1:-1], + dist[None, 1:-1, 2:], + dist[None, 0:-2, 1:-1], + ), + 0, + ).min(dim=0)[0] + + 1 + ) + + dist = walls * max_length + (1 - walls) * dist + + if dist.equal(pred_dist): + return dist * (1 - walls) + # @torch.compile - def task_islands(self, A, f_A, B, f_B): - pass + def task_distance(self, A, f_A, B, f_B): + c = torch.randperm(len(self.colors) - 1)[:3] + 1 + dist0 = torch.empty(self.height + 2, self.width + 2) + dist1 = torch.empty(self.height + 2, self.width + 2) + for X, f_X in [(A, f_A), (B, f_B)]: + nb_rec = torch.randint(3, (1,)).item() + 1 + while True: + r = self.rec_coo(nb_rec, prevent_overlap=True) + X[...] = 0 + f_X[...] = 0 + for n in range(nb_rec): + i1, j1, i2, j2 = r[n] + X[i1:i2, j1:j2] = c[0] + f_X[i1:i2, j1:j2] = c[0] + while True: + i0, j0 = ( + torch.randint(self.height, (1,)).item(), + torch.randint(self.width, (1,)).item(), + ) + if X[i0, j0] == 0: + break + while True: + i1, j1 = ( + torch.randint(self.height, (1,)).item(), + torch.randint(self.width, (1,)).item(), + ) + if X[i1, j1] == 0: + break + dist1[...] = 1 + dist1[1:-1, 1:-1] = (X != 0).long() + dist1[...] = self.compute_distance(dist1, i1 + 1, j1 + 1) + if ( + dist1[i0 + 1, j0 + 1] >= 1 + and dist1[i0 + 1, j0 + 1] < self.height * 4 + ): + break + + dist0[...] = 1 + dist0[1:-1, 1:-1] = (X != 0).long() + dist0[...] = self.compute_distance(dist0, i0 + 1, j0 + 1) + + dist0 = dist0[1:-1, 1:-1] + dist1 = dist1[1:-1, 1:-1] + + D = dist1[i0, j0] + for d in range(1, D): + M = (dist0 == d) & (dist1 == D - d) + f_X[...] = (1 - M) * f_X + M * c[1] + + X[i0, j0] = c[2] + f_X[i0, j0] = c[2] + X[i1, j1] = c[2] + f_X[i1, j1] = c[2] # for X, f_X in [(A, f_A), (B, f_B)]: # n = torch.arange(self.height * self.width).reshape(self.height, self.width) @@ -722,24 +968,104 @@ class Grids(problem.Problem): # i,j=q%self.height,q//self.height # if - ###################################################################### + # @torch.compile + def task_puzzle(self, A, f_A, B, f_B): + S = 4 + i0, j0 = (self.height - S) // 2, (self.width - S) // 2 + c = torch.randperm(len(self.colors) - 1)[:4] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + while True: + f_X[...] = 0 + h = list(torch.randperm(c.size(0))) + n = torch.zeros(c.max() + 1) + for _ in range(2): + k = torch.randperm(S * S) + for q in k: + i, j = q % S + i0, q // S + j0 + if f_X[i, j] == 0: + r, s, t, u = ( + f_X[i - 1, j], + f_X[i, j - 1], + f_X[i + 1, j], + f_X[i, j + 1], + ) + r, s, t, u = torch.tensor([r, s, t, u])[torch.randperm(4)] + if r > 0 and n[r] < 6: + n[r] += 1 + f_X[i, j] = r + elif s > 0 and n[s] < 6: + n[s] += 1 + f_X[i, j] = s + elif t > 0 and n[t] < 6: + n[t] += 1 + f_X[i, j] = t + elif u > 0 and n[u] < 6: + n[u] += 1 + f_X[i, j] = u + else: + if len(h) > 0: + d = c[h.pop()] + n[d] += 1 + f_X[i, j] = d + + if n.sum() == S * S: + break - def all_tasks(self): - return [ - self.task_replace_color, - self.task_translate, - self.task_grow, - self.task_color_grow, - self.task_frame, - self.task_detect, - self.task_count, - self.task_trajectory, - self.task_bounce, - self.task_scale, - self.task_symbols, - self.task_ortho, - # self.task_islands, - ] + k = 0 + for d in range(4): + while True: + ii, jj = ( + torch.randint(self.height, (1,)).item(), + torch.randint(self.width, (1,)).item(), + ) + e = 0 + for i in range(S): + for j in range(S): + if ( + ii + i >= self.height + or jj + j >= self.width + or ( + f_X[i + i0, j + j0] == c[d] + and X[ii + i, jj + j] > 0 + ) + ): + e = 1 + if e == 0: + break + for i in range(S): + for j in range(S): + if f_X[i + i0, j + j0] == c[d]: + X[ii + i, jj + j] = c[d] + + def task_islands(self, A, f_A, B, f_B): + c = torch.randperm(len(self.colors) - 1)[:2] + 1 + for X, f_X in [(A, f_A), (B, f_B)]: + if not hasattr(self, "cache_islands") or len(self.cache_islands) == 0: + self.cache_islands = list( + grow_islands( + 1000, + self.height, + self.width, + nb_seeds=self.height * self.width // 20, + nb_iterations=self.height * self.width // 2, + ) + ) + + A = self.cache_islands.pop() + + while True: + i, j = ( + torch.randint(self.height // 2, (1,)).item(), + torch.randint(self.width // 2, (1,)).item(), + ) + if A[i, j] > 0: + break + + X[...] = (A > 0) * c[0] + X[i, j] = c[1] + f_X[...] = (A == A[i, j]) * c[1] + ((A > 0) & (A != A[i, j])) * c[0] + + ###################################################################### def trivial_prompts_and_answers(self, prompts, answers): S = self.height * self.width @@ -747,11 +1073,9 @@ class Grids(problem.Problem): f_Bs = answers return (Bs == f_Bs).long().min(dim=-1).values > 0 - def generate_prompts_and_answers( - self, nb, tasks=None, progress_bar=False, device="cpu" - ): + def generate_prompts_and_answers_(self, nb, tasks=None, progress_bar=False): if tasks is None: - tasks = self.all_tasks() + tasks = self.all_tasks S = self.height * self.width prompts = torch.zeros(nb, 3 * S + 2, dtype=torch.int64) @@ -772,12 +1096,12 @@ class Grids(problem.Problem): f_A = prompt[1 * (S + 1) : 1 * (S + 1) + S].view(self.height, self.width) B = prompt[2 * (S + 1) : 2 * (S + 1) + S].view(self.height, self.width) f_B = answer.view(self.height, self.width) - task = tasks[torch.randint(len(tasks), (1,))] + task = tasks[torch.randint(len(tasks), (1,)).item()] task(A, f_A, B, f_B) return prompts.flatten(1), answers.flatten(1) - def save_quizzes( + def save_quiz_illustrations( self, result_dir, filename_prefix, @@ -797,12 +1121,22 @@ class Grids(problem.Problem): nrow, ) + def save_some_examples(self, result_dir): + nb, nrow = 72, 4 + for t in self.all_tasks: + print(t.__name__) + prompts, answers = self.generate_prompts_and_answers_(nb, tasks=[t]) + self.save_quiz_illustrations( + result_dir, t.__name__, prompts[:nb], answers[:nb], nrow=nrow + ) + ###################################################################### if __name__ == "__main__": import time + # grids = Grids(max_nb_cached_chunks=5, chunk_size=100, nb_threads=4) grids = Grids() # nb = 1000 @@ -816,22 +1150,26 @@ if __name__ == "__main__": # print(f"{prompts.size(0)/delay:02f} seq/s") # exit(0) - if True: - nb = 72 + # if True: + nb, nrow = 72, 4 + # nb, nrow = 8, 2 - for t in grids.all_tasks(): - # for t in [grids.task_ortho]: - print(t.__name__) - prompts, answers = grids.generate_prompts_and_answers(nb, tasks=[t]) - grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=4) + # for t in grids.all_tasks: + for t in [grids.task_distance]: + print(t.__name__) + prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t]) + grids.save_quiz_illustrations( + "/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow + ) - exit(0) + # exit(0) - nb = 500 + nb = 1000 - for t in grids.all_tasks(): + # for t in grids.all_tasks: + for t in [grids.task_distance]: start_time = time.perf_counter() - prompts, answers = grids.generate_prompts_and_answers(nb, tasks=[t]) + prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t]) delay = time.perf_counter() - start_time print(f"{t.__name__} {prompts.size(0)/delay:02f} seq/s") @@ -841,7 +1179,7 @@ if __name__ == "__main__": predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1) predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1) - grids.save_quizzes( + grids.save_quiz_illustrations( "/tmp", "test", prompts[:nb], diff --git a/main.py b/main.py index 9c3d7f1..6b00bbf 100755 --- a/main.py +++ b/main.py @@ -15,22 +15,14 @@ import ffutils import mygpt import sky, grids, quiz_machine -from problem import MultiThreadProblem -# world quizzes vs. culture quizzes +import threading -###################################################################### - -if torch.cuda.is_available(): - device = torch.device("cuda") - torch.backends.cuda.matmul.allow_tf32 = True -else: - device = torch.device("cpu") +import torch.multiprocessing as mp ###################################################################### parser = argparse.ArgumentParser( - description="An implementation of GPT with cache.", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) @@ -40,7 +32,9 @@ parser.add_argument("--result_dir", type=str, default=None) parser.add_argument("--seed", type=int, default=0) -parser.add_argument("--max_percents_of_test_in_train", type=int, default=1) +parser.add_argument("--resume", action="store_true", default=False) + +parser.add_argument("--max_percents_of_test_in_train", type=int, default=-1) ######################################## @@ -54,6 +48,10 @@ parser.add_argument("--nb_train_samples", type=int, default=None) parser.add_argument("--nb_test_samples", type=int, default=None) +parser.add_argument("--nb_new_c_quizzes_for_train", type=int, default=None) + +parser.add_argument("--nb_new_c_quizzes_for_test", type=int, default=None) + parser.add_argument("--learning_rate", type=float, default=5e-4) ######################################## @@ -78,23 +76,34 @@ parser.add_argument("--deterministic_synthesis", action="store_true", default=Fa parser.add_argument("--problem", type=str, default="grids") -parser.add_argument("--multi_thread_problem", action="store_true", default=False) +parser.add_argument("--nb_threads", type=int, default=1) + +parser.add_argument("--gpus", type=str, default="all") parser.add_argument("--nb_gpts", type=int, default=5) -parser.add_argument("--min_to_validate", type=int, default=None) +parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.9) + +parser.add_argument("--proba_understands", type=float, default=0.9) -parser.add_argument("--max_to_validate", type=int, default=None) +parser.add_argument("--proba_not_understands", type=float, default=0.5) -parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.975) +parser.add_argument("--generation_temperature", type=float, default=1.0) -parser.add_argument("--generation_temperature", type=float, default=2.0) +parser.add_argument("--dirty_debug", action="store_true", default=False) -parser.add_argument("--deterministic_validation", action="store_true", default=False) +###################################################################### -parser.add_argument("--bidirectional_validation", action="store_true", default=False) +grids_tasks = ", ".join( + [x.__name__.removeprefix("task_") for x in grids.Grids().all_tasks] +) -parser.add_argument("--dirty_debug", action="store_true", default=False) +parser.add_argument( + "--grids_tasks", + type=str, + default=None, + help="A comma-separated subset of: " + grids_tasks + ", or None for all.", +) ###################################################################### @@ -112,12 +121,6 @@ parser.add_argument("--sky_speed", type=int, default=3) args = parser.parse_args() -if args.min_to_validate is None: - args.min_to_validate = args.nb_gpts - 1 - -if args.max_to_validate is None: - args.max_to_validate = args.nb_gpts - 1 - if args.result_dir is None: args.result_dir = f"results_culture" @@ -125,7 +128,7 @@ if args.result_dir is None: default_args = { "model": "37M", - "batch_size": 100, + "batch_size": 25, "nb_train_samples": 100000, "nb_test_samples": 10000, } @@ -183,11 +186,15 @@ else: ###################################################################### -try: - os.mkdir(args.result_dir) -except FileExistsError: - print(f"result directory {args.result_dir} already exists") - exit(1) +if args.resume: + assert os.path.isdir(args.result_dir) + +else: + try: + os.mkdir(args.result_dir) + except FileExistsError: + print(f"result directory {args.result_dir} already exists") + exit(1) log_file = open(os.path.join(args.result_dir, args.log_filename), "a") @@ -213,6 +220,10 @@ def log_string(s): sys.stdout.flush() +now = time.strftime("%Y%m%d-%H%M%S", time.localtime()) + +os.system(f"tar zcvf {args.result_dir}/src-{now}.tgz *.py") + log_string(f"argv {' '.join(sys.argv)}") for n in vars(args): @@ -221,6 +232,19 @@ for n in vars(args): ###################################################################### +if args.gpus == "all": + gpus_idx = range(torch.cuda.device_count()) +else: + gpus_idx = [int(k) for k in args.gpus.split(",")] + +gpus = [torch.device(f"cuda:{n}") for n in gpus_idx] + +if torch.cuda.is_available(): + main_device = gpus[0] +else: + assert len(gpus) == 0 + main_device = torch.device("cpu") + if args.dirty_debug: args.nb_train_samples = 2500 args.nb_test_samples = 100 @@ -240,16 +264,23 @@ if args.problem == "sky": nb_birds=args.sky_nb_birds, nb_iterations=args.sky_nb_iterations, speed=args.sky_speed, + max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100, + chunk_size=100, + nb_threads=args.nb_threads, ) back_accuracy = False elif args.problem == "grids": - problem = grids.Grids(device=device) + problem = grids.Grids( + max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100, + chunk_size=100, + nb_threads=args.nb_threads, + tasks=args.grids_tasks, + ) back_accuracy = True else: raise ValueError -if args.multi_thread_problem: - problem = MultiThreadProblem(problem, args.nb_train_samples, chunk_size=1000) +problem.save_some_examples(args.result_dir) quiz_machine = quiz_machine.QuizMachine( problem=problem, @@ -259,12 +290,12 @@ quiz_machine = quiz_machine.QuizMachine( batch_size=args.physical_batch_size, result_dir=args.result_dir, logger=log_string, - device=device, + device=main_device, ) ###################################################################### -log_string(f"device {device}") +log_string(f"main_device {main_device} gpus {[ str(g) for g in gpus]}") vocabulary_size = quiz_machine.vocabulary_size() @@ -272,64 +303,47 @@ log_string(f"vocabulary_size {vocabulary_size}") ###################################################################### -# Compute the entropy of the training tokens -token_count = 0 -for input in quiz_machine.batches(split="train", desc="train-entropy"): - token_count += F.one_hot(input, num_classes=quiz_machine.vocabulary_size()).sum( - (0, 1) - ) -token_probas = token_count / token_count.sum() -entropy = -torch.xlogy(token_probas, token_probas).sum() -train_set_perplexity = math.exp(entropy) +def run_tests(model, quiz_machine, deterministic_synthesis, local_device=main_device): + with torch.autograd.no_grad(): + model.eval().to(local_device) -###################################################################### -# A bit of paranoia never hurts + nb_test_samples, acc_test_loss = 0, 0.0 + nb_samples_accumulated = 0 -if args.max_percents_of_test_in_train >= 0: + for input in quiz_machine.batches(model, split="test"): + input = input.to(local_device) - def subsets_as_tuples(batches, cs): - s = set() - for batch in batches: - for x in batch: - s.add(tuple([v.item() for v in x])) - if len(s) == cs: - yield s - s = set() - yield s + bs = model(mygpt.BracketedSequence(input)) + output = bs.x - nb_test, nb_in_train = 0, 0 - for test_subset in subsets_as_tuples( - quiz_machine.batches(split="test", desc="test-check"), 25000 - ): - in_train = set() - for train_subset in subsets_as_tuples( - quiz_machine.batches(split="train", desc="train-check"), 25000 - ): - in_train.update(test_subset.intersection(train_subset)) - nb_in_train += len(in_train) - nb_test += len(test_subset) + loss = F.cross_entropy(output.transpose(1, 2), input) - log_string( - f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set" - ) + acc_test_loss += loss.item() * input.size(0) - assert ( - nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100 - ), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set" + nb_test_samples += input.size(0) -############################## + test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples)) + log_string(f"test_perplexity {n_epoch} model {model.id} {test_perplexity}") -def one_epoch(model, quiz_machine): - optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) + model.main_test_accuracy = quiz_machine.produce_results( + n_epoch=n_epoch, + model=model, + result_dir=args.result_dir, + deterministic_synthesis=deterministic_synthesis, + ) - model.train() + +def one_epoch(model, quiz_machine, local_device=main_device): + model.to(local_device).train() + + optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) nb_train_samples, acc_train_loss = 0, 0.0 - for input in quiz_machine.batches(split="train"): - input = input.to(device) + for input in quiz_machine.batches(model, split="train"): + input = input.to(local_device) if nb_train_samples % args.batch_size == 0: optimizer.zero_grad() @@ -347,148 +361,110 @@ def one_epoch(model, quiz_machine): train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples)) - log_string(f"train_perplexity {n_epoch} {train_perplexity}") + log_string(f"train_perplexity {n_epoch} model {model.id} {train_perplexity}") + run_tests(model, quiz_machine, deterministic_synthesis=False) -###################################################################### + model.to(main_device) -def run_tests(model, quiz_machine, deterministic_synthesis): - with torch.autograd.no_grad(): - model.eval() - - nb_test_samples, acc_test_loss = 0, 0.0 - nb_samples_accumulated = 0 - - for input in quiz_machine.batches(split="test"): - input = input.to(device) - - bs = model(mygpt.BracketedSequence(input)) - output = bs.x +###################################################################### - loss = F.cross_entropy(output.transpose(1, 2), input) +# This is the key routine that decides what generated quizzes to keep - acc_test_loss += loss.item() * input.size(0) - nb_test_samples += input.size(0) +# token_logprobas are NxMxT where M is the number of models - test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples)) - log_string(f"test_perplexity {n_epoch} {test_perplexity}") +def compute_valid_quizzes_(token_logprobas): + warnings.warn("validation with uniform constraints", RuntimeWarning) + l = token_logprobas.min(dim=-1).values.sort(dim=-1).values + return (l[:, 0] < math.log(0.1)) & (l[:, 1] > math.log(0.5)) - model.main_test_accuracy = quiz_machine.produce_results( - n_epoch=n_epoch, - model=model, - result_dir=args.result_dir, - deterministic_synthesis=deterministic_synthesis, - ) +def compute_valid_quizzes(token_logprobas): + l = token_logprobas.sum(dim=-1).sort(dim=-1).values + return (l[:, 0] < math.log(args.proba_not_understands)) & ( + l[:, 1] > math.log(args.proba_understands) + ) -###################################################################### +def extract_valid_quizzes_and_logprobas(recorded): + validated_quizzes, validated_logprobas = [], [] + for quizzes, token_logprobas in recorded: + validated_indices = compute_valid_quizzes(token_logprobas) + validated_quizzes.append(quizzes[validated_indices]) + validated_logprobas.append(token_logprobas[validated_indices]) -def valid_c_quizzes(recorded, criteria): - result = [q[criteria(c)] for q, c in recorded] - return torch.cat(result, dim=0) if len(result) > 0 else torch.tensor([]) + if len(validated_quizzes) > 0: + return torch.cat(validated_quizzes, dim=0), torch.cat( + validated_logprobas, dim=0 + ) + else: + return None, None ###################################################################### -def create_c_quizzes( - models, - quiz_machine, - nb_for_train=1000, - nb_for_test=100, -): - quizzes_and_nb_correct_records = [] - +def create_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100): nb_to_create = nb_for_train + nb_for_test - # ------------------------------------------------------------ + recorded_quizzes_logprobas = [] - standard_validity = lambda nb_correct: torch.logical_and( - nb_correct >= args.min_to_validate, nb_correct <= args.max_to_validate - ) - - file_name = os.path.join(args.result_dir, f"culture_c_quiz_{n_epoch:04d}_logp.dat") + nb_validated = 0 - with open(file_name, "w") as logp_file: - while ( - valid_c_quizzes(quizzes_and_nb_correct_records, standard_validity).size(0) - < nb_to_create - ): - # Select a model at random to generate the new quizzes + while nb_validated < nb_to_create: + model_for_generation = models[torch.randint(len(models), (1,))] - model_for_generation = models[torch.randint(len(models), (1,))] - - c_quizzes = quiz_machine.generate_quizzes( - nb_to_create, - model_for_generation=model_for_generation, - temperature=args.generation_temperature, - ) - - c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)] - - if c_quizzes.size(0) > 0: - nb_correct, seq_logproba = quiz_machine.compute_correctness( - c_quizzes, - models, - bidirectional_validation=args.bidirectional_validation, - deterministic_validation=args.deterministic_validation, - ) - - for n, l in zip(nb_correct, seq_logproba): - s = " ".join([str(x.item()) for x in l]) - logp_file.write(f"{n} {s}\n") + c_quizzes = quiz_machine.generate_quizzes( + nb_to_create, + model_for_generation=model_for_generation, + temperature=args.generation_temperature, + ) - if args.dirty_debug: - nb_correct = torch.randint( - len(models) + 1, nb_correct.size(), device=c_quizzes.device - ) + c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)] - quizzes_and_nb_correct_records.append((c_quizzes, nb_correct)) + if c_quizzes.size(0) > 0: + token_logproba = quiz_machine.solution_token_logprobas(models, c_quizzes) + recorded_quizzes_logprobas.append((c_quizzes, token_logproba)) - nv = F.one_hot(nb_correct, num_classes=len(models) + 1).sum(0) - nv = " ".join([str(x.item()) for x in nv]) + ( + validated_quizzes, + validated_logprobas, + ) = extract_valid_quizzes_and_logprobas(recorded_quizzes_logprobas) - nb_validated = valid_c_quizzes( - quizzes_and_nb_correct_records, standard_validity - ).size(0) + if validated_quizzes is not None: + nb_validated = validated_quizzes.size(0) - log_string( - f"keep c_quizzes model {model_for_generation.id} kept {nv} nb_accumulated {nb_validated} / {nb_to_create}" - ) + log_string( + f"keep c_quizzes model {model_for_generation.id} nb_accumulated {nb_validated} / {nb_to_create}" + ) # store the new c_quizzes which have been validated - new_c_quizzes = valid_c_quizzes(quizzes_and_nb_correct_records, standard_validity) - - quiz_machine.reverse_random_half_in_place(new_c_quizzes) - - quiz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True) - quiz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False) + quiz_machine.reverse_random_half_in_place(validated_quizzes) + quiz_machine.store_c_quizzes(validated_quizzes[:nb_for_train], for_train=True) + quiz_machine.store_c_quizzes( + validated_quizzes[nb_for_train:nb_to_create], for_train=False + ) - # save a bunch of images to investigate what quizzes with a - # certain nb of correct predictions look like + ###################################################################### + # save images with their logprobas - for n in range(len(models) + 1): - s = ( - "_validated" - if n >= args.min_to_validate and n <= args.max_to_validate - else "" - ) + vq = validated_quizzes[:72] + vl = validated_logprobas[:72] - q = valid_c_quizzes( - quizzes_and_nb_correct_records, criteria=lambda nb_correct: nb_correct == n - )[:72] + if vq.size(0) > 0: + prefix = f"culture_c_quiz_{n_epoch:04d}" + filename = os.path.join(args.result_dir, prefix + "_logp.pth") + torch.save(vl, filename) + # with open(file_name, "w") as logp_file: + # for l in vl: + # s = " ".join([str(x.item()) for x in l]) + # logp_file.write(s + "\n") - quiz_machine.reverse_random_half_in_place(q) - - if q.size(0) > 0: - quiz_machine.save_quizzes( - args.result_dir, f"culture_c_quiz_{n_epoch:04d}_N{n}{s}", q - ) + quiz_machine.save_quiz_illustrations(args.result_dir, prefix, vq) ###################################################################### @@ -496,6 +472,7 @@ def create_c_quizzes( models = [] for k in range(args.nb_gpts): + log_string(f"creating model {k} and its w_quizzes") model = mygpt.MyGPT( vocabulary_size=vocabulary_size, dim_model=args.dim_model, @@ -505,24 +482,109 @@ for k in range(args.nb_gpts): nb_blocks=args.nb_blocks, causal=True, dropout=args.dropout, - ).to(device) + ).to(main_device) model.main_test_accuracy = 0.0 model.id = k + model.train_w_quizzes = quiz_machine.generate_token_sequences(args.nb_train_samples) + quiz_machine.reverse_random_half_in_place(model.train_w_quizzes) + model.test_w_quizzes = quiz_machine.generate_token_sequences(args.nb_test_samples) + quiz_machine.reverse_random_half_in_place(model.test_w_quizzes) + models.append(model) +###################################################################### + +if args.resume: + try: + for model in models: + filename = f"gpt_{model.id:03d}.pth" + + try: + d = torch.load(os.path.join(args.result_dir, filename)) + model.load_state_dict(d[0]) + model.main_test_accuracy = d[1] + log_string(f"successfully loaded {filename}") + except FileNotFoundError: + log_string(f"cannot find {filename}") + pass + + try: + filename = "c_quizzes.pth" + quiz_machine.load_c_quizzes(os.path.join(args.result_dir, filename)) + log_string(f"successfully loaded {filename}") + except FileNotFoundError: + log_string(f"cannot find {filename}") + pass + + except: + log_string(f"error when loading {filename}.") + exit(1) + +###################################################################### nb_parameters = sum(p.numel() for p in models[0].parameters()) log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)") ###################################################################### -nb_new_c_quizzes_for_train = args.nb_train_samples // 50 -nb_new_c_quizzes_for_test = args.nb_test_samples // 50 +# Compute the entropy of the training tokens + +token_count = 0 +for input in quiz_machine.batches(models[0], split="train", desc="train-entropy"): + token_count += F.one_hot(input, num_classes=quiz_machine.vocabulary_size()).sum( + (0, 1) + ) +token_probas = token_count / token_count.sum() +entropy = -torch.xlogy(token_probas, token_probas).sum() +train_set_perplexity = math.exp(entropy) + +###################################################################### +# A bit of paranoia never hurts + +if args.max_percents_of_test_in_train >= 0: + + def subsets_as_tuples(batches, cs): + s = set() + for batch in batches: + for x in batch: + s.add(tuple([v.item() for v in x])) + if len(s) == cs: + yield s + s = set() + yield s + + nb_test, nb_in_train = 0, 0 + for test_subset in subsets_as_tuples( + quiz_machine.batches(models[0], split="test", desc="test-check"), 25000 + ): + in_train = set() + for train_subset in subsets_as_tuples( + quiz_machine.batches(models[0], split="train", desc="train-check"), 25000 + ): + in_train.update(test_subset.intersection(train_subset)) + nb_in_train += len(in_train) + nb_test += len(test_subset) + + log_string( + f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set" + ) + + assert ( + nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100 + ), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set" + +###################################################################### + +if args.nb_new_c_quizzes_for_train is None: + args.nb_new_c_quizzes_for_train = args.nb_train_samples // 50 + +if args.nb_new_c_quizzes_for_test is None: + args.nb_new_c_quizzes_for_test = args.nb_test_samples // 50 log_string( - f"nb_new_c_quizzes_for_train {nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {nb_new_c_quizzes_for_test}" + f"nb_new_c_quizzes_for_train {args.nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {args.nb_new_c_quizzes_for_test}" ) ###################################################################### @@ -530,8 +592,9 @@ log_string( if args.dirty_debug: args.accuracy_to_make_c_quizzes = 0.0 args.nb_gpts = 2 - nb_new_c_quizzes_for_train = 100 - nb_new_c_quizzes_for_test = 10 + args.nb_new_c_quizzes_for_train = 100 + args.nb_new_c_quizzes_for_test = 10 + ###################################################################### @@ -541,46 +604,63 @@ for n_epoch in range(args.nb_epochs): cta = " ".join([f"{float(m.main_test_accuracy):.04f}" for m in models]) log_string(f"current_test_accuracies {cta}") + ################################################## + # If all the models are good enough, generate new quizzes and + # re-compute the test errors + + if min([m.main_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes: + create_c_quizzes( + models, + quiz_machine, + nb_for_train=args.nb_new_c_quizzes_for_train, + nb_for_test=args.nb_new_c_quizzes_for_test, + ) + + filename = "c_quizzes.pth" + quiz_machine.save_c_quizzes(os.path.join(args.result_dir, filename)) + log_string(f"wrote {filename}") + + # Force one epoch of training + for model in models: + model.main_test_accuracy = 0.0 + ################################################## # Select, improve, and eval the worst model - weakest_model = min(models, key=lambda m: float(m.main_test_accuracy)) + ranked_models = sorted(models, key=lambda m: float(m.main_test_accuracy)) - log_string( - f"training model {weakest_model.id} main_test_accuracy {weakest_model.main_test_accuracy}" - ) + weakest_models = ranked_models[: len(gpus)] - one_epoch(weakest_model, quiz_machine) + threads = [] - log_string( - f"train_set_composition w_quizzes {quiz_machine.nb_batch_w_quizzes} c_quizzes {quiz_machine.nb_batch_c_quizzes}" - ) + for gpu, model in zip(gpus, weakest_models): + log_string(f"training model {model.id}") - run_tests(weakest_model, quiz_machine, deterministic_synthesis=False) + t = threading.Thread( + target=one_epoch, daemon=True, args=(model, quiz_machine, gpu) + ) - log_string( - f"test_set_composition w_quizzes {quiz_machine.nb_batch_w_quizzes} c_quizzes {quiz_machine.nb_batch_c_quizzes}" - ) + threads.append(t) - ################################################## - # Replace a fraction of the w_quizzes with fresh ones + t.start() - quiz_machine.renew_w_quizzes(args.nb_train_samples // args.nb_gpts) + for t in threads: + t.join() - ################################################## - # If all the models are good enough, generate new quizzes and - # re-compute the test errors + # Save the models to disk - if min([m.main_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes: - create_c_quizzes( - models, - quiz_machine, - nb_for_train=nb_new_c_quizzes_for_train, - nb_for_test=nb_new_c_quizzes_for_test, + for model in weakest_models: + filename = f"gpt_{model.id:03d}.pth" + torch.save( + (model.state_dict(), model.main_test_accuracy), + os.path.join(args.result_dir, filename), ) + log_string(f"wrote {filename}") - for model in models: - run_tests(model, quiz_machine, deterministic_synthesis=False) + # Renew the training samples + + for model in weakest_models: + quiz_machine.renew_w_quizzes(model, args.nb_train_samples) ###################################################################### diff --git a/problem.py b/problem.py index a49634d..05f3b20 100755 --- a/problem.py +++ b/problem.py @@ -9,18 +9,34 @@ import threading, queue, torch, tqdm class Problem: + def __init__(self, max_nb_cached_chunks=None, chunk_size=None, nb_threads=-1): + if nb_threads > 0: + self.chunk_size = chunk_size + self.queue = queue.Queue(maxsize=max_nb_cached_chunks) + for _ in range(nb_threads): + threading.Thread(target=self.fill_cache, daemon=True).start() + self.rest = None + else: + self.queue = None + + def nb_cached_quizzes(self): + if self.queue is None: + return None + else: + return self.queue.qsize() * self.chunk_size + def nb_token_values(self): pass def trivial_prompts_and_answers(self, prompts, answers): pass - # returns two tensors nb x D and nb x D' - def generate_prompts_and_answers(self, nb): + # The one to implement, returns two tensors nb x D and nb x D' + def generate_prompts_and_answers_(self, nb): pass # save a file to vizualize quizzes, you can save a txt or png file - def save_quizzes( + def save_quiz_illustrations( self, result_dir, filename_prefix, @@ -31,49 +47,16 @@ class Problem: ): pass - -class MultiThreadProblem: - def __init__(self, problem, max_nb_cached_chunks, chunk_size, nb_threads=1): - self.problem = problem - self.chunk_size = chunk_size - self.queue = queue.Queue(maxsize=max_nb_cached_chunks) - for _ in range(nb_threads): - threading.Thread(target=self.fill_cache, daemon=True).start() - self.rest = None - - def nb_token_values(self): - return self.problem.nb_token_values() - - def save_quizzes( - self, - result_dir, - filename_prefix, - prompts, - answers, - predicted_prompts=None, - predicted_answers=None, - ): - self.problem.save_quizzes( - result_dir, - filename_prefix, - prompts, - answers, - predicted_prompts=None, - predicted_answers=None, - ) - def fill_cache(self): while True: - prompts, answers = self.problem.generate_prompts_and_answers( - self.chunk_size - ) + prompts, answers = self.generate_prompts_and_answers_(self.chunk_size) self.queue.put((prompts.to("cpu"), answers.to("cpu")), block=True) - def trivial_prompts_and_answers(self, prompts, answers): - return self.problem.trivial_prompts_and_answers(prompts, answers) - def generate_prompts_and_answers(self, nb): + if self.queue is None: + return self.generate_prompts_and_answers_(nb) + if self.rest is not None: prompts, answers = rest else: @@ -105,3 +88,6 @@ class MultiThreadProblem: prompts, answers = prompts[:-k], answers[:-k] return prompts, answers + + def save_some_examples(self, result_dir): + pass diff --git a/quiz_machine.py b/quiz_machine.py index f0fb408..bc468d3 100755 --- a/quiz_machine.py +++ b/quiz_machine.py @@ -5,7 +5,7 @@ # Written by Francois Fleuret -import math, os, tqdm, warnings +import math, os, tqdm, warnings, sys import torch, torchvision @@ -15,6 +15,38 @@ from torch.nn import functional as F import mygpt from mygpt import BracketedSequence +import threading + +###################################################################### +# if output is log(P(X=y)) and target is Y, returns -log P(X=Y) + H(X +# | X != Y) + + +# output is NxCxT and target is NxT +def confusion(output, target, reduction="mean"): + N, C, T = output.shape + output = output.permute(0, 2, 1).reshape(-1, C) + target = target.flatten() + all_t = torch.arange(N * T, device=output.device) + output = output.log_softmax(dim=-1) + result = -output[all_t, target] + + output[all_t, target] = float("-inf") + output = output.log_softmax(dim=-1) + e = output.exp() + output[all_t, target] = 0 + result = result - (output * e).sum(-1) + + if reduction == "none": + return result.reshape(N, T) + elif reduction == "mean": + return result.reshape(N, T).mean() + elif reduction == "sum": + return result.reshape(N, T).sum() + else: + raise ValueError(f"unknown reduction '{reduction}'.") + + ###################################################################### # ar_mask is a tensor with 0s and 1s, of same shape as input, with @@ -235,32 +267,18 @@ class QuizMachine: self.prompt_len = None self.answer_len = None - self.train_w_quizzes = self.generate_token_sequences(nb_train_samples) - self.reverse_random_half_in_place(self.train_w_quizzes) - self.train_w_quizzes = self.train_w_quizzes.to(device) - - self.test_w_quizzes = self.generate_token_sequences(nb_test_samples).to(device) - self.reverse_random_half_in_place(self.test_w_quizzes) - self.test_w_quizzes = self.test_w_quizzes.to(device) - + self.LOCK_C_QUIZZES = threading.Lock() self.train_c_quizzes = [] self.test_c_quizzes = [] - if result_dir is not None: - self.save_quizzes( - result_dir, - "culture_w_quizzes", - self.train_w_quizzes[:72], - ) - - def save_quizzes( + def save_quiz_illustrations( self, result_dir, filename_prefix, quizzes, mistakes=None, ): - quizzes = quizzes.clone() + quizzes = quizzes.clone().to("cpu") n_forward = quizzes[quizzes[:, 0] == self.token_forward] n_backward = quizzes[:, 0] == self.token_backward backward = quizzes[n_backward] @@ -271,14 +289,14 @@ class QuizMachine: predicted_answers = 1 - predicted_prompts if mistakes is not None: # 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct - predicted_prompts *= mistakes - predicted_answers *= mistakes + predicted_prompts *= mistakes.to("cpu") + predicted_answers *= mistakes.to("cpu") else: # 0/2 ~ not-to-predict / to predict predicted_prompts *= 2 predicted_answers *= 2 - self.problem.save_quizzes( + self.problem.save_quiz_illustrations( result_dir, filename_prefix, quizzes[:, 1 : 1 + self.prompt_len], @@ -287,34 +305,41 @@ class QuizMachine: predicted_answers, ) - def batches(self, split="train", desc=None): - assert split in {"train", "test"} - if split == "train": - w_quizzes = self.train_w_quizzes - c_quizzes = self.train_c_quizzes - else: - w_quizzes = self.test_w_quizzes - c_quizzes = self.test_c_quizzes + def vocabulary_size(self): + return self.nb_token_values - if len(c_quizzes) > 0: - c_quizzes = torch.cat(c_quizzes, dim=0) - if c_quizzes.size(0) > w_quizzes.size(0) // 2: - i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2] - c_quizzes = c_quizzes[i] + ###################################################################### - i = torch.randperm(w_quizzes.size(0))[ - : w_quizzes.size(0) - c_quizzes.size(0) - ] - w_quizzes = w_quizzes[i] + def batches(self, model, split="train", desc=None): + assert split in {"train", "test"} - self.nb_batch_w_quizzes = w_quizzes.size(0) - self.nb_batch_c_quizzes = c_quizzes.size(0) + with self.LOCK_C_QUIZZES: + if split == "train": + w_quizzes = model.train_w_quizzes + c_quizzes = self.train_c_quizzes + else: + w_quizzes = model.test_w_quizzes + c_quizzes = self.test_c_quizzes + + if len(c_quizzes) > 0: + c_quizzes = torch.cat(c_quizzes, dim=0) + if c_quizzes.size(0) > w_quizzes.size(0) // 2: + i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2] + c_quizzes = c_quizzes[i] + + i = torch.randperm(w_quizzes.size(0))[ + : w_quizzes.size(0) - c_quizzes.size(0) + ] + w_quizzes = w_quizzes[i] - input = torch.cat([w_quizzes, c_quizzes], dim=0) - else: - input = w_quizzes - self.nb_batch_w_quizzes = w_quizzes.size(0) - self.nb_batch_c_quizzes = 0 + self.nb_batch_w_quizzes = w_quizzes.size(0) + self.nb_batch_c_quizzes = c_quizzes.size(0) + + input = torch.cat([w_quizzes, c_quizzes], dim=0) + else: + input = w_quizzes + self.nb_batch_w_quizzes = w_quizzes.size(0) + self.nb_batch_c_quizzes = 0 # Shuffle input = input[torch.randperm(input.size(0))] @@ -326,13 +351,13 @@ class QuizMachine: ): yield batch - def vocabulary_size(self): - return self.nb_token_values + ###################################################################### def produce_results( self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000 ): def compute_accuracy(input, log_prefix=None): + input = input.to(self.device) ar_mask = self.make_ar_mask(input) result = input.clone() * (1 - ar_mask) seq_logproba = torch.empty(input.size(0), device=self.device) @@ -373,19 +398,15 @@ class QuizMachine: backward_nb_total = correct[n_backward].size(0) self.logger( - f"{log_prefix}_forward_accuracy {n_epoch} model {model.id} nb_correct {forward_nb_correct} / {forward_nb_total} ({forward_nb_correct*100/forward_nb_total} %)" - ) - - self.logger( - f"{log_prefix}_backward_accuracy {n_epoch} model {model.id} nb_correct {backward_nb_correct} / {backward_nb_total} ({backward_nb_correct*100/backward_nb_total} %)" + f"{log_prefix}_accuracy {n_epoch} model {model.id} forward {forward_nb_correct} / {forward_nb_total} backward {backward_nb_correct} / {backward_nb_total}" ) return result, correct - compute_accuracy(self.train_w_quizzes[:nmax], log_prefix="train") + # compute_accuracy(model.train_w_quizzes[:nmax], log_prefix="train") test_result, test_correct = compute_accuracy( - self.test_w_quizzes[:nmax], log_prefix="test" + model.test_w_quizzes[:nmax], log_prefix="test" ) main_test_accuracy = test_correct.sum() / test_correct.size(0) @@ -393,7 +414,7 @@ class QuizMachine: ############################## - self.save_quizzes( + self.save_quiz_illustrations( result_dir, f"culture_prediction_{n_epoch:04d}_{model.id:02d}", quizzes=test_result[:72], @@ -402,19 +423,65 @@ class QuizMachine: return main_test_accuracy - def renew_w_quizzes(self, nb, for_train=True): - input = self.train_w_quizzes if for_train else self.test_w_quizzes + ###################################################################### + + def renew_w_quizzes(self, model, nb, for_train=True): + input = model.train_w_quizzes if for_train else model.test_w_quizzes nb = min(nb, input.size(0)) input[:-nb] = input[nb:].clone() fresh_w_quizzes = self.generate_token_sequences(nb) self.reverse_random_half_in_place(fresh_w_quizzes) - input[-nb:] = fresh_w_quizzes.to(self.device) + input[-nb:] = fresh_w_quizzes.to("cpu") + + ###################################################################### def store_c_quizzes(self, new_c_quizzes, for_train=True): - if for_train: - self.train_c_quizzes.append(new_c_quizzes) - else: - self.test_c_quizzes.append(new_c_quizzes) + with self.LOCK_C_QUIZZES: + if for_train: + self.train_c_quizzes.append(new_c_quizzes.to("cpu")) + else: + self.test_c_quizzes.append(new_c_quizzes.to("cpu")) + + def save_c_quizzes(self, filename): + torch.save((self.train_c_quizzes, self.test_c_quizzes), filename) + + def load_c_quizzes(self, filename): + self.train_c_quizzes, self.test_c_quizzes = torch.load(filename) + + ###################################################################### + + def solution_token_logprobas(self, models, c_quizzes): + logproba = c_quizzes.new_zeros( + c_quizzes.size(0), + len(models), + c_quizzes.size(1), + device=self.device, + dtype=torch.float32, + ) + + for model in models: + with torch.autograd.no_grad(): + t = model.training + model.eval() + + for input, l in zip( + c_quizzes.split(self.batch_size), logproba.split(self.batch_size) + ): + input = input.to(self.device) + ar_mask = self.make_ar_mask(input) + output = model(mygpt.BracketedSequence(input)).x + l[:, model.id] = ( + -F.cross_entropy( + output.transpose(1, 2), input, reduction="none" + ) + * ar_mask + ) + + model.train(t) + + return logproba.to("cpu") + + ############################################################### def compute_correctness( self, @@ -488,7 +555,10 @@ class QuizMachine: def generate_quizzes(self, nb, model_for_generation, temperature=1.0): c_quizzes = torch.empty( - nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64 + nb, + self.prompt_len + self.answer_len + 2, + device=self.device, + dtype=torch.int64, ) seq_logproba = torch.zeros(nb, device=self.device) @@ -538,4 +608,4 @@ class QuizMachine: device=self.device, ) - return c_quizzes + return c_quizzes.to("cpu") diff --git a/sky.py b/sky.py index ed440d3..cc5bd4f 100755 --- a/sky.py +++ b/sky.py @@ -50,7 +50,11 @@ class Sky(problem.Problem): speed=2, nb_iterations=2, avoid_collision=True, + max_nb_cached_chunks=None, + chunk_size=None, + nb_threads=-1, ): + super().__init__(max_nb_cached_chunks, chunk_size, nb_threads) self.height = height self.width = width self.nb_birds = nb_birds @@ -296,7 +300,7 @@ class Sky(problem.Problem): return prompts, answers - def save_quizzes( + def save_quiz_illustrations( self, result_dir, filename_prefix, @@ -327,7 +331,7 @@ if __name__ == "__main__": predicted_prompts = torch.randint(3, (prompts.size(0),)) - 1 predicted_answers = torch.randint(3, (prompts.size(0),)) - 1 - sky.save_quizzes( + sky.save_quiz_illustrations( "/tmp", "test", prompts, answers, predicted_prompts, predicted_answers )