import problem
+def grow_islands(nb, height, width, nb_seeds, nb_iterations):
+ w = torch.empty(5, 1, 3, 3)
+
+ w[0, 0] = torch.tensor(
+ [
+ [1.0, 1.0, 1.0],
+ [1.0, 0.0, 1.0],
+ [1.0, 1.0, 1.0],
+ ]
+ )
+
+ w[1, 0] = torch.tensor(
+ [
+ [-1.0, 1.0, 0.0],
+ [1.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0],
+ ]
+ )
+
+ w[2, 0] = torch.tensor(
+ [
+ [0.0, 1.0, -1.0],
+ [0.0, 0.0, 1.0],
+ [0.0, 0.0, 0.0],
+ ]
+ )
+
+ w[3, 0] = torch.tensor(
+ [
+ [0.0, 0.0, 0.0],
+ [0.0, 0.0, 1.0],
+ [0.0, 1.0, -1.0],
+ ]
+ )
+
+ w[4, 0] = torch.tensor(
+ [
+ [0.0, 0.0, 0.0],
+ [1.0, 0.0, 0.0],
+ [-1.0, 1.0, 0.0],
+ ]
+ )
+
+ Z = torch.zeros(nb, height, width)
+ U = Z.flatten(1)
+
+ for _ in range(nb_seeds):
+ M = F.conv2d(Z[:, None, :, :], w, padding=1)
+ M = torch.cat([M[:, :1], M[:, 1:].min(dim=1, keepdim=True).values], dim=1)
+ M = ((M[:, 0] == 0) & (Z == 0)).long()
+ Q = (M.flatten(1).max(dim=1).values > 0).long()[:, None]
+ M = M * torch.rand(M.size())
+ M = M.flatten(1)
+ M = F.one_hot(M.argmax(dim=1), num_classes=M.size(1))
+ U += M * Q
+
+ for _ in range(nb_iterations):
+ M = F.conv2d(Z[:, None, :, :], w, padding=1)
+ M = torch.cat([M[:, :1], M[:, 1:].min(dim=1, keepdim=True).values], dim=1)
+ M = ((M[:, 1] >= 0) & (Z == 0)).long()
+ Q = (M.flatten(1).max(dim=1).values > 0).long()[:, None]
+ M = M * torch.rand(M.size())
+ M = M.flatten(1)
+ M = F.one_hot(M.argmax(dim=1), num_classes=M.size(1))
+ U = Z.flatten(1)
+ U += M * Q
+
+ M = Z.clone()
+ Z = Z * (torch.arange(Z.size(1) * Z.size(2)) + 1).reshape(1, Z.size(1), Z.size(2))
+
+ while True:
+ W = Z.clone()
+ Z = F.max_pool2d(Z, 3, 1, 1) * M
+ if Z.equal(W):
+ break
+
+ Z = Z.long()
+ U = Z.flatten(1)
+ V = F.one_hot(U).max(dim=1).values
+ W = V.cumsum(dim=1) - V
+ N = torch.arange(Z.size(0))[:, None, None].expand_as(Z)
+ Z = W[N, Z]
+
+ return Z
+
+
class Grids(problem.Problem):
named_colors = [
("white", [255, 255, 255]),
("gray", [128, 128, 128]),
]
- def __init__(self, device=torch.device("cpu")):
+ def __init__(
+ self,
+ max_nb_cached_chunks=None,
+ chunk_size=None,
+ nb_threads=-1,
+ tasks=None,
+ ):
self.colors = torch.tensor([c for _, c in self.named_colors])
self.height = 10
self.width = 10
- self.device = device
+ self.cache_rec_coo = {}
+
+ all_tasks = [
+ self.task_replace_color,
+ self.task_translate,
+ self.task_grow,
+ self.task_half_fill,
+ self.task_frame,
+ self.task_detect,
+ self.task_count,
+ self.task_trajectory,
+ self.task_bounce,
+ self.task_scale,
+ self.task_symbols,
+ self.task_isometry,
+ # self.task_islands,
+ ]
+
+ if tasks is None:
+ self.all_tasks = all_tasks
+ else:
+ self.all_tasks = [getattr(self, "task_" + t) for t in tasks.split(",")]
+
+ super().__init__(max_nb_cached_chunks, chunk_size, nb_threads)
######################################################################
c = c.long()[:, None]
c = (
(1 - ((c == 1).long() + (c == 0).long() + (c == -1).long()))
- * torch.tensor([64, 64, 64], device=c.device)
- + (c == 1).long() * torch.tensor([0, 255, 0], device=c.device)
- + (c == 0).long() * torch.tensor([255, 255, 255], device=c.device)
- + (c == -1).long() * torch.tensor([255, 0, 0], device=c.device)
+ * torch.tensor([64, 64, 64])
+ + (c == 1).long() * torch.tensor([0, 255, 0])
+ + (c == 0).long() * torch.tensor([255, 255, 255])
+ + (c == -1).long() * torch.tensor([255, 0, 0])
)
y[...] = c[:, :, None, None]
return len(self.colors)
# @torch.compile
- def rec_coo_(self, nb_rec, min_height=3, min_width=3):
- # @torch.compile
- def overlap(ia, ja, ib, jb):
- return (
- ia[1] >= ib[0] and ia[0] <= ib[1] and ja[1] >= jb[0] and ja[0] <= jb[1]
- )
+ def rec_coo(
+ self,
+ nb_rec,
+ min_height=3,
+ min_width=3,
+ surface_max=None,
+ prevent_overlap=False,
+ ):
+ if surface_max is None:
+ surface_max = self.height * self.width // 2
- if nb_rec == 3:
- while True:
- i = torch.randint(self.height + 1, (nb_rec, 2)).sort(dim=1).values
- j = torch.randint(self.width + 1, (nb_rec, 2)).sort(dim=1).values
- if (
- not (
- overlap(i[0], j[0], i[1], j[1])
- or overlap(i[0], j[0], i[2], j[2])
- or overlap(i[1], j[1], i[2], j[2])
- )
- and (i[:, 1] - i[:, 0]).min() >= min_height
- and (j[:, 1] - j[:, 0]).min() >= min_width
- ):
- break
- return (
- (i[0, 0], j[0, 0], i[0, 1], j[0, 1]),
- (i[1, 0], j[1, 0], i[1, 1], j[1, 1]),
- (i[2, 0], j[2, 0], i[2, 1], j[2, 1]),
- )
+ signature = (nb_rec, min_height, min_width, surface_max)
- # That's quite a tensorial spaghetti mess to sample
- # non-overlapping rectangles quickly, but made the generation of
- # 100k samples go from 1h50 with a lame pure python code to 3min30s
- # with this one.
- # @torch.compile
- def rec_coo(self, nb_rec, min_height=3, min_width=3):
- nb_trials = 200
+ try:
+ return self.cache_rec_coo[signature].pop()
+ except IndexError:
+ pass
+ except KeyError:
+ pass
+ N = 10000
while True:
- v = (
- (
- torch.rand(nb_trials * nb_rec, self.height + 1, device=self.device)
- .sort(dim=-1)
- .indices
- < 2
- )
- .long()
- .cumsum(dim=1)
- == 1
- ).long()
+ while True:
+ i = torch.randint(self.height, (N * nb_rec, 2)).sort(dim=-1).values
+ j = torch.randint(self.width, (N * nb_rec, 2)).sort(dim=-1).values
- h = (
- (
- torch.rand(nb_trials * nb_rec, self.width + 1, device=self.device)
- .sort(dim=-1)
- .indices
- < 2
+ big_enough = (
+ (i[:, 1] >= i[:, 0] + min_height)
+ & (j[:, 1] >= j[:, 0] + min_height)
+ & ((i[:, 1] - i[:, 0]) * (j[:, 1] - j[:, 0]) <= surface_max)
)
- .long()
- .cumsum(dim=1)
- == 1
- ).long()
-
- i = torch.logical_and(
- v.sum(dim=-1) >= min_height, h.sum(dim=-1) >= min_width
- )
- v, h = v[i], h[i]
- v = v[: v.size(0) - v.size(0) % nb_rec]
- h = h[: h.size(0) - h.size(0) % nb_rec]
- v = v.reshape(v.size(0) // nb_rec, nb_rec, -1)
- h = h.reshape(h.size(0) // nb_rec, nb_rec, -1)
+ i, j = i[big_enough], j[big_enough]
- r = v[:, :, :, None] * h[:, :, None, :]
+ n = i.size(0) - i.size(0) % nb_rec
- valid = r.sum(dim=1).flatten(1).max(dim=-1).values == 1
+ if n > 0:
+ break
- v = v[valid]
- h = h[valid]
+ i = i[:n].reshape(n // nb_rec, nb_rec, -1)
+ j = j[:n].reshape(n // nb_rec, nb_rec, -1)
+
+ if prevent_overlap:
+ can_fit = ((i[:, :, 1] - i[:, :, 0]) * (j[:, :, 1] - j[:, :, 0])).sum(
+ dim=-1
+ ) <= self.height * self.width
+ i, j = i[can_fit], j[can_fit]
+ if nb_rec == 2:
+ A_i1, A_i2, A_j1, A_j2 = (
+ i[:, 0, 0],
+ i[:, 0, 1],
+ j[:, 0, 0],
+ j[:, 0, 1],
+ )
+ B_i1, B_i2, B_j1, B_j2 = (
+ i[:, 1, 0],
+ i[:, 1, 1],
+ j[:, 1, 0],
+ j[:, 1, 1],
+ )
+ no_overlap = torch.logical_not(
+ (A_i1 >= B_i2)
+ & (A_i2 <= B_i1)
+ & (A_j1 >= B_j1)
+ & (A_j2 <= B_j1)
+ )
+ i, j = i[no_overlap], j[no_overlap]
+ elif nb_rec == 3:
+ A_i1, A_i2, A_j1, A_j2 = (
+ i[:, 0, 0],
+ i[:, 0, 1],
+ j[:, 0, 0],
+ j[:, 0, 1],
+ )
+ B_i1, B_i2, B_j1, B_j2 = (
+ i[:, 1, 0],
+ i[:, 1, 1],
+ j[:, 1, 0],
+ j[:, 1, 1],
+ )
+ C_i1, C_i2, C_j1, C_j2 = (
+ i[:, 2, 0],
+ i[:, 2, 1],
+ j[:, 2, 0],
+ j[:, 2, 1],
+ )
+ no_overlap = (
+ (
+ (A_i1 >= B_i2)
+ | (A_i2 <= B_i1)
+ | (A_j1 >= B_j2)
+ | (A_j2 <= B_j1)
+ )
+ & (
+ (A_i1 >= C_i2)
+ | (A_i2 <= C_i1)
+ | (A_j1 >= C_j2)
+ | (A_j2 <= C_j1)
+ )
+ & (
+ (B_i1 >= C_i2)
+ | (B_i2 <= C_i1)
+ | (B_j1 >= C_j2)
+ | (B_j2 <= C_j1)
+ )
+ )
+ i, j = (i[no_overlap], j[no_overlap])
+ else:
+ assert nb_rec == 1
- if v.size(0) > 0:
+ if i.size(0) > 1:
break
- av = torch.arange(v.size(2), device=self.device)[None, :]
- ah = torch.arange(h.size(2), device=self.device)[None, :]
-
- return [
- (i1.item(), j1.item(), i2.item() + 1, j2.item() + 1)
- for i1, j1, i2, j2 in zip(
- v.size(2) - (v[0] * (v.size(2) - av)).max(dim=-1).values,
- h.size(2) - (h[0] * (h.size(2) - ah)).max(dim=-1).values,
- (v[0] * av).max(dim=-1).values,
- (h[0] * ah).max(dim=-1).values,
- )
+ self.cache_rec_coo[signature] = [
+ [
+ (
+ i[n, k, 0].item(),
+ j[n, k, 0].item(),
+ i[n, k, 1].item(),
+ j[n, k, 1].item(),
+ )
+ for k in range(nb_rec)
+ ]
+ for n in range(i.size(0))
]
- # @torch.compile
- def rec_coo_(self, x, n, min_height=3, min_width=3):
- collision = x.new(x.size())
- while True:
- collision[...] = 0
- result = []
- for _ in range(n):
- while True:
- i1, i2 = torch.randint(x.size(0), (2,))
- if i1 + min_height <= i2:
- break
- while True:
- j1, j2 = torch.randint(x.size(1), (2,))
- if j1 + min_width <= j2:
- break
- collision[i1:i2, j1:j2] += 1
- if collision.max() > 1:
- break
- result.append((i1, j1, i2, j2))
- if collision.max() == 1:
- break
- return result
+ return self.cache_rec_coo[signature].pop()
######################################################################
nb_rec = 3
c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
- r = self.rec_coo(nb_rec)
+ r = self.rec_coo(nb_rec, prevent_overlap=True)
for n in range(nb_rec):
i1, j1, i2, j2 = r[n]
X[i1:i2, j1:j2] = c[n]
# @torch.compile
def task_translate(self, A, f_A, B, f_B):
- di, dj = torch.randint(3, (2,)) - 1
+ while True:
+ di, dj = torch.randint(3, (2,)) - 1
+ if di.abs() + dj.abs() > 0:
+ break
+
nb_rec = 3
c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
while True:
- r = self.rec_coo(nb_rec)
+ r = self.rec_coo(nb_rec, prevent_overlap=True)
i1, j1, i2, j2 = r[nb_rec - 1]
if (
i1 + di >= 0
di, dj = torch.randint(2, (2,)) * 2 - 1
nb_rec = 3
c = torch.randperm(len(self.colors) - 1)[:nb_rec] + 1
- direction = torch.randint(2, (1,))
+ direction = torch.randint(2, (1,)).item()
for X, f_X in [(A, f_A), (B, f_B)]:
while True:
- r = self.rec_coo(nb_rec)
+ r = self.rec_coo(nb_rec, prevent_overlap=True)
i1, j1, i2, j2 = r[nb_rec - 1]
if i1 + 3 < i2 and j1 + 3 < j2:
break
f_X[i1:i2, j1:j2] = c[n]
# @torch.compile
- def task_color_grow(self, A, f_A, B, f_B):
+ def task_half_fill(self, A, f_A, B, f_B):
di, dj = torch.randint(2, (2,)) * 2 - 1
nb_rec = 3
c = torch.randperm(len(self.colors) - 1)[: 2 * nb_rec] + 1
- direction = torch.randint(4, (1,))
+ direction = torch.randint(4, (1,)).item()
for X, f_X in [(A, f_A), (B, f_B)]:
- r = self.rec_coo(nb_rec)
+ r = self.rec_coo(nb_rec, prevent_overlap=True)
for n in range(nb_rec):
i1, j1, i2, j2 = r[n]
X[i1:i2, j1:j2] = c[2 * n]
nb_rec = 3
c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
- r = self.rec_coo(nb_rec)
+ r = self.rec_coo(nb_rec, prevent_overlap=True)
for n in range(nb_rec):
i1, j1, i2, j2 = r[n]
X[i1:i2, j1:j2] = c[n]
- f_X[i1:i2, j1:j2] = c[n]
if n == nb_rec - 1:
- f_X[i1 + 1 : i2 - 1, j1 + 1 : j2 - 1] = 0
+ f_X[i1:i2, j1] = c[n]
+ f_X[i1:i2, j2 - 1] = c[n]
+ f_X[i1, j1:j2] = c[n]
+ f_X[i2 - 1, j1:j2] = c[n]
+ else:
+ f_X[i1:i2, j1:j2] = c[n]
# @torch.compile
def task_detect(self, A, f_A, B, f_B):
nb_rec = 3
c = torch.randperm(len(self.colors) - 1)[: nb_rec + 1] + 1
for X, f_X in [(A, f_A), (B, f_B)]:
- r = self.rec_coo(nb_rec)
+ r = self.rec_coo(nb_rec, prevent_overlap=True)
for n in range(nb_rec):
i1, j1, i2, j2 = r[n]
X[i1:i2, j1:j2] = c[n]
return no, nq, nq_diag
- # @torch.compile
def task_count(self, A, f_A, B, f_B):
- N = (torch.randint(4, (1,)) + 2).item()
- c = torch.randperm(len(self.colors) - 1)[:N] + 1
+ while True:
+ error = False
+
+ N = torch.randint(5, (1,)).item() + 1
+ c = torch.zeros(N + 1)
+ c[1:] = torch.randperm(len(self.colors) - 1)[:N] + 1
+
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ if not hasattr(self, "cache_count") or len(self.cache_count) == 0:
+ self.cache_count = list(
+ grow_islands(
+ 1000,
+ self.height,
+ self.width,
+ nb_seeds=self.height * self.width // 8,
+ nb_iterations=self.height * self.width // 10,
+ )
+ )
- for X, f_X in [(A, f_A), (B, f_B)]:
- nb = torch.zeros(N, dtype=torch.int64)
- q = torch.randint(N, (self.height * self.width,))
- k = torch.randperm(self.height * self.width)
- for p in range(self.height * self.width):
- i, j = k[p] % self.height, k[p] // self.height
- no, nq, nq_diag = self.contact(X, i, j, c[q[p]])
- if no == 0 and nq_diag == 0:
- if nq == 0:
- if nb[q[p]] < self.width:
- X[i, j] = c[q[p]]
- nb[q[p]] += 1
- if nq == 1:
- X[i, j] = c[q[p]]
-
- for n in range(N):
- for j in range(nb[n]):
- f_X[n, j] = c[n]
+ X[...] = self.cache_count.pop()
+
+ k = (X.max() + 1 + (c.size(0) - 1)).item()
+ V = torch.arange(k) // (c.size(0) - 1)
+ V = (V + torch.rand(V.size())).sort().indices[: X.max() + 1] % (
+ c.size(0) - 1
+ ) + 1
+ V[0] = 0
+ X[...] = c[V[X]]
+
+ if F.one_hot(X.flatten()).max(dim=0).values.sum().item() == N + 1:
+ f_X[...] = 0
+ for e in range(1, N + 1):
+ for j in range((X == c[e]).sum() + 1):
+ if j < self.width:
+ f_X[e - 1, j] = c[e]
+ else:
+ error = True
+ break
+ else:
+ error = True
+ break
+
+ if not error:
+ break
# @torch.compile
def task_trajectory(self, A, f_A, B, f_B):
for X, f_X in [(A, f_A), (B, f_B)]:
while True:
di, dj = torch.randint(7, (2,)) - 3
- i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,))
+ i, j = (
+ torch.randint(self.height, (1,)).item(),
+ torch.randint(self.width, (1,)).item(),
+ )
if (
abs(di) + abs(dj) > 0
and i + 2 * di >= 0
X[...] = 0
for _ in range((self.height * self.width) // 10):
- i, j = torch.randint(self.height, (1,)), torch.randint(
- self.width, (1,)
+ i, j = (
+ torch.randint(self.height, (1,)).item(),
+ torch.randint(self.width, (1,)).item(),
)
X[i, j] = c[0]
f_X[i, j] = c[0]
if abs(di) + abs(dj) == 1:
break
- i, j = torch.randint(self.height, (1,)), torch.randint(self.width, (1,))
+ i, j = (
+ torch.randint(self.height, (1,)).item(),
+ torch.randint(self.width, (1,)).item(),
+ )
X[i, j] = c[1]
f_X[i, j] = c[1]
def task_scale(self, A, f_A, B, f_B):
c = torch.randperm(len(self.colors) - 1)[:2] + 1
- i, j = torch.randint(self.height // 2, (1,)), torch.randint(
- self.width // 2, (1,)
+ i, j = (
+ torch.randint(self.height // 2, (1,)).item(),
+ torch.randint(self.width // 2, (1,)).item(),
)
for X, f_X in [(A, f_A), (B, f_B)]:
for _ in range(3):
while True:
- i1, j1 = torch.randint(self.height // 2 + 1, (1,)), torch.randint(
- self.width // 2 + 1, (1,)
+ i1, j1 = (
+ torch.randint(self.height // 2 + 1, (1,)).item(),
+ torch.randint(self.width // 2 + 1, (1,)).item(),
)
- i2, j2 = torch.randint(self.height // 2 + 1, (1,)), torch.randint(
- self.width // 2 + 1, (1,)
+ i2, j2 = (
+ torch.randint(self.height // 2 + 1, (1,)).item(),
+ torch.randint(self.width // 2 + 1, (1,)).item(),
)
if i1 < i2 and j1 < j2 and min(i2 - i1, j2 - j1) <= 3:
break
ai, aj = i.float().mean(), j.float().mean()
- q = torch.randint(3, (1,)) + 1
+ q = torch.randint(3, (1,)).item() + 1
X[i[0] + delta // 2 - 1, j[0] + delta // 2 - 1] = c[0]
X[i[0] + delta // 2 - 1, j[0] + delta // 2 + 1] = c[0]
f_X[i[0] : i[0] + delta, j[0] : j[0] + delta] = c[q]
# @torch.compile
- def task_ortho(self, A, f_A, B, f_B):
+ def task_isometry(self, A, f_A, B, f_B):
nb_rec = 3
di, dj = torch.randint(3, (2,)) - 1
o = torch.tensor([[0.0, 1.0], [-1.0, 0.0]])
m = torch.eye(2)
- for _ in range(torch.randint(4, (1,))):
+ for _ in range(torch.randint(4, (1,)).item()):
m = m @ o
if torch.rand(1) < 0.5:
m[0, :] = -m[0, :]
):
break
+ def compute_distance(self, walls, goal_i, goal_j):
+ max_length = walls.numel()
+ dist = torch.full_like(walls, max_length)
+
+ dist[goal_i, goal_j] = 0
+ pred_dist = torch.empty_like(dist)
+
+ while True:
+ pred_dist.copy_(dist)
+ dist[1:-1, 1:-1] = (
+ torch.cat(
+ (
+ dist[None, 1:-1, 1:-1],
+ dist[None, 1:-1, 0:-2],
+ dist[None, 2:, 1:-1],
+ dist[None, 1:-1, 2:],
+ dist[None, 0:-2, 1:-1],
+ ),
+ 0,
+ ).min(dim=0)[0]
+ + 1
+ )
+
+ dist = walls * max_length + (1 - walls) * dist
+
+ if dist.equal(pred_dist):
+ return dist * (1 - walls)
+
# @torch.compile
- def task_islands(self, A, f_A, B, f_B):
- pass
+ def task_distance(self, A, f_A, B, f_B):
+ c = torch.randperm(len(self.colors) - 1)[:3] + 1
+ dist0 = torch.empty(self.height + 2, self.width + 2)
+ dist1 = torch.empty(self.height + 2, self.width + 2)
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ nb_rec = torch.randint(3, (1,)).item() + 1
+ while True:
+ r = self.rec_coo(nb_rec, prevent_overlap=True)
+ X[...] = 0
+ f_X[...] = 0
+ for n in range(nb_rec):
+ i1, j1, i2, j2 = r[n]
+ X[i1:i2, j1:j2] = c[0]
+ f_X[i1:i2, j1:j2] = c[0]
+ while True:
+ i0, j0 = (
+ torch.randint(self.height, (1,)).item(),
+ torch.randint(self.width, (1,)).item(),
+ )
+ if X[i0, j0] == 0:
+ break
+ while True:
+ i1, j1 = (
+ torch.randint(self.height, (1,)).item(),
+ torch.randint(self.width, (1,)).item(),
+ )
+ if X[i1, j1] == 0:
+ break
+ dist1[...] = 1
+ dist1[1:-1, 1:-1] = (X != 0).long()
+ dist1[...] = self.compute_distance(dist1, i1 + 1, j1 + 1)
+ if (
+ dist1[i0 + 1, j0 + 1] >= 1
+ and dist1[i0 + 1, j0 + 1] < self.height * 4
+ ):
+ break
+
+ dist0[...] = 1
+ dist0[1:-1, 1:-1] = (X != 0).long()
+ dist0[...] = self.compute_distance(dist0, i0 + 1, j0 + 1)
+
+ dist0 = dist0[1:-1, 1:-1]
+ dist1 = dist1[1:-1, 1:-1]
+
+ D = dist1[i0, j0]
+ for d in range(1, D):
+ M = (dist0 == d) & (dist1 == D - d)
+ f_X[...] = (1 - M) * f_X + M * c[1]
+
+ X[i0, j0] = c[2]
+ f_X[i0, j0] = c[2]
+ X[i1, j1] = c[2]
+ f_X[i1, j1] = c[2]
# for X, f_X in [(A, f_A), (B, f_B)]:
# n = torch.arange(self.height * self.width).reshape(self.height, self.width)
# i,j=q%self.height,q//self.height
# if
- ######################################################################
+ # @torch.compile
+ def task_puzzle(self, A, f_A, B, f_B):
+ S = 4
+ i0, j0 = (self.height - S) // 2, (self.width - S) // 2
+ c = torch.randperm(len(self.colors) - 1)[:4] + 1
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ while True:
+ f_X[...] = 0
+ h = list(torch.randperm(c.size(0)))
+ n = torch.zeros(c.max() + 1)
+ for _ in range(2):
+ k = torch.randperm(S * S)
+ for q in k:
+ i, j = q % S + i0, q // S + j0
+ if f_X[i, j] == 0:
+ r, s, t, u = (
+ f_X[i - 1, j],
+ f_X[i, j - 1],
+ f_X[i + 1, j],
+ f_X[i, j + 1],
+ )
+ r, s, t, u = torch.tensor([r, s, t, u])[torch.randperm(4)]
+ if r > 0 and n[r] < 6:
+ n[r] += 1
+ f_X[i, j] = r
+ elif s > 0 and n[s] < 6:
+ n[s] += 1
+ f_X[i, j] = s
+ elif t > 0 and n[t] < 6:
+ n[t] += 1
+ f_X[i, j] = t
+ elif u > 0 and n[u] < 6:
+ n[u] += 1
+ f_X[i, j] = u
+ else:
+ if len(h) > 0:
+ d = c[h.pop()]
+ n[d] += 1
+ f_X[i, j] = d
+
+ if n.sum() == S * S:
+ break
- def all_tasks(self):
- return [
- self.task_replace_color,
- self.task_translate,
- self.task_grow,
- self.task_color_grow,
- self.task_frame,
- self.task_detect,
- self.task_count,
- self.task_trajectory,
- self.task_bounce,
- self.task_scale,
- self.task_symbols,
- self.task_ortho,
- # self.task_islands,
- ]
+ k = 0
+ for d in range(4):
+ while True:
+ ii, jj = (
+ torch.randint(self.height, (1,)).item(),
+ torch.randint(self.width, (1,)).item(),
+ )
+ e = 0
+ for i in range(S):
+ for j in range(S):
+ if (
+ ii + i >= self.height
+ or jj + j >= self.width
+ or (
+ f_X[i + i0, j + j0] == c[d]
+ and X[ii + i, jj + j] > 0
+ )
+ ):
+ e = 1
+ if e == 0:
+ break
+ for i in range(S):
+ for j in range(S):
+ if f_X[i + i0, j + j0] == c[d]:
+ X[ii + i, jj + j] = c[d]
+
+ def task_islands(self, A, f_A, B, f_B):
+ c = torch.randperm(len(self.colors) - 1)[:2] + 1
+ for X, f_X in [(A, f_A), (B, f_B)]:
+ if not hasattr(self, "cache_islands") or len(self.cache_islands) == 0:
+ self.cache_islands = list(
+ grow_islands(
+ 1000,
+ self.height,
+ self.width,
+ nb_seeds=self.height * self.width // 20,
+ nb_iterations=self.height * self.width // 2,
+ )
+ )
+
+ A = self.cache_islands.pop()
+
+ while True:
+ i, j = (
+ torch.randint(self.height // 2, (1,)).item(),
+ torch.randint(self.width // 2, (1,)).item(),
+ )
+ if A[i, j] > 0:
+ break
+
+ X[...] = (A > 0) * c[0]
+ X[i, j] = c[1]
+ f_X[...] = (A == A[i, j]) * c[1] + ((A > 0) & (A != A[i, j])) * c[0]
+
+ ######################################################################
def trivial_prompts_and_answers(self, prompts, answers):
S = self.height * self.width
f_Bs = answers
return (Bs == f_Bs).long().min(dim=-1).values > 0
- def generate_prompts_and_answers(
- self, nb, tasks=None, progress_bar=False, device="cpu"
- ):
+ def generate_prompts_and_answers_(self, nb, tasks=None, progress_bar=False):
if tasks is None:
- tasks = self.all_tasks()
+ tasks = self.all_tasks
S = self.height * self.width
prompts = torch.zeros(nb, 3 * S + 2, dtype=torch.int64)
f_A = prompt[1 * (S + 1) : 1 * (S + 1) + S].view(self.height, self.width)
B = prompt[2 * (S + 1) : 2 * (S + 1) + S].view(self.height, self.width)
f_B = answer.view(self.height, self.width)
- task = tasks[torch.randint(len(tasks), (1,))]
+ task = tasks[torch.randint(len(tasks), (1,)).item()]
task(A, f_A, B, f_B)
return prompts.flatten(1), answers.flatten(1)
- def save_quizzes(
+ def save_quiz_illustrations(
self,
result_dir,
filename_prefix,
nrow,
)
+ def save_some_examples(self, result_dir):
+ nb, nrow = 72, 4
+ for t in self.all_tasks:
+ print(t.__name__)
+ prompts, answers = self.generate_prompts_and_answers_(nb, tasks=[t])
+ self.save_quiz_illustrations(
+ result_dir, t.__name__, prompts[:nb], answers[:nb], nrow=nrow
+ )
+
######################################################################
if __name__ == "__main__":
import time
+ # grids = Grids(max_nb_cached_chunks=5, chunk_size=100, nb_threads=4)
grids = Grids()
# nb = 1000
# print(f"{prompts.size(0)/delay:02f} seq/s")
# exit(0)
- if True:
- nb = 72
+ # if True:
+ nb, nrow = 72, 4
+ # nb, nrow = 8, 2
- for t in grids.all_tasks():
- # for t in [grids.task_ortho]:
- print(t.__name__)
- prompts, answers = grids.generate_prompts_and_answers(nb, tasks=[t])
- grids.save_quizzes("/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=4)
+ # for t in grids.all_tasks:
+ for t in [grids.task_distance]:
+ print(t.__name__)
+ prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
+ grids.save_quiz_illustrations(
+ "/tmp", t.__name__, prompts[:nb], answers[:nb], nrow=nrow
+ )
- exit(0)
+ # exit(0)
- nb = 500
+ nb = 1000
- for t in grids.all_tasks():
+ # for t in grids.all_tasks:
+ for t in [grids.task_distance]:
start_time = time.perf_counter()
- prompts, answers = grids.generate_prompts_and_answers(nb, tasks=[t])
+ prompts, answers = grids.generate_prompts_and_answers_(nb, tasks=[t])
delay = time.perf_counter() - start_time
print(f"{t.__name__} {prompts.size(0)/delay:02f} seq/s")
predicted_prompts = m * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
predicted_answers = (1 - m) * (torch.randint(2, (prompts.size(0),)) * 2 - 1)
- grids.save_quizzes(
+ grids.save_quiz_illustrations(
"/tmp",
"test",
prompts[:nb],
import mygpt
import sky, grids, quiz_machine
-from problem import MultiThreadProblem
-# world quizzes vs. culture quizzes
+import threading
-######################################################################
-
-if torch.cuda.is_available():
- device = torch.device("cuda")
- torch.backends.cuda.matmul.allow_tf32 = True
-else:
- device = torch.device("cpu")
+import torch.multiprocessing as mp
######################################################################
parser = argparse.ArgumentParser(
- description="An implementation of GPT with cache.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--seed", type=int, default=0)
-parser.add_argument("--max_percents_of_test_in_train", type=int, default=1)
+parser.add_argument("--resume", action="store_true", default=False)
+
+parser.add_argument("--max_percents_of_test_in_train", type=int, default=-1)
########################################
parser.add_argument("--nb_test_samples", type=int, default=None)
+parser.add_argument("--nb_new_c_quizzes_for_train", type=int, default=None)
+
+parser.add_argument("--nb_new_c_quizzes_for_test", type=int, default=None)
+
parser.add_argument("--learning_rate", type=float, default=5e-4)
########################################
parser.add_argument("--problem", type=str, default="grids")
-parser.add_argument("--multi_thread_problem", action="store_true", default=False)
+parser.add_argument("--nb_threads", type=int, default=1)
+
+parser.add_argument("--gpus", type=str, default="all")
parser.add_argument("--nb_gpts", type=int, default=5)
-parser.add_argument("--min_to_validate", type=int, default=None)
+parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.9)
+
+parser.add_argument("--proba_understands", type=float, default=0.9)
-parser.add_argument("--max_to_validate", type=int, default=None)
+parser.add_argument("--proba_not_understands", type=float, default=0.5)
-parser.add_argument("--accuracy_to_make_c_quizzes", type=float, default=0.975)
+parser.add_argument("--generation_temperature", type=float, default=1.0)
-parser.add_argument("--generation_temperature", type=float, default=2.0)
+parser.add_argument("--dirty_debug", action="store_true", default=False)
-parser.add_argument("--deterministic_validation", action="store_true", default=False)
+######################################################################
-parser.add_argument("--bidirectional_validation", action="store_true", default=False)
+grids_tasks = ", ".join(
+ [x.__name__.removeprefix("task_") for x in grids.Grids().all_tasks]
+)
-parser.add_argument("--dirty_debug", action="store_true", default=False)
+parser.add_argument(
+ "--grids_tasks",
+ type=str,
+ default=None,
+ help="A comma-separated subset of: " + grids_tasks + ", or None for all.",
+)
######################################################################
args = parser.parse_args()
-if args.min_to_validate is None:
- args.min_to_validate = args.nb_gpts - 1
-
-if args.max_to_validate is None:
- args.max_to_validate = args.nb_gpts - 1
-
if args.result_dir is None:
args.result_dir = f"results_culture"
default_args = {
"model": "37M",
- "batch_size": 100,
+ "batch_size": 25,
"nb_train_samples": 100000,
"nb_test_samples": 10000,
}
######################################################################
-try:
- os.mkdir(args.result_dir)
-except FileExistsError:
- print(f"result directory {args.result_dir} already exists")
- exit(1)
+if args.resume:
+ assert os.path.isdir(args.result_dir)
+
+else:
+ try:
+ os.mkdir(args.result_dir)
+ except FileExistsError:
+ print(f"result directory {args.result_dir} already exists")
+ exit(1)
log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
sys.stdout.flush()
+now = time.strftime("%Y%m%d-%H%M%S", time.localtime())
+
+os.system(f"tar zcvf {args.result_dir}/src-{now}.tgz *.py")
+
log_string(f"argv {' '.join(sys.argv)}")
for n in vars(args):
######################################################################
+if args.gpus == "all":
+ gpus_idx = range(torch.cuda.device_count())
+else:
+ gpus_idx = [int(k) for k in args.gpus.split(",")]
+
+gpus = [torch.device(f"cuda:{n}") for n in gpus_idx]
+
+if torch.cuda.is_available():
+ main_device = gpus[0]
+else:
+ assert len(gpus) == 0
+ main_device = torch.device("cpu")
+
if args.dirty_debug:
args.nb_train_samples = 2500
args.nb_test_samples = 100
nb_birds=args.sky_nb_birds,
nb_iterations=args.sky_nb_iterations,
speed=args.sky_speed,
+ max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
+ chunk_size=100,
+ nb_threads=args.nb_threads,
)
back_accuracy = False
elif args.problem == "grids":
- problem = grids.Grids(device=device)
+ problem = grids.Grids(
+ max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
+ chunk_size=100,
+ nb_threads=args.nb_threads,
+ tasks=args.grids_tasks,
+ )
back_accuracy = True
else:
raise ValueError
-if args.multi_thread_problem:
- problem = MultiThreadProblem(problem, args.nb_train_samples, chunk_size=1000)
+problem.save_some_examples(args.result_dir)
quiz_machine = quiz_machine.QuizMachine(
problem=problem,
batch_size=args.physical_batch_size,
result_dir=args.result_dir,
logger=log_string,
- device=device,
+ device=main_device,
)
######################################################################
-log_string(f"device {device}")
+log_string(f"main_device {main_device} gpus {[ str(g) for g in gpus]}")
vocabulary_size = quiz_machine.vocabulary_size()
######################################################################
-# Compute the entropy of the training tokens
-token_count = 0
-for input in quiz_machine.batches(split="train", desc="train-entropy"):
- token_count += F.one_hot(input, num_classes=quiz_machine.vocabulary_size()).sum(
- (0, 1)
- )
-token_probas = token_count / token_count.sum()
-entropy = -torch.xlogy(token_probas, token_probas).sum()
-train_set_perplexity = math.exp(entropy)
+def run_tests(model, quiz_machine, deterministic_synthesis, local_device=main_device):
+ with torch.autograd.no_grad():
+ model.eval().to(local_device)
-######################################################################
-# A bit of paranoia never hurts
+ nb_test_samples, acc_test_loss = 0, 0.0
+ nb_samples_accumulated = 0
-if args.max_percents_of_test_in_train >= 0:
+ for input in quiz_machine.batches(model, split="test"):
+ input = input.to(local_device)
- def subsets_as_tuples(batches, cs):
- s = set()
- for batch in batches:
- for x in batch:
- s.add(tuple([v.item() for v in x]))
- if len(s) == cs:
- yield s
- s = set()
- yield s
+ bs = model(mygpt.BracketedSequence(input))
+ output = bs.x
- nb_test, nb_in_train = 0, 0
- for test_subset in subsets_as_tuples(
- quiz_machine.batches(split="test", desc="test-check"), 25000
- ):
- in_train = set()
- for train_subset in subsets_as_tuples(
- quiz_machine.batches(split="train", desc="train-check"), 25000
- ):
- in_train.update(test_subset.intersection(train_subset))
- nb_in_train += len(in_train)
- nb_test += len(test_subset)
+ loss = F.cross_entropy(output.transpose(1, 2), input)
- log_string(
- f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set"
- )
+ acc_test_loss += loss.item() * input.size(0)
- assert (
- nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100
- ), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set"
+ nb_test_samples += input.size(0)
-##############################
+ test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
+ log_string(f"test_perplexity {n_epoch} model {model.id} {test_perplexity}")
-def one_epoch(model, quiz_machine):
- optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+ model.main_test_accuracy = quiz_machine.produce_results(
+ n_epoch=n_epoch,
+ model=model,
+ result_dir=args.result_dir,
+ deterministic_synthesis=deterministic_synthesis,
+ )
- model.train()
+
+def one_epoch(model, quiz_machine, local_device=main_device):
+ model.to(local_device).train()
+
+ optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
nb_train_samples, acc_train_loss = 0, 0.0
- for input in quiz_machine.batches(split="train"):
- input = input.to(device)
+ for input in quiz_machine.batches(model, split="train"):
+ input = input.to(local_device)
if nb_train_samples % args.batch_size == 0:
optimizer.zero_grad()
train_perplexity = math.exp(min(100, acc_train_loss / nb_train_samples))
- log_string(f"train_perplexity {n_epoch} {train_perplexity}")
+ log_string(f"train_perplexity {n_epoch} model {model.id} {train_perplexity}")
+ run_tests(model, quiz_machine, deterministic_synthesis=False)
-######################################################################
+ model.to(main_device)
-def run_tests(model, quiz_machine, deterministic_synthesis):
- with torch.autograd.no_grad():
- model.eval()
-
- nb_test_samples, acc_test_loss = 0, 0.0
- nb_samples_accumulated = 0
-
- for input in quiz_machine.batches(split="test"):
- input = input.to(device)
-
- bs = model(mygpt.BracketedSequence(input))
- output = bs.x
+######################################################################
- loss = F.cross_entropy(output.transpose(1, 2), input)
+# This is the key routine that decides what generated quizzes to keep
- acc_test_loss += loss.item() * input.size(0)
- nb_test_samples += input.size(0)
+# token_logprobas are NxMxT where M is the number of models
- test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
- log_string(f"test_perplexity {n_epoch} {test_perplexity}")
+def compute_valid_quizzes_(token_logprobas):
+ warnings.warn("validation with uniform constraints", RuntimeWarning)
+ l = token_logprobas.min(dim=-1).values.sort(dim=-1).values
+ return (l[:, 0] < math.log(0.1)) & (l[:, 1] > math.log(0.5))
- model.main_test_accuracy = quiz_machine.produce_results(
- n_epoch=n_epoch,
- model=model,
- result_dir=args.result_dir,
- deterministic_synthesis=deterministic_synthesis,
- )
+def compute_valid_quizzes(token_logprobas):
+ l = token_logprobas.sum(dim=-1).sort(dim=-1).values
+ return (l[:, 0] < math.log(args.proba_not_understands)) & (
+ l[:, 1] > math.log(args.proba_understands)
+ )
-######################################################################
+def extract_valid_quizzes_and_logprobas(recorded):
+ validated_quizzes, validated_logprobas = [], []
+ for quizzes, token_logprobas in recorded:
+ validated_indices = compute_valid_quizzes(token_logprobas)
+ validated_quizzes.append(quizzes[validated_indices])
+ validated_logprobas.append(token_logprobas[validated_indices])
-def valid_c_quizzes(recorded, criteria):
- result = [q[criteria(c)] for q, c in recorded]
- return torch.cat(result, dim=0) if len(result) > 0 else torch.tensor([])
+ if len(validated_quizzes) > 0:
+ return torch.cat(validated_quizzes, dim=0), torch.cat(
+ validated_logprobas, dim=0
+ )
+ else:
+ return None, None
######################################################################
-def create_c_quizzes(
- models,
- quiz_machine,
- nb_for_train=1000,
- nb_for_test=100,
-):
- quizzes_and_nb_correct_records = []
-
+def create_c_quizzes(models, quiz_machine, nb_for_train=1000, nb_for_test=100):
nb_to_create = nb_for_train + nb_for_test
- # ------------------------------------------------------------
+ recorded_quizzes_logprobas = []
- standard_validity = lambda nb_correct: torch.logical_and(
- nb_correct >= args.min_to_validate, nb_correct <= args.max_to_validate
- )
-
- file_name = os.path.join(args.result_dir, f"culture_c_quiz_{n_epoch:04d}_logp.dat")
+ nb_validated = 0
- with open(file_name, "w") as logp_file:
- while (
- valid_c_quizzes(quizzes_and_nb_correct_records, standard_validity).size(0)
- < nb_to_create
- ):
- # Select a model at random to generate the new quizzes
+ while nb_validated < nb_to_create:
+ model_for_generation = models[torch.randint(len(models), (1,))]
- model_for_generation = models[torch.randint(len(models), (1,))]
-
- c_quizzes = quiz_machine.generate_quizzes(
- nb_to_create,
- model_for_generation=model_for_generation,
- temperature=args.generation_temperature,
- )
-
- c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)]
-
- if c_quizzes.size(0) > 0:
- nb_correct, seq_logproba = quiz_machine.compute_correctness(
- c_quizzes,
- models,
- bidirectional_validation=args.bidirectional_validation,
- deterministic_validation=args.deterministic_validation,
- )
-
- for n, l in zip(nb_correct, seq_logproba):
- s = " ".join([str(x.item()) for x in l])
- logp_file.write(f"{n} {s}\n")
+ c_quizzes = quiz_machine.generate_quizzes(
+ nb_to_create,
+ model_for_generation=model_for_generation,
+ temperature=args.generation_temperature,
+ )
- if args.dirty_debug:
- nb_correct = torch.randint(
- len(models) + 1, nb_correct.size(), device=c_quizzes.device
- )
+ c_quizzes = c_quizzes[quiz_machine.non_trivial(c_quizzes)]
- quizzes_and_nb_correct_records.append((c_quizzes, nb_correct))
+ if c_quizzes.size(0) > 0:
+ token_logproba = quiz_machine.solution_token_logprobas(models, c_quizzes)
+ recorded_quizzes_logprobas.append((c_quizzes, token_logproba))
- nv = F.one_hot(nb_correct, num_classes=len(models) + 1).sum(0)
- nv = " ".join([str(x.item()) for x in nv])
+ (
+ validated_quizzes,
+ validated_logprobas,
+ ) = extract_valid_quizzes_and_logprobas(recorded_quizzes_logprobas)
- nb_validated = valid_c_quizzes(
- quizzes_and_nb_correct_records, standard_validity
- ).size(0)
+ if validated_quizzes is not None:
+ nb_validated = validated_quizzes.size(0)
- log_string(
- f"keep c_quizzes model {model_for_generation.id} kept {nv} nb_accumulated {nb_validated} / {nb_to_create}"
- )
+ log_string(
+ f"keep c_quizzes model {model_for_generation.id} nb_accumulated {nb_validated} / {nb_to_create}"
+ )
# store the new c_quizzes which have been validated
- new_c_quizzes = valid_c_quizzes(quizzes_and_nb_correct_records, standard_validity)
-
- quiz_machine.reverse_random_half_in_place(new_c_quizzes)
-
- quiz_machine.store_c_quizzes(new_c_quizzes[:nb_for_train], for_train=True)
- quiz_machine.store_c_quizzes(new_c_quizzes[nb_for_train:], for_train=False)
+ quiz_machine.reverse_random_half_in_place(validated_quizzes)
+ quiz_machine.store_c_quizzes(validated_quizzes[:nb_for_train], for_train=True)
+ quiz_machine.store_c_quizzes(
+ validated_quizzes[nb_for_train:nb_to_create], for_train=False
+ )
- # save a bunch of images to investigate what quizzes with a
- # certain nb of correct predictions look like
+ ######################################################################
+ # save images with their logprobas
- for n in range(len(models) + 1):
- s = (
- "_validated"
- if n >= args.min_to_validate and n <= args.max_to_validate
- else ""
- )
+ vq = validated_quizzes[:72]
+ vl = validated_logprobas[:72]
- q = valid_c_quizzes(
- quizzes_and_nb_correct_records, criteria=lambda nb_correct: nb_correct == n
- )[:72]
+ if vq.size(0) > 0:
+ prefix = f"culture_c_quiz_{n_epoch:04d}"
+ filename = os.path.join(args.result_dir, prefix + "_logp.pth")
+ torch.save(vl, filename)
+ # with open(file_name, "w") as logp_file:
+ # for l in vl:
+ # s = " ".join([str(x.item()) for x in l])
+ # logp_file.write(s + "\n")
- quiz_machine.reverse_random_half_in_place(q)
-
- if q.size(0) > 0:
- quiz_machine.save_quizzes(
- args.result_dir, f"culture_c_quiz_{n_epoch:04d}_N{n}{s}", q
- )
+ quiz_machine.save_quiz_illustrations(args.result_dir, prefix, vq)
######################################################################
models = []
for k in range(args.nb_gpts):
+ log_string(f"creating model {k} and its w_quizzes")
model = mygpt.MyGPT(
vocabulary_size=vocabulary_size,
dim_model=args.dim_model,
nb_blocks=args.nb_blocks,
causal=True,
dropout=args.dropout,
- ).to(device)
+ ).to(main_device)
model.main_test_accuracy = 0.0
model.id = k
+ model.train_w_quizzes = quiz_machine.generate_token_sequences(args.nb_train_samples)
+ quiz_machine.reverse_random_half_in_place(model.train_w_quizzes)
+ model.test_w_quizzes = quiz_machine.generate_token_sequences(args.nb_test_samples)
+ quiz_machine.reverse_random_half_in_place(model.test_w_quizzes)
+
models.append(model)
+######################################################################
+
+if args.resume:
+ try:
+ for model in models:
+ filename = f"gpt_{model.id:03d}.pth"
+
+ try:
+ d = torch.load(os.path.join(args.result_dir, filename))
+ model.load_state_dict(d[0])
+ model.main_test_accuracy = d[1]
+ log_string(f"successfully loaded {filename}")
+ except FileNotFoundError:
+ log_string(f"cannot find {filename}")
+ pass
+
+ try:
+ filename = "c_quizzes.pth"
+ quiz_machine.load_c_quizzes(os.path.join(args.result_dir, filename))
+ log_string(f"successfully loaded {filename}")
+ except FileNotFoundError:
+ log_string(f"cannot find {filename}")
+ pass
+
+ except:
+ log_string(f"error when loading {filename}.")
+ exit(1)
+
+######################################################################
nb_parameters = sum(p.numel() for p in models[0].parameters())
log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
######################################################################
-nb_new_c_quizzes_for_train = args.nb_train_samples // 50
-nb_new_c_quizzes_for_test = args.nb_test_samples // 50
+# Compute the entropy of the training tokens
+
+token_count = 0
+for input in quiz_machine.batches(models[0], split="train", desc="train-entropy"):
+ token_count += F.one_hot(input, num_classes=quiz_machine.vocabulary_size()).sum(
+ (0, 1)
+ )
+token_probas = token_count / token_count.sum()
+entropy = -torch.xlogy(token_probas, token_probas).sum()
+train_set_perplexity = math.exp(entropy)
+
+######################################################################
+# A bit of paranoia never hurts
+
+if args.max_percents_of_test_in_train >= 0:
+
+ def subsets_as_tuples(batches, cs):
+ s = set()
+ for batch in batches:
+ for x in batch:
+ s.add(tuple([v.item() for v in x]))
+ if len(s) == cs:
+ yield s
+ s = set()
+ yield s
+
+ nb_test, nb_in_train = 0, 0
+ for test_subset in subsets_as_tuples(
+ quiz_machine.batches(models[0], split="test", desc="test-check"), 25000
+ ):
+ in_train = set()
+ for train_subset in subsets_as_tuples(
+ quiz_machine.batches(models[0], split="train", desc="train-check"), 25000
+ ):
+ in_train.update(test_subset.intersection(train_subset))
+ nb_in_train += len(in_train)
+ nb_test += len(test_subset)
+
+ log_string(
+ f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set"
+ )
+
+ assert (
+ nb_in_train <= args.max_percents_of_test_in_train * nb_test / 100
+ ), f"More than {args.max_percents_of_test_in_train}% of test samples are in the train set"
+
+######################################################################
+
+if args.nb_new_c_quizzes_for_train is None:
+ args.nb_new_c_quizzes_for_train = args.nb_train_samples // 50
+
+if args.nb_new_c_quizzes_for_test is None:
+ args.nb_new_c_quizzes_for_test = args.nb_test_samples // 50
log_string(
- f"nb_new_c_quizzes_for_train {nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {nb_new_c_quizzes_for_test}"
+ f"nb_new_c_quizzes_for_train {args.nb_new_c_quizzes_for_train} nb_new_c_quizzes_for_test {args.nb_new_c_quizzes_for_test}"
)
######################################################################
if args.dirty_debug:
args.accuracy_to_make_c_quizzes = 0.0
args.nb_gpts = 2
- nb_new_c_quizzes_for_train = 100
- nb_new_c_quizzes_for_test = 10
+ args.nb_new_c_quizzes_for_train = 100
+ args.nb_new_c_quizzes_for_test = 10
+
######################################################################
cta = " ".join([f"{float(m.main_test_accuracy):.04f}" for m in models])
log_string(f"current_test_accuracies {cta}")
+ ##################################################
+ # If all the models are good enough, generate new quizzes and
+ # re-compute the test errors
+
+ if min([m.main_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes:
+ create_c_quizzes(
+ models,
+ quiz_machine,
+ nb_for_train=args.nb_new_c_quizzes_for_train,
+ nb_for_test=args.nb_new_c_quizzes_for_test,
+ )
+
+ filename = "c_quizzes.pth"
+ quiz_machine.save_c_quizzes(os.path.join(args.result_dir, filename))
+ log_string(f"wrote {filename}")
+
+ # Force one epoch of training
+ for model in models:
+ model.main_test_accuracy = 0.0
+
##################################################
# Select, improve, and eval the worst model
- weakest_model = min(models, key=lambda m: float(m.main_test_accuracy))
+ ranked_models = sorted(models, key=lambda m: float(m.main_test_accuracy))
- log_string(
- f"training model {weakest_model.id} main_test_accuracy {weakest_model.main_test_accuracy}"
- )
+ weakest_models = ranked_models[: len(gpus)]
- one_epoch(weakest_model, quiz_machine)
+ threads = []
- log_string(
- f"train_set_composition w_quizzes {quiz_machine.nb_batch_w_quizzes} c_quizzes {quiz_machine.nb_batch_c_quizzes}"
- )
+ for gpu, model in zip(gpus, weakest_models):
+ log_string(f"training model {model.id}")
- run_tests(weakest_model, quiz_machine, deterministic_synthesis=False)
+ t = threading.Thread(
+ target=one_epoch, daemon=True, args=(model, quiz_machine, gpu)
+ )
- log_string(
- f"test_set_composition w_quizzes {quiz_machine.nb_batch_w_quizzes} c_quizzes {quiz_machine.nb_batch_c_quizzes}"
- )
+ threads.append(t)
- ##################################################
- # Replace a fraction of the w_quizzes with fresh ones
+ t.start()
- quiz_machine.renew_w_quizzes(args.nb_train_samples // args.nb_gpts)
+ for t in threads:
+ t.join()
- ##################################################
- # If all the models are good enough, generate new quizzes and
- # re-compute the test errors
+ # Save the models to disk
- if min([m.main_test_accuracy for m in models]) >= args.accuracy_to_make_c_quizzes:
- create_c_quizzes(
- models,
- quiz_machine,
- nb_for_train=nb_new_c_quizzes_for_train,
- nb_for_test=nb_new_c_quizzes_for_test,
+ for model in weakest_models:
+ filename = f"gpt_{model.id:03d}.pth"
+ torch.save(
+ (model.state_dict(), model.main_test_accuracy),
+ os.path.join(args.result_dir, filename),
)
+ log_string(f"wrote {filename}")
- for model in models:
- run_tests(model, quiz_machine, deterministic_synthesis=False)
+ # Renew the training samples
+
+ for model in weakest_models:
+ quiz_machine.renew_w_quizzes(model, args.nb_train_samples)
######################################################################
class Problem:
+ def __init__(self, max_nb_cached_chunks=None, chunk_size=None, nb_threads=-1):
+ if nb_threads > 0:
+ self.chunk_size = chunk_size
+ self.queue = queue.Queue(maxsize=max_nb_cached_chunks)
+ for _ in range(nb_threads):
+ threading.Thread(target=self.fill_cache, daemon=True).start()
+ self.rest = None
+ else:
+ self.queue = None
+
+ def nb_cached_quizzes(self):
+ if self.queue is None:
+ return None
+ else:
+ return self.queue.qsize() * self.chunk_size
+
def nb_token_values(self):
pass
def trivial_prompts_and_answers(self, prompts, answers):
pass
- # returns two tensors nb x D and nb x D'
- def generate_prompts_and_answers(self, nb):
+ # The one to implement, returns two tensors nb x D and nb x D'
+ def generate_prompts_and_answers_(self, nb):
pass
# save a file to vizualize quizzes, you can save a txt or png file
- def save_quizzes(
+ def save_quiz_illustrations(
self,
result_dir,
filename_prefix,
):
pass
-
-class MultiThreadProblem:
- def __init__(self, problem, max_nb_cached_chunks, chunk_size, nb_threads=1):
- self.problem = problem
- self.chunk_size = chunk_size
- self.queue = queue.Queue(maxsize=max_nb_cached_chunks)
- for _ in range(nb_threads):
- threading.Thread(target=self.fill_cache, daemon=True).start()
- self.rest = None
-
- def nb_token_values(self):
- return self.problem.nb_token_values()
-
- def save_quizzes(
- self,
- result_dir,
- filename_prefix,
- prompts,
- answers,
- predicted_prompts=None,
- predicted_answers=None,
- ):
- self.problem.save_quizzes(
- result_dir,
- filename_prefix,
- prompts,
- answers,
- predicted_prompts=None,
- predicted_answers=None,
- )
-
def fill_cache(self):
while True:
- prompts, answers = self.problem.generate_prompts_and_answers(
- self.chunk_size
- )
+ prompts, answers = self.generate_prompts_and_answers_(self.chunk_size)
self.queue.put((prompts.to("cpu"), answers.to("cpu")), block=True)
- def trivial_prompts_and_answers(self, prompts, answers):
- return self.problem.trivial_prompts_and_answers(prompts, answers)
-
def generate_prompts_and_answers(self, nb):
+ if self.queue is None:
+ return self.generate_prompts_and_answers_(nb)
+
if self.rest is not None:
prompts, answers = rest
else:
prompts, answers = prompts[:-k], answers[:-k]
return prompts, answers
+
+ def save_some_examples(self, result_dir):
+ pass
# Written by Francois Fleuret <francois@fleuret.org>
-import math, os, tqdm, warnings
+import math, os, tqdm, warnings, sys
import torch, torchvision
import mygpt
from mygpt import BracketedSequence
+import threading
+
+######################################################################
+# if output is log(P(X=y)) and target is Y, returns -log P(X=Y) + H(X
+# | X != Y)
+
+
+# output is NxCxT and target is NxT
+def confusion(output, target, reduction="mean"):
+ N, C, T = output.shape
+ output = output.permute(0, 2, 1).reshape(-1, C)
+ target = target.flatten()
+ all_t = torch.arange(N * T, device=output.device)
+ output = output.log_softmax(dim=-1)
+ result = -output[all_t, target]
+
+ output[all_t, target] = float("-inf")
+ output = output.log_softmax(dim=-1)
+ e = output.exp()
+ output[all_t, target] = 0
+ result = result - (output * e).sum(-1)
+
+ if reduction == "none":
+ return result.reshape(N, T)
+ elif reduction == "mean":
+ return result.reshape(N, T).mean()
+ elif reduction == "sum":
+ return result.reshape(N, T).sum()
+ else:
+ raise ValueError(f"unknown reduction '{reduction}'.")
+
+
######################################################################
# ar_mask is a tensor with 0s and 1s, of same shape as input, with
self.prompt_len = None
self.answer_len = None
- self.train_w_quizzes = self.generate_token_sequences(nb_train_samples)
- self.reverse_random_half_in_place(self.train_w_quizzes)
- self.train_w_quizzes = self.train_w_quizzes.to(device)
-
- self.test_w_quizzes = self.generate_token_sequences(nb_test_samples).to(device)
- self.reverse_random_half_in_place(self.test_w_quizzes)
- self.test_w_quizzes = self.test_w_quizzes.to(device)
-
+ self.LOCK_C_QUIZZES = threading.Lock()
self.train_c_quizzes = []
self.test_c_quizzes = []
- if result_dir is not None:
- self.save_quizzes(
- result_dir,
- "culture_w_quizzes",
- self.train_w_quizzes[:72],
- )
-
- def save_quizzes(
+ def save_quiz_illustrations(
self,
result_dir,
filename_prefix,
quizzes,
mistakes=None,
):
- quizzes = quizzes.clone()
+ quizzes = quizzes.clone().to("cpu")
n_forward = quizzes[quizzes[:, 0] == self.token_forward]
n_backward = quizzes[:, 0] == self.token_backward
backward = quizzes[n_backward]
predicted_answers = 1 - predicted_prompts
if mistakes is not None:
# 0/-1/+1 ~ not-to-predict / predicted wrong / predicted correct
- predicted_prompts *= mistakes
- predicted_answers *= mistakes
+ predicted_prompts *= mistakes.to("cpu")
+ predicted_answers *= mistakes.to("cpu")
else:
# 0/2 ~ not-to-predict / to predict
predicted_prompts *= 2
predicted_answers *= 2
- self.problem.save_quizzes(
+ self.problem.save_quiz_illustrations(
result_dir,
filename_prefix,
quizzes[:, 1 : 1 + self.prompt_len],
predicted_answers,
)
- def batches(self, split="train", desc=None):
- assert split in {"train", "test"}
- if split == "train":
- w_quizzes = self.train_w_quizzes
- c_quizzes = self.train_c_quizzes
- else:
- w_quizzes = self.test_w_quizzes
- c_quizzes = self.test_c_quizzes
+ def vocabulary_size(self):
+ return self.nb_token_values
- if len(c_quizzes) > 0:
- c_quizzes = torch.cat(c_quizzes, dim=0)
- if c_quizzes.size(0) > w_quizzes.size(0) // 2:
- i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2]
- c_quizzes = c_quizzes[i]
+ ######################################################################
- i = torch.randperm(w_quizzes.size(0))[
- : w_quizzes.size(0) - c_quizzes.size(0)
- ]
- w_quizzes = w_quizzes[i]
+ def batches(self, model, split="train", desc=None):
+ assert split in {"train", "test"}
- self.nb_batch_w_quizzes = w_quizzes.size(0)
- self.nb_batch_c_quizzes = c_quizzes.size(0)
+ with self.LOCK_C_QUIZZES:
+ if split == "train":
+ w_quizzes = model.train_w_quizzes
+ c_quizzes = self.train_c_quizzes
+ else:
+ w_quizzes = model.test_w_quizzes
+ c_quizzes = self.test_c_quizzes
+
+ if len(c_quizzes) > 0:
+ c_quizzes = torch.cat(c_quizzes, dim=0)
+ if c_quizzes.size(0) > w_quizzes.size(0) // 2:
+ i = torch.randperm(c_quizzes.size(0))[: w_quizzes.size(0) // 2]
+ c_quizzes = c_quizzes[i]
+
+ i = torch.randperm(w_quizzes.size(0))[
+ : w_quizzes.size(0) - c_quizzes.size(0)
+ ]
+ w_quizzes = w_quizzes[i]
- input = torch.cat([w_quizzes, c_quizzes], dim=0)
- else:
- input = w_quizzes
- self.nb_batch_w_quizzes = w_quizzes.size(0)
- self.nb_batch_c_quizzes = 0
+ self.nb_batch_w_quizzes = w_quizzes.size(0)
+ self.nb_batch_c_quizzes = c_quizzes.size(0)
+
+ input = torch.cat([w_quizzes, c_quizzes], dim=0)
+ else:
+ input = w_quizzes
+ self.nb_batch_w_quizzes = w_quizzes.size(0)
+ self.nb_batch_c_quizzes = 0
# Shuffle
input = input[torch.randperm(input.size(0))]
):
yield batch
- def vocabulary_size(self):
- return self.nb_token_values
+ ######################################################################
def produce_results(
self, n_epoch, model, result_dir, deterministic_synthesis, nmax=1000
):
def compute_accuracy(input, log_prefix=None):
+ input = input.to(self.device)
ar_mask = self.make_ar_mask(input)
result = input.clone() * (1 - ar_mask)
seq_logproba = torch.empty(input.size(0), device=self.device)
backward_nb_total = correct[n_backward].size(0)
self.logger(
- f"{log_prefix}_forward_accuracy {n_epoch} model {model.id} nb_correct {forward_nb_correct} / {forward_nb_total} ({forward_nb_correct*100/forward_nb_total} %)"
- )
-
- self.logger(
- f"{log_prefix}_backward_accuracy {n_epoch} model {model.id} nb_correct {backward_nb_correct} / {backward_nb_total} ({backward_nb_correct*100/backward_nb_total} %)"
+ f"{log_prefix}_accuracy {n_epoch} model {model.id} forward {forward_nb_correct} / {forward_nb_total} backward {backward_nb_correct} / {backward_nb_total}"
)
return result, correct
- compute_accuracy(self.train_w_quizzes[:nmax], log_prefix="train")
+ # compute_accuracy(model.train_w_quizzes[:nmax], log_prefix="train")
test_result, test_correct = compute_accuracy(
- self.test_w_quizzes[:nmax], log_prefix="test"
+ model.test_w_quizzes[:nmax], log_prefix="test"
)
main_test_accuracy = test_correct.sum() / test_correct.size(0)
##############################
- self.save_quizzes(
+ self.save_quiz_illustrations(
result_dir,
f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
quizzes=test_result[:72],
return main_test_accuracy
- def renew_w_quizzes(self, nb, for_train=True):
- input = self.train_w_quizzes if for_train else self.test_w_quizzes
+ ######################################################################
+
+ def renew_w_quizzes(self, model, nb, for_train=True):
+ input = model.train_w_quizzes if for_train else model.test_w_quizzes
nb = min(nb, input.size(0))
input[:-nb] = input[nb:].clone()
fresh_w_quizzes = self.generate_token_sequences(nb)
self.reverse_random_half_in_place(fresh_w_quizzes)
- input[-nb:] = fresh_w_quizzes.to(self.device)
+ input[-nb:] = fresh_w_quizzes.to("cpu")
+
+ ######################################################################
def store_c_quizzes(self, new_c_quizzes, for_train=True):
- if for_train:
- self.train_c_quizzes.append(new_c_quizzes)
- else:
- self.test_c_quizzes.append(new_c_quizzes)
+ with self.LOCK_C_QUIZZES:
+ if for_train:
+ self.train_c_quizzes.append(new_c_quizzes.to("cpu"))
+ else:
+ self.test_c_quizzes.append(new_c_quizzes.to("cpu"))
+
+ def save_c_quizzes(self, filename):
+ torch.save((self.train_c_quizzes, self.test_c_quizzes), filename)
+
+ def load_c_quizzes(self, filename):
+ self.train_c_quizzes, self.test_c_quizzes = torch.load(filename)
+
+ ######################################################################
+
+ def solution_token_logprobas(self, models, c_quizzes):
+ logproba = c_quizzes.new_zeros(
+ c_quizzes.size(0),
+ len(models),
+ c_quizzes.size(1),
+ device=self.device,
+ dtype=torch.float32,
+ )
+
+ for model in models:
+ with torch.autograd.no_grad():
+ t = model.training
+ model.eval()
+
+ for input, l in zip(
+ c_quizzes.split(self.batch_size), logproba.split(self.batch_size)
+ ):
+ input = input.to(self.device)
+ ar_mask = self.make_ar_mask(input)
+ output = model(mygpt.BracketedSequence(input)).x
+ l[:, model.id] = (
+ -F.cross_entropy(
+ output.transpose(1, 2), input, reduction="none"
+ )
+ * ar_mask
+ )
+
+ model.train(t)
+
+ return logproba.to("cpu")
+
+ ###############################################################
def compute_correctness(
self,
def generate_quizzes(self, nb, model_for_generation, temperature=1.0):
c_quizzes = torch.empty(
- nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
+ nb,
+ self.prompt_len + self.answer_len + 2,
+ device=self.device,
+ dtype=torch.int64,
)
seq_logproba = torch.zeros(nb, device=self.device)
device=self.device,
)
- return c_quizzes
+ return c_quizzes.to("cpu")
speed=2,
nb_iterations=2,
avoid_collision=True,
+ max_nb_cached_chunks=None,
+ chunk_size=None,
+ nb_threads=-1,
):
+ super().__init__(max_nb_cached_chunks, chunk_size, nb_threads)
self.height = height
self.width = width
self.nb_birds = nb_birds
return prompts, answers
- def save_quizzes(
+ def save_quiz_illustrations(
self,
result_dir,
filename_prefix,
predicted_prompts = torch.randint(3, (prompts.size(0),)) - 1
predicted_answers = torch.randint(3, (prompts.size(0),)) - 1
- sky.save_quizzes(
+ sky.save_quiz_illustrations(
"/tmp", "test", prompts, answers, predicted_prompts, predicted_answers
)