+++ /dev/null
-#!/usr/bin/env python
-
-# Any copyright is dedicated to the Public Domain.
-# https://creativecommons.org/publicdomain/zero/1.0/
-
-# Written by Francois Fleuret <francois@fleuret.org>
-
-import math, re
-
-import torch, torchvision
-
-from torch import nn
-from torch.nn import functional as F
-
-
-def random_var(nb_variables=None, variables=None):
- if variables is None:
- return chr(ord("A") + torch.randint(nb_variables, (1,)).item())
- else:
- l = list(variables)
- return l[torch.randint(len(l), (1,)).item()]
-
-
-def random_expr(variables, operand_max, budget):
- if budget <= 5:
- op = torch.randint(2, (1,)).item()
- if op == 0 and len(variables) > 0:
- return random_var(variables=variables)
- else:
- return str(torch.randint(operand_max + 1, (1,)).item())
- else:
- op = torch.randint(3, (1,)).item()
- if op == 0:
- e = random_expr(variables, operand_max, budget - 2)
- if ("+" in e or "-" in e or "*" in e) and (e[0] != "(" or e[-1] != ")"):
- return "(" + e + ")"
- else:
- return e
- else:
- b = 2 + torch.randint(budget - 5, (1,)).item()
- e1 = random_expr(variables, operand_max, b)
- e2 = random_expr(variables, operand_max, budget - b - 1)
- if op == 1:
- return e1 + "+" + e2
- elif op == 2:
- return e1 + "*" + e2
-
-
-def generate_program(nb_variables, operand_max, length):
- s = ""
- variables = set()
-
- while len(s) < length:
- v = random_var(nb_variables=nb_variables)
- s += v + "=" + random_expr(variables, operand_max, budget=20) + ";"
- variables.add(v)
-
- return s, variables
-
-
-def generate_sequences(nb, nb_variables=5, length=20, operand_max=9, result_max=99):
- assert nb_variables <= 26
- sequences = []
-
- for n in range(nb):
- # We take length itself half of the time, and uniform between
- # 1 and length otherwise. The actual length can be slightly
- # greater
-
- l = min(length, 1 + torch.randint(length * 2, (1,)).item())
- result = None
- while result == None or max(result.values()) > result_max:
- p, v = generate_program(nb_variables, operand_max, l)
- v = ", ".join(['"' + v + '": ' + v for v in v])
- ldict = {}
- exec(p + "result={" + v + "}", globals(), ldict)
- result = ldict["result"]
-
- k = list(result.keys())
- k.sort()
- sequences.append(p + " " + "".join([v + ":" + str(result[v]) + ";" for v in k]))
-
- return sequences
-
-
-def extract_results(seq):
- f = lambda a: (a[0], -1 if a[1] == "" else int(a[1]))
- results = [
- dict([f(tuple(x.split(":"))) for x in re.findall("[A-Z]:[0-9]*", s)])
- for s in seq
- ]
- return results
-
-
-if __name__ == "__main__":
- import time
-
- start_time = time.perf_counter()
- sequences = generate_sequences(1000, length=40)
- end_time = time.perf_counter()
- for s in sequences[:10]:
- print(s)
- print(f"{len(sequences) / (end_time - start_time):.02f} samples per second")
-
- print(extract_results(sequences[:10]))
+++ /dev/null
-#!/usr/bin/env python
-
-# Any copyright is dedicated to the Public Domain.
-# https://creativecommons.org/publicdomain/zero/1.0/
-
-# Written by Francois Fleuret <francois@fleuret.org>
-
-import torch
-
-from torch.nn import functional as F
-
-######################################################################
-
-REWARD_PLUS = 1
-REWARD_NONE = 0
-REWARD_MINUS = -1
-REWARD_UNKNOWN = 2
-
-
-class GreedWorld:
- def __init__(self, height=6, width=6, T=10, nb_walls=3, nb_coins=2):
- self.height = height
- self.width = width
- self.T = T
- self.nb_walls = nb_walls
- self.nb_coins = nb_coins
-
- self.nb_states_codes = 5
- self.nb_actions_codes = 5
- self.nb_rewards_codes = 3
- self.nb_lookahead_rewards_codes = 4 # stands for -1, 0, +1, and UNKNOWN
-
- self.first_states_code = 0
- self.first_actions_code = self.first_states_code + self.nb_states_codes
- self.first_rewards_code = self.first_actions_code + self.nb_actions_codes
- self.first_lookahead_rewards_code = (
- self.first_rewards_code + self.nb_rewards_codes
- )
- self.nb_codes = (
- self.first_lookahead_rewards_code + self.nb_lookahead_rewards_codes
- )
-
- self.state_len = self.height * self.width
- self.index_lookahead_reward = 0
- self.index_states = 1
- self.index_reward = self.state_len + 1
- self.index_action = self.state_len + 2
- self.it_len = self.state_len + 3 # lookahead_reward / state / reward / action
-
- def state2code(self, r):
- return r + self.first_states_code
-
- def code2state(self, r):
- return r - self.first_states_code
-
- def action2code(self, r):
- return r + self.first_actions_code
-
- def code2action(self, r):
- return r - self.first_actions_code
-
- def reward2code(self, r):
- return r + 1 + self.first_rewards_code
-
- def code2reward(self, r):
- return r - self.first_rewards_code - 1
-
- def lookahead_reward2code(self, r):
- # -1, 0, +1 or 2 for UNKNOWN
- return r + 1 + self.first_lookahead_rewards_code
-
- def code2lookahead_reward(self, r):
- return r - self.first_lookahead_rewards_code - 1
-
- ######################################################################
-
- def generate_episodes(self, nb):
- rnd = torch.rand(nb, self.height, self.width)
- rnd[:, 0, :] = 0
- rnd[:, -1, :] = 0
- rnd[:, :, 0] = 0
- rnd[:, :, -1] = 0
- wall = 0
- for k in range(self.nb_walls):
- wall = wall + (
- rnd.flatten(1).argmax(dim=1)[:, None]
- == torch.arange(rnd.flatten(1).size(1))[None, :]
- ).long().reshape(rnd.size())
-
- rnd = rnd * (1 - wall.clamp(max=1))
-
- rnd = torch.rand(nb, self.height, self.width)
- rnd[:, 0, 0] = 0 # Do not put coin at the agent's starting
- # position
- coins = torch.zeros(nb, self.T, self.height, self.width, dtype=torch.int64)
- rnd = rnd * (1 - wall.clamp(max=1))
- for k in range(self.nb_coins):
- coins[:, 0] = coins[:, 0] + (
- rnd.flatten(1).argmax(dim=1)[:, None]
- == torch.arange(rnd.flatten(1).size(1))[None, :]
- ).long().reshape(rnd.size())
-
- rnd = rnd * (1 - coins[:, 0].clamp(max=1))
-
- states = wall[:, None, :, :].expand(-1, self.T, -1, -1).clone()
-
- agent = torch.zeros(states.size(), dtype=torch.int64)
- agent[:, 0, 0, 0] = 1
- agent_actions = torch.randint(5, (nb, self.T))
- rewards = torch.zeros(nb, self.T, dtype=torch.int64)
-
- troll = torch.zeros(states.size(), dtype=torch.int64)
- troll[:, 0, -1, -1] = 1
- troll_actions = torch.randint(5, (nb, self.T))
-
- all_moves = agent.new(nb, 5, self.height, self.width)
- for t in range(self.T - 1):
- all_moves.zero_()
- all_moves[:, 0] = agent[:, t]
- all_moves[:, 1, 1:, :] = agent[:, t, :-1, :]
- all_moves[:, 2, :-1, :] = agent[:, t, 1:, :]
- all_moves[:, 3, :, 1:] = agent[:, t, :, :-1]
- all_moves[:, 4, :, :-1] = agent[:, t, :, 1:]
- a = F.one_hot(agent_actions[:, t], num_classes=5)[:, :, None, None]
- after_move = (all_moves * a).sum(dim=1)
- collision = (
- (after_move * (1 - wall) * (1 - troll[:, t]))
- .flatten(1)
- .sum(dim=1)[:, None, None]
- == 0
- ).long()
- agent[:, t + 1] = collision * agent[:, t] + (1 - collision) * after_move
-
- all_moves.zero_()
- all_moves[:, 0] = troll[:, t]
- all_moves[:, 1, 1:, :] = troll[:, t, :-1, :]
- all_moves[:, 2, :-1, :] = troll[:, t, 1:, :]
- all_moves[:, 3, :, 1:] = troll[:, t, :, :-1]
- all_moves[:, 4, :, :-1] = troll[:, t, :, 1:]
- a = F.one_hot(troll_actions[:, t], num_classes=5)[:, :, None, None]
- after_move = (all_moves * a).sum(dim=1)
- collision = (
- (after_move * (1 - wall) * (1 - agent[:, t + 1]))
- .flatten(1)
- .sum(dim=1)[:, None, None]
- == 0
- ).long()
- troll[:, t + 1] = collision * troll[:, t] + (1 - collision) * after_move
-
- hit = (
- (agent[:, t + 1, 1:, :] * troll[:, t + 1, :-1, :]).flatten(1).sum(dim=1)
- + (agent[:, t + 1, :-1, :] * troll[:, t + 1, 1:, :])
- .flatten(1)
- .sum(dim=1)
- + (agent[:, t + 1, :, 1:] * troll[:, t + 1, :, :-1])
- .flatten(1)
- .sum(dim=1)
- + (agent[:, t + 1, :, :-1] * troll[:, t + 1, :, 1:])
- .flatten(1)
- .sum(dim=1)
- )
- hit = (hit > 0).long()
-
- # assert hit.min() == 0 and hit.max() <= 1
-
- got_coin = (agent[:, t + 1] * coins[:, t]).flatten(1).sum(dim=1)
- coins[:, t + 1] = coins[:, t] * (1 - agent[:, t + 1])
-
- rewards[:, t + 1] = -hit + (1 - hit) * got_coin
-
- states = states + 2 * agent + 3 * troll + 4 * coins * (1 - troll)
-
- return states, agent_actions, rewards
-
- ######################################################################
-
- def episodes2seq(self, states, actions, rewards):
- neg = rewards.new_zeros(rewards.size())
- pos = rewards.new_zeros(rewards.size())
- for t in range(neg.size(1)):
- neg[:, t] = rewards[:, t:].min(dim=-1).values
- pos[:, t] = rewards[:, t:].max(dim=-1).values
- s = (neg < 0).long() * neg + (neg >= 0).long() * pos
-
- return torch.cat(
- [
- self.lookahead_reward2code(s[:, :, None]),
- self.state2code(states.flatten(2)),
- self.reward2code(rewards[:, :, None]),
- self.action2code(actions[:, :, None]),
- ],
- dim=2,
- ).flatten(1)
-
- def seq2episodes(self, seq):
- seq = seq.reshape(seq.size(0), -1, self.height * self.width + 3)
- lookahead_rewards = self.code2lookahead_reward(
- seq[:, :, self.index_lookahead_reward]
- )
- states = self.code2state(
- seq[:, :, self.index_states : self.height * self.width + self.index_states]
- )
- states = states.reshape(states.size(0), states.size(1), self.height, self.width)
- actions = self.code2action(seq[:, :, self.index_action])
- rewards = self.code2reward(seq[:, :, self.index_reward])
- return lookahead_rewards, states, actions, rewards
-
- def seq2str(self, seq):
- def token2str(t):
- if (
- t >= self.first_states_code
- and t < self.first_states_code + self.nb_states_codes
- ):
- return "_#@T$"[t - self.first_states_code]
- elif (
- t >= self.first_actions_code
- and t < self.first_actions_code + self.nb_actions_codes
- ):
- return "ISNEW"[t - self.first_actions_code]
- elif (
- t >= self.first_rewards_code
- and t < self.first_rewards_code + self.nb_rewards_codes
- ):
- return "-0+"[t - self.first_rewards_code]
- elif (
- t >= self.first_lookahead_rewards_code
- and t
- < self.first_lookahead_rewards_code + self.nb_lookahead_rewards_codes
- ):
- return "n.pU"[t - self.first_lookahead_rewards_code]
- else:
- return "?"
-
- return ["".join([token2str(x.item()) for x in row]) for row in seq]
-
- ######################################################################
-
- def episodes2str(
- self,
- lookahead_rewards,
- states,
- actions,
- rewards,
- unicode=False,
- ansi_colors=False,
- ):
- if unicode:
- symbols = "·█@T$"
- # vert, hori, cross, thin_hori = "║", "═", "╬", "─"
- vert, hori, cross, thin_vert, thin_hori = "┃", "━", "╋", "│", "─"
- else:
- symbols = " #@T$"
- vert, hori, cross, thin_vert, thin_hori = "|", "-", "+", "|", "-"
-
- hline = (cross + hori * states.size(-1)) * states.size(1) + cross + "\n"
-
- result = hline
-
- for n in range(states.size(0)):
-
- def state_symbol(v):
- v = v.item()
- return "?" if v < 0 or v >= len(symbols) else symbols[v]
-
- for i in range(states.size(2)):
- result += (
- vert
- + vert.join(
- [
- "".join([state_symbol(v) for v in row])
- for row in states[n, :, i]
- ]
- )
- + vert
- + "\n"
- )
-
- # result += (vert + thin_hori * states.size(-1)) * states.size(1) + vert + "\n"
-
- def status_bar(a, r, lr=None):
- a, r = a.item(), r.item()
- sb_a = "ISNEW"[a] if a >= 0 and a < 5 else "?"
- sb_r = "- +"[r + 1] if r in {-1, 0, 1} else "?"
- if lr is None:
- sb_lr = ""
- else:
- lr = lr.item()
- sb_lr = "n pU"[lr + 1] if lr in {-1, 0, 1, 2} else "?"
- return (
- sb_a
- + "/"
- + sb_r
- + " " * (states.size(-1) - 1 - len(sb_a + sb_r + sb_lr))
- + sb_lr
- )
-
- result += (
- vert
- + vert.join(
- [
- status_bar(a, r, lr)
- for a, r, lr in zip(
- actions[n], rewards[n], lookahead_rewards[n]
- )
- ]
- )
- + vert
- + "\n"
- )
-
- result += hline
-
- if ansi_colors:
- for u, c in [("T", 31), ("@", 32), ("$", 34)]:
- result = result.replace(u, f"\u001b[{c}m{u}\u001b[0m")
-
- return result
-
- ######################################################################
-
- def save_seq_as_anim_script(self, seq, filename):
- it_len = self.height * self.width + 3
-
- seq = (
- seq.reshape(seq.size(0), -1, it_len)
- .permute(1, 0, 2)
- .reshape(self.T, seq.size(0), -1)
- )
-
- with open(filename, "w") as f:
- for t in range(self.T):
- # f.write("clear\n")
- f.write("cat << EOF\n")
- f.write("\u001b[H")
- # for i in range(seq.size(2)):
- # lr, s, a, r = seq2episodes(seq[t : t + 1, :, i], self.height, self.width)
- lr, s, a, r = self.seq2episodes(seq[t : t + 1, :].reshape(8, -1))
- f.write(self.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True))
- f.write("EOF\n")
- f.write("sleep 0.25\n")
- print(f"Saved {filename}")
-
-
-if __name__ == "__main__":
- gw = GreedWorld(height=5, width=7, T=10, nb_walls=4, nb_coins=2)
- states, actions, rewards = gw.generate_episodes(nb=6)
- seq = gw.episodes2seq(states, actions, rewards)
- lr, s, a, r = gw.seq2episodes(seq)
- print(gw.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True))
-
- print()
- for s in gw.seq2str(seq):
- print(s)
-
- gw = GreedWorld(height=5, width=7, T=100, nb_walls=4, nb_coins=2)
- states, actions, rewards = gw.generate_episodes(nb=128)
- seq = gw.episodes2seq(states, actions, rewards)
- gw.save_seq_as_anim_script(seq, "anim.sh")
+++ /dev/null
-#!/usr/bin/env python
-
-# Any copyright is dedicated to the Public Domain.
-# https://creativecommons.org/publicdomain/zero/1.0/
-
-# Written by Francois Fleuret <francois@fleuret.org>
-
-import math
-import torch, torchvision
-import torch.nn.functional as F
-
-######################################################################
-
-
-class GridFactory:
- def __init__(
- self,
- size=6,
- max_nb_items=4,
- max_nb_transformations=3,
- nb_questions=4,
- nb_shapes=6,
- nb_colors=6,
- nb_play_steps=3,
- ):
- assert size % 2 == 0
- self.size = size
- self.max_nb_items = max_nb_items
- self.max_nb_transformations = max_nb_transformations
- self.nb_questions = nb_questions
- self.nb_play_steps = nb_play_steps
- self.name_shapes = ["A", "B", "C", "D", "E", "F"]
- self.name_colors = ["red", "yellow", "blue", "green", "white", "purple"]
- self.vname_shapes = ["vA", "vB", "vC", "vD", "vE", "vF"]
- self.vname_colors = ["vred", "vyellow", "vblue", "vgreen", "vwhite", "vpurple"]
-
- def generate_scene(self):
- nb_items = torch.randint(self.max_nb_items - 1, (1,)).item() + 2
- col = torch.full((self.size * self.size,), -1)
- shp = torch.full((self.size * self.size,), -1)
- a = torch.randperm(len(self.name_colors) * len(self.name_shapes))[:nb_items]
- col[:nb_items] = a % len(self.name_colors)
- shp[:nb_items] = a // len(self.name_colors)
- i = torch.randperm(self.size * self.size)
- col = col[i]
- shp = shp[i]
- return col.reshape(self.size, self.size), shp.reshape(self.size, self.size)
-
- def random_object_move(self, scene):
- col, shp = scene
- while True:
- a = (col.flatten() >= 0).nonzero()
- a = a[torch.randint(a.size(0), (1,)).item()]
- i, j = a // self.size, a % self.size
- assert col[i, j] >= 0
- dst = [(i, j), (i - 1, j), (i + 1, j), (i, j - 1), (i, j + 1)]
- dst = list(
- filter(
- lambda x: x[0] >= 0
- and x[1] >= 0
- and x[0] < self.size
- and x[1] < self.size
- and col[x[0], x[1]] < 0,
- dst,
- )
- )
- if len(dst) > 0:
- ni, nj = dst[torch.randint(len(dst), (1,)).item()]
- col[ni, nj] = col[i, j]
- shp[ni, nj] = shp[i, j]
- col[i, j] = -1
- shp[i, j] = -1
- break
-
- return col, shp
-
- def transformation(self, t, scene):
- col, shp = scene
- if t == 0:
- col, shp = col.flip(0), shp.flip(0)
- description = "<chg> vertical flip"
- elif t == 1:
- col, shp = col.flip(1), shp.flip(1)
- description = "<chg> horizontal flip"
- elif t == 2:
- col, shp = col.flip(0).t(), shp.flip(0).t()
- description = "<chg> rotate 90 degrees"
- elif t == 3:
- col, shp = col.flip(0).flip(1), shp.flip(0).flip(1)
- description = "<chg> rotate 180 degrees"
- elif t == 4:
- col, shp = col.flip(1).t(), shp.flip(1).t()
- description = "<chg> rotate 270 degrees"
-
- return (col.contiguous(), shp.contiguous()), description
-
- def random_transformations(self, scene):
- descriptions = []
- nb_transformations = torch.randint(self.max_nb_transformations + 1, (1,)).item()
- transformations = torch.randint(5, (nb_transformations,))
-
- for t in transformations:
- scene, description = self.transformation(t, scene)
- descriptions += [description]
-
- return scene, descriptions
-
- def visual_scene2str(self, scene):
- col, shp = scene
- r = []
- for i in range(self.size):
- s = []
- for j in range(self.size):
- if col[i, j] >= 0:
- s += [self.vname_colors[col[i, j]], self.vname_shapes[shp[i, j]]]
- else:
- s += ["v_", "v+"]
- r += s # .append(" ".join(s))
- return " ".join(r)
-
- def print_scene(self, scene):
- col, shp = scene
-
- # for i in range(self.size):
- # for j in range(self.size):
- # if col[i,j] >= 0:
- # print(f"at ({i},{j}) {self.name_colors[col[i,j]]} {self.name_shapes[shp[i,j]]}")
-
- for i in range(self.size):
- for j in range(self.size):
- if col[i, j] >= 0:
- print(
- f"{self.name_colors[col[i,j]][0]}{self.name_shapes[shp[i,j]]}",
- end="",
- )
- elif j == 0:
- print(" +", end="")
- else:
- print("-+", end="")
- if j < self.size - 1:
- print("--", end="")
- else:
- print("")
- if i < self.size - 1:
- for j in range(self.size - 1):
- print(" | ", end="")
- print(" |")
-
- def grid_positions(self, scene):
- col, shp = scene
-
- properties = []
-
- for i in range(self.size):
- for j in range(self.size):
- if col[i, j] >= 0:
- n = f"{self.name_colors[col[i,j]]} {self.name_shapes[shp[i,j]]}"
- properties += [f"a {n} at {i} {j}"]
-
- return properties
-
- def all_properties(self, scene):
- col, shp = scene
-
- properties = []
-
- for i1 in range(self.size):
- for j1 in range(self.size):
- if col[i1, j1] >= 0:
- n1 = (
- f"{self.name_colors[col[i1,j1]]} {self.name_shapes[shp[i1,j1]]}"
- )
- properties += [f"there is a {n1}"]
- if i1 < self.size // 2:
- properties += [f"a {n1} is in the top half"]
- if i1 >= self.size // 2:
- properties += [f"a {n1} is in the bottom half"]
- if j1 < self.size // 2:
- properties += [f"a {n1} is in the left half"]
- if j1 >= self.size // 2:
- properties += [f"a {n1} is in the right half"]
- for i2 in range(self.size):
- for j2 in range(self.size):
- if col[i2, j2] >= 0:
- n2 = f"{self.name_colors[col[i2,j2]]} {self.name_shapes[shp[i2,j2]]}"
- if i1 > i2:
- properties += [f"a {n1} is below a {n2}"]
- if i1 < i2:
- properties += [f"a {n1} is above a {n2}"]
- if j1 > j2:
- properties += [f"a {n1} is right of a {n2}"]
- if j1 < j2:
- properties += [f"a {n1} is left of a {n2}"]
- if abs(i1 - i2) + abs(j1 - j2) == 1:
- properties += [f"a {n1} is next to a {n2}"]
-
- return properties
-
- def generate_scene_and_play(self):
- scene = self.generate_scene()
- steps = [self.visual_scene2str(scene)]
- for t in range(self.nb_play_steps - 1):
- if torch.randint(4, (1,)).item() == 0:
- scene, _ = self.transformation(torch.randint(5, (1,)), scene)
- else:
- scene = self.random_object_move(scene)
- steps.append(self.visual_scene2str(scene))
- return " | ".join(steps)
-
- def generate_scene_and_questions(self):
- while True:
- # We generate scenes until we get one with enough
- # properties
-
- while True:
- start_scene = self.generate_scene()
- scene, transformations = self.random_transformations(start_scene)
- true = self.all_properties(scene)
- if len(true) >= self.nb_questions:
- break
-
- # We generate a bunch of false properties by shuffling the
- # scene and sometimes adding properties from totally
- # different scenes. We try ten times to get enough false
- # properties and go back to generating the scene if we do
- # not succeed
-
- for a in range(10):
- col, shp = scene
- col, shp = col.view(-1), shp.view(-1)
- p = torch.randperm(col.size(0))
- col, shp = col[p], shp[p]
- other_scene = (
- col.view(self.size, self.size),
- shp.view(self.size, self.size),
- )
-
- false = self.all_properties(other_scene)
-
- # We sometime add properties from a totally different
- # scene to have negative "there is a xxx xxx"
- # properties
-
- if torch.rand(1).item() < 0.2:
- other_scene = self.generate_scene()
- false += self.all_properties(other_scene)
-
- false = list(set(false) - set(true))
- if len(false) >= self.nb_questions:
- break
-
- if a < 10:
- break
-
- true = [true[k] for k in torch.randperm(len(true))[: self.nb_questions]]
- false = [false[k] for k in torch.randperm(len(false))[: self.nb_questions]]
- true = ["<prop> " + q + " <ans> true" for q in true]
- false = ["<prop> " + q + " <ans> false" for q in false]
-
- union = true + false
- questions = [union[k] for k in torch.randperm(len(union))[: self.nb_questions]]
-
- result = " ".join(
- ["<obj> " + x for x in self.grid_positions(start_scene)]
- + transformations
- + questions
- )
-
- return start_scene, scene, result
-
- def generate_samples(self, nb, fraction_play=0.0, progress_bar=None):
- result = []
-
- play = torch.rand(nb) < fraction_play
- if progress_bar is not None:
- play = progress_bar(play)
-
- for p in play:
- if p:
- result.append(self.generate_scene_and_play())
- else:
- result.append(self.generate_scene_and_questions()[2])
-
- return result
-
-
-######################################################################
-
-if __name__ == "__main__":
- import time
-
- grid_factory = GridFactory()
-
- # start_time = time.perf_counter()
- # samples = grid_factory.generate_samples(10000)
- # end_time = time.perf_counter()
- # print(f"{len(samples) / (end_time - start_time):.02f} samples per second")
-
- start_scene, scene, questions = grid_factory.generate_scene_and_questions()
- print()
- print("-- Original scene -----------------------------")
- print()
- grid_factory.print_scene(start_scene)
- print()
- print("-- Transformed scene --------------------------")
- print()
- grid_factory.print_scene(scene)
- print()
- print("-- Sequence -----------------------------------")
- print()
- print(questions)
-
- # print(grid_factory.visual_scene2str(scene))
-
- # grid_factory.print_scene(scene)
- # for t in range(5):
- # scene = grid_factory.random_object_move(scene)
- # print()
- # grid_factory.print_scene(scene)
-
- print(grid_factory.generate_scene_and_play())
-
-######################################################################
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
-parser.add_argument(
- "--task",
- type=str,
- default="world",
- help="file, byheart, learnop, guessop, mixing, memory, twotargets, addition, picoclvr, mnist, maze, snake, stack, expr, rpl, grid, qmlp, greed",
-)
+parser.add_argument("--task", type=str, default="world", help="world")
parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
-##############################
-# filetask
-
-parser.add_argument("--filetask_train_file", type=str, default=None)
-
-parser.add_argument("--filetask_test_file", type=str, default=None)
-
-##############################
-# rpl options
-
-parser.add_argument("--rpl_nb_starting_values", type=int, default=3)
-
-parser.add_argument("--rpl_max_input", type=int, default=9)
-
-parser.add_argument("--rpl_prog_len", type=int, default=8)
-
-parser.add_argument("--rpl_nb_runs", type=int, default=5)
-
-parser.add_argument("--rpl_no_prog", action="store_true", default=False)
-
-##############################
-# grid options
-
-parser.add_argument("--grid_size", type=int, default=6)
-
-parser.add_argument("--grid_fraction_play", type=float, default=0)
-
-##############################
-# picoclvr options
-
-parser.add_argument("--picoclvr_nb_colors", type=int, default=5)
-
-parser.add_argument("--picoclvr_height", type=int, default=12)
-
-parser.add_argument("--picoclvr_width", type=int, default=16)
-
-parser.add_argument("--picocvlr_prune_properties", type=str, default="none")
-
-##############################
-# Maze options
-
-parser.add_argument("--maze_height", type=int, default=13)
-
-parser.add_argument("--maze_width", type=int, default=21)
-
-parser.add_argument("--maze_nb_walls", type=int, default=15)
-
-##############################
-# Snake options
-
-parser.add_argument("--snake_height", type=int, default=9)
-
-parser.add_argument("--snake_width", type=int, default=12)
-
-parser.add_argument("--snake_nb_colors", type=int, default=5)
-
-parser.add_argument("--snake_length", type=int, default=200)
-
-##############################
-# ByHeart options
-
-parser.add_argument("--byheart_separation", type=int, default=1)
-
-##############################
-# Stack options
-
-parser.add_argument("--stack_nb_steps", type=int, default=100)
-
-parser.add_argument("--stack_nb_stacks", type=int, default=3)
-
-parser.add_argument("--stack_nb_digits", type=int, default=3)
-
-parser.add_argument("--stack_fraction_values_for_train", type=float, default=None)
-
-##############################
-# Expr options
-
-parser.add_argument("--expr_nb_variables", type=int, default=5)
-
-parser.add_argument("--expr_sequence_length", type=int, default=40)
-
-parser.add_argument("--expr_operand_max", type=int, default=9)
-
-parser.add_argument("--expr_result_max", type=int, default=99)
-
-parser.add_argument("--expr_input_file", type=str, default=None)
-
-##############################
-# Mixing
-
-parser.add_argument("--mixing_hard", action="store_true", default=False)
-
-parser.add_argument("--mixing_deterministic_start", action="store_true", default=False)
-
-##############################
-# greed options
-
-parser.add_argument("--greed_height", type=int, default=5)
-
-parser.add_argument("--greed_width", type=int, default=7)
-
-parser.add_argument("--greed_T", type=int, default=25)
-
-parser.add_argument("--greed_nb_walls", type=int, default=5)
-
-parser.add_argument("--greed_nb_coins", type=int, default=2)
-
######################################################################
args = parser.parse_args()
-assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
-
if args.result_dir is None:
args.result_dir = f"results_{args.task}"
"nb_train_samples": 250000,
"nb_test_samples": 10000,
},
- "file": {
- "model": "37M",
- "batch_size": 25,
- "nb_train_samples": 250000,
- "nb_test_samples": 10000,
- },
- "addition": {
- "model": "352M",
- "batch_size": 25,
- "nb_train_samples": 250000,
- "nb_test_samples": 10000,
- },
- "byheart": {
- "model": "37M",
- "batch_size": 25,
- "nb_train_samples": 50000,
- "nb_test_samples": 10000,
- },
- "expr": {
- "model": "352M",
- "batch_size": 25,
- "nb_train_samples": 2500000,
- "nb_test_samples": 10000,
- },
- "grid": {
- "model": "37M",
- "batch_size": 25,
- "nb_train_samples": 250000,
- "nb_test_samples": 10000,
- },
- "qmlp": {
- "model": "37M",
- "batch_size": 10,
- "nb_train_samples": 100000,
- "nb_test_samples": 1000,
- },
- "guessop": {
- "model": "352M",
- "batch_size": 25,
- "nb_train_samples": 1000000,
- "nb_test_samples": 10000,
- },
- "learnop": {
- "model": "37M",
- "batch_size": 25,
- "nb_train_samples": 50000,
- "nb_test_samples": 10000,
- },
- "maze": {
- "model": "37M",
- "batch_size": 5,
- "nb_train_samples": 100000,
- "nb_test_samples": 10000,
- },
- "picoclvr": {
- "model": "37M",
- "batch_size": 25,
- "nb_train_samples": 250000,
- "nb_test_samples": 10000,
- },
- "rpl": {
- "model": "352M",
- "batch_size": 5,
- "nb_train_samples": 2500000,
- "nb_test_samples": 10000,
- },
- "snake": {
- "model": "37M",
- "batch_size": 25,
- "nb_train_samples": 250000,
- "nb_test_samples": 10000,
- },
- "stack": {
- "model": "37M",
- "batch_size": 25,
- "nb_train_samples": 100000,
- "nb_test_samples": 1000,
- },
- "twotargets": {
- "model": "37M",
- "batch_size": 25,
- "nb_train_samples": 50000,
- "nb_test_samples": 10000,
- },
- "memory": {
- "model": "37M",
- "batch_size": 100,
- "nb_train_samples": 25000,
- "nb_test_samples": 1000,
- },
- "mixing": {
- "model": "37M",
- "batch_size": 25,
- "nb_train_samples": 250000,
- "nb_test_samples": 10000,
- },
- "mnist": {
- "model": "37M",
- "batch_size": 10,
- "nb_train_samples": 60000,
- "nb_test_samples": 10000,
- },
- "greed": {
- "model": "37M",
- "batch_size": 25,
- "nb_train_samples": 25000,
- "nb_test_samples": 10000,
- },
}
if args.task in default_task_args:
######################################################################
-def picoclvr_pruner_horizontal_green(p):
- return not ("green" in p and ("left" in p or "right" in p))
-
-
-picoclvr_pruner_train = (
- picoclvr_pruner_horizontal_green
- if args.picocvlr_prune_properties in {"train+eval"}
- else None
-)
-
-picoclvr_pruner_eval = (
- (lambda p: not picoclvr_pruner_horizontal_green(p))
- if args.picocvlr_prune_properties in {"train+eval", "eval"}
- else None
-)
-
-######################################################################
-
if args.physical_batch_size is None:
args.physical_batch_size = args.batch_size
else:
accuracy_to_make_quizzes = 0.975
for n_epoch in range(args.nb_epochs):
+ # select the model with lowest accuracy
models.sort(key=lambda model: model.main_test_accuracy)
-
model = models[0]
log_string(
f"training model {model.id} main_test_accuracy {model.main_test_accuracy}"
)
+ # improve it
one_epoch(model, task)
log_string(
f"train_set_composition world {task.nb_batch_samples_world} quizzes {task.nb_batch_samples_quizzes}"
)
+ # test it
run_tests(model, task, deterministic_synthesis=False)
if model.main_test_accuracy >= accuracy_to_make_quizzes:
+++ /dev/null
-#!/usr/bin/env python
-
-# Any copyright is dedicated to the Public Domain.
-# https://creativecommons.org/publicdomain/zero/1.0/
-
-# Written by Francois Fleuret <francois@fleuret.org>
-
-import torch, torchvision
-
-######################################################################
-
-v_empty, v_wall, v_start, v_goal, v_path = 0, 1, 2, 3, 4
-
-
-def create_maze(h=11, w=17, nb_walls=8):
- assert h % 2 == 1 and w % 2 == 1
-
- nb_attempts, nb_added_walls = 0, 0
-
- while nb_added_walls < nb_walls:
- while True:
- if nb_attempts == 0:
- m = torch.zeros(h, w, dtype=torch.int64)
- m[0, :] = 1
- m[-1, :] = 1
- m[:, 0] = 1
- m[:, -1] = 1
-
- r = torch.rand(4)
-
- if r[0] <= 0.5:
- # Add a vertical wall
- i1, i2, j = (
- int((r[1] * h).item()),
- int((r[2] * h).item()),
- int((r[3] * w).item()),
- )
- i1, i2, j = i1 - i1 % 2, i2 - i2 % 2, j - j % 2
- i1, i2 = min(i1, i2), max(i1, i2)
-
- # If this wall does not hit another one, add it
- if i2 - i1 > 1 and i2 - i1 <= h / 2 and m[i1 : i2 + 1, j].sum() <= 1:
- m[i1 : i2 + 1, j] = 1
- break
-
- else:
- # Add an horizontal wall
- i, j1, j2 = (
- int((r[1] * h).item()),
- int((r[2] * w).item()),
- int((r[3] * w).item()),
- )
- i, j1, j2 = i - i % 2, j1 - j1 % 2, j2 - j2 % 2
- j1, j2 = min(j1, j2), max(j1, j2)
-
- # If this wall does not hit another one, add it
- if j2 - j1 > 1 and j2 - j1 <= w / 2 and m[i, j1 : j2 + 1].sum() <= 1:
- m[i, j1 : j2 + 1] = 1
- break
-
- nb_attempts += 1
-
- if nb_attempts > 10 * nb_walls:
- nb_attempts, nb_added_walls = 0, 0
-
- nb_added_walls += 1
-
- return m
-
-
-######################################################################
-
-
-def compute_distance(walls, goal_i, goal_j):
- max_length = walls.numel()
- dist = torch.full_like(walls, max_length)
-
- dist[goal_i, goal_j] = 0
- pred_dist = torch.empty_like(dist)
-
- while True:
- pred_dist.copy_(dist)
- d = (
- torch.cat(
- (
- dist[None, 1:-1, 0:-2],
- dist[None, 2:, 1:-1],
- dist[None, 1:-1, 2:],
- dist[None, 0:-2, 1:-1],
- ),
- 0,
- ).min(dim=0)[0]
- + 1
- )
-
- dist[1:-1, 1:-1] = torch.min(dist[1:-1, 1:-1], d)
- dist = walls * max_length + (1 - walls) * dist
-
- if dist.equal(pred_dist):
- return dist * (1 - walls)
-
-
-######################################################################
-
-
-def compute_policy(walls, goal_i, goal_j):
- distance = compute_distance(walls, goal_i, goal_j)
- distance = distance + walls.numel() * walls
-
- value = distance.new_full((4,) + distance.size(), walls.numel())
- value[0, :, 1:] = distance[:, :-1] # <
- value[1, :, :-1] = distance[:, 1:] # >
- value[2, 1:, :] = distance[:-1, :] # ^
- value[3, :-1, :] = distance[1:, :] # v
-
- proba = (value.min(dim=0)[0][None] == value).float()
- proba = proba / proba.sum(dim=0)[None]
- proba = proba * (1 - walls) + walls.float() / 4
-
- return proba
-
-
-def stationary_densities(mazes, policies):
- policies = policies * (mazes != v_goal)[:, None]
- start = (mazes == v_start).nonzero(as_tuple=True)
- probas = mazes.new_zeros(mazes.size(), dtype=torch.float32)
- pred_probas = probas.clone()
- probas[start] = 1.0
-
- while not pred_probas.equal(probas):
- pred_probas.copy_(probas)
- probas.zero_()
- probas[:, 1:, :] += pred_probas[:, :-1, :] * policies[:, 3, :-1, :]
- probas[:, :-1, :] += pred_probas[:, 1:, :] * policies[:, 2, 1:, :]
- probas[:, :, 1:] += pred_probas[:, :, :-1] * policies[:, 1, :, :-1]
- probas[:, :, :-1] += pred_probas[:, :, 1:] * policies[:, 0, :, 1:]
- probas[start] = 1.0
-
- return probas
-
-
-######################################################################
-
-
-def mark_path(walls, i, j, goal_i, goal_j, policy):
- action = torch.distributions.categorical.Categorical(
- policy.permute(1, 2, 0)
- ).sample()
- n, nmax = 0, walls.numel()
- while i != goal_i or j != goal_j:
- di, dj = [(0, -1), (0, 1), (-1, 0), (1, 0)][action[i, j]]
- i, j = i + di, j + dj
- assert walls[i, j] == 0
- walls[i, j] = v_path
- n += 1
- assert n < nmax
-
-
-def path_optimality(ref_paths, paths):
- return (ref_paths == v_path).long().flatten(1).sum(1) == (
- paths == v_path
- ).long().flatten(1).sum(1)
-
-
-def path_correctness(mazes, paths):
- still_ok = (mazes - (paths * (paths != v_path))).view(mazes.size(0), -1).abs().sum(
- 1
- ) == 0
- reached = still_ok.new_zeros(still_ok.size())
- current, pred_current = paths.clone(), paths.new_zeros(paths.size())
- goal = (mazes == v_goal).long()
- while not pred_current.equal(current):
- pred_current.copy_(current)
- u = (current == v_start).long()
- possible_next = (
- u[:, 2:, 1:-1] + u[:, 0:-2, 1:-1] + u[:, 1:-1, 2:] + u[:, 1:-1, 0:-2] > 0
- ).long()
- u = u[:, 1:-1, 1:-1]
- reached += ((goal[:, 1:-1, 1:-1] * possible_next).sum((1, 2)) == 1) * (
- (current == v_path).sum((1, 2)) == 0
- )
- current[:, 1:-1, 1:-1] = (1 - u) * current[:, 1:-1, 1:-1] + (
- v_start - v_path
- ) * (possible_next * (current[:, 1:-1, 1:-1] == v_path))
- still_ok *= (current == v_start).sum((1, 2)) <= 1
-
- return still_ok * reached
-
-
-######################################################################
-
-
-def create_maze_data(
- nb, height=11, width=17, nb_walls=8, dist_min=10, progress_bar=lambda x: x
-):
- mazes = torch.empty(nb, height, width, dtype=torch.int64)
- paths = torch.empty(nb, height, width, dtype=torch.int64)
- policies = torch.empty(nb, 4, height, width)
-
- for n in progress_bar(range(nb)):
- maze = create_maze(height, width, nb_walls)
- i = (maze == v_empty).nonzero()
- while True:
- start, goal = i[torch.randperm(i.size(0))[:2]]
- if (start - goal).abs().sum() >= dist_min:
- break
- start_i, start_j, goal_i, goal_j = start[0], start[1], goal[0], goal[1]
-
- policy = compute_policy(maze, goal_i, goal_j)
- path = maze.clone()
- mark_path(path, start_i, start_j, goal_i, goal_j, policy)
- maze[start_i, start_j] = v_start
- maze[goal_i, goal_j] = v_goal
- path[start_i, start_j] = v_start
- path[goal_i, goal_j] = v_goal
-
- mazes[n] = maze
- paths[n] = path
- policies[n] = policy
-
- return mazes, paths, policies
-
-
-######################################################################
-
-
-def save_image(
- name,
- mazes,
- target_paths=None,
- predicted_paths=None,
- path_correct=None,
- path_optimal=None,
-):
- colors = torch.tensor(
- [
- [255, 255, 255], # empty
- [0, 0, 0], # wall
- [0, 255, 0], # start
- [127, 127, 255], # goal
- [255, 0, 0], # path
- ]
- )
-
- mazes = mazes.cpu()
-
- c_mazes = (
- colors[mazes.reshape(-1)].reshape(mazes.size() + (-1,)).permute(0, 3, 1, 2)
- )
-
- imgs = c_mazes.unsqueeze(1)
-
- if target_paths is not None:
- target_paths = target_paths.cpu()
-
- c_target_paths = (
- colors[target_paths.reshape(-1)]
- .reshape(target_paths.size() + (-1,))
- .permute(0, 3, 1, 2)
- )
-
- imgs = torch.cat((imgs, c_target_paths.unsqueeze(1)), 1)
-
- if predicted_paths is not None:
- predicted_paths = predicted_paths.cpu()
- c_predicted_paths = (
- colors[predicted_paths.reshape(-1)]
- .reshape(predicted_paths.size() + (-1,))
- .permute(0, 3, 1, 2)
- )
- imgs = torch.cat((imgs, c_predicted_paths.unsqueeze(1)), 1)
-
- img = torch.tensor([255, 255, 0]).view(1, -1, 1, 1)
-
- # NxKxCxHxW
- if path_optimal is not None:
- path_optimal = path_optimal.cpu().long().view(-1, 1, 1, 1)
- img = (
- img * (1 - path_optimal)
- + torch.tensor([0, 255, 0]).view(1, -1, 1, 1) * path_optimal
- )
-
- if path_correct is not None:
- path_correct = path_correct.cpu().long().view(-1, 1, 1, 1)
- img = img * path_correct + torch.tensor([255, 0, 0]).view(1, -1, 1, 1) * (
- 1 - path_correct
- )
-
- img = img.expand(
- -1, -1, imgs.size(3) + 2, 1 + imgs.size(1) * (1 + imgs.size(4))
- ).clone()
-
- print(f"{img.size()=} {imgs.size()=}")
-
- for k in range(imgs.size(1)):
- img[
- :,
- :,
- 1 : 1 + imgs.size(3),
- 1 + k * (1 + imgs.size(4)) : 1 + k * (1 + imgs.size(4)) + imgs.size(4),
- ] = imgs[:, k]
-
- img = img.float() / 255.0
-
- torchvision.utils.save_image(img, name, nrow=4, padding=1, pad_value=224.0 / 256)
-
-
-######################################################################
-
-if __name__ == "__main__":
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- mazes, paths, policies = create_maze_data(8)
- mazes, paths = mazes.to(device), paths.to(device)
- save_image("test.png", mazes=mazes, target_paths=paths, predicted_paths=paths)
- print(path_correctness(mazes, paths))
-
-######################################################################
+++ /dev/null
-#!/usr/bin/env python
-
-# Any copyright is dedicated to the Public Domain.
-# https://creativecommons.org/publicdomain/zero/1.0/
-
-# Written by Francois Fleuret <francois@fleuret.org>
-
-import math
-import torch, torchvision
-import torch.nn.functional as F
-
-color_name2rgb = {
- "white": [255, 255, 255],
- "red": [255, 0, 0],
- "green": [0, 128, 0],
- "blue": [0, 0, 255],
- "yellow": [255, 255, 0],
- "black": [0, 0, 0],
- "maroon": [128, 0, 0],
- "dark_red": [139, 0, 0],
- "brown": [165, 42, 42],
- "firebrick": [178, 34, 34],
- "crimson": [220, 20, 60],
- "tomato": [255, 99, 71],
- "coral": [255, 127, 80],
- "indian_red": [205, 92, 92],
- "light_coral": [240, 128, 128],
- "dark_salmon": [233, 150, 122],
- "salmon": [250, 128, 114],
- "light_salmon": [255, 160, 122],
- "orange_red": [255, 69, 0],
- "dark_orange": [255, 140, 0],
- "orange": [255, 165, 0],
- "gold": [255, 215, 0],
- "dark_golden_rod": [184, 134, 11],
- "golden_rod": [218, 165, 32],
- "pale_golden_rod": [238, 232, 170],
- "dark_khaki": [189, 183, 107],
- "khaki": [240, 230, 140],
- "olive": [128, 128, 0],
- "yellow_green": [154, 205, 50],
- "dark_olive_green": [85, 107, 47],
- "olive_drab": [107, 142, 35],
- "lawn_green": [124, 252, 0],
- "chartreuse": [127, 255, 0],
- "green_yellow": [173, 255, 47],
- "dark_green": [0, 100, 0],
- "forest_green": [34, 139, 34],
- "lime": [0, 255, 0],
- "lime_green": [50, 205, 50],
- "light_green": [144, 238, 144],
- "pale_green": [152, 251, 152],
- "dark_sea_green": [143, 188, 143],
- "medium_spring_green": [0, 250, 154],
- "spring_green": [0, 255, 127],
- "sea_green": [46, 139, 87],
- "medium_aqua_marine": [102, 205, 170],
- "medium_sea_green": [60, 179, 113],
- "light_sea_green": [32, 178, 170],
- "dark_slate_gray": [47, 79, 79],
- "teal": [0, 128, 128],
- "dark_cyan": [0, 139, 139],
- "aqua": [0, 255, 255],
- "cyan": [0, 255, 255],
- "light_cyan": [224, 255, 255],
- "dark_turquoise": [0, 206, 209],
- "turquoise": [64, 224, 208],
- "medium_turquoise": [72, 209, 204],
- "pale_turquoise": [175, 238, 238],
- "aqua_marine": [127, 255, 212],
- "powder_blue": [176, 224, 230],
- "cadet_blue": [95, 158, 160],
- "steel_blue": [70, 130, 180],
- "corn_flower_blue": [100, 149, 237],
- "deep_sky_blue": [0, 191, 255],
- "dodger_blue": [30, 144, 255],
- "light_blue": [173, 216, 230],
- "sky_blue": [135, 206, 235],
- "light_sky_blue": [135, 206, 250],
- "midnight_blue": [25, 25, 112],
- "navy": [0, 0, 128],
- "dark_blue": [0, 0, 139],
- "medium_blue": [0, 0, 205],
- "royal_blue": [65, 105, 225],
- "blue_violet": [138, 43, 226],
- "indigo": [75, 0, 130],
- "dark_slate_blue": [72, 61, 139],
- "slate_blue": [106, 90, 205],
- "medium_slate_blue": [123, 104, 238],
- "medium_purple": [147, 112, 219],
- "dark_magenta": [139, 0, 139],
- "dark_violet": [148, 0, 211],
- "dark_orchid": [153, 50, 204],
- "medium_orchid": [186, 85, 211],
- "purple": [128, 0, 128],
- "thistle": [216, 191, 216],
- "plum": [221, 160, 221],
- "violet": [238, 130, 238],
- "magenta": [255, 0, 255],
- "orchid": [218, 112, 214],
- "medium_violet_red": [199, 21, 133],
- "pale_violet_red": [219, 112, 147],
- "deep_pink": [255, 20, 147],
- "hot_pink": [255, 105, 180],
- "light_pink": [255, 182, 193],
- "pink": [255, 192, 203],
- "antique_white": [250, 235, 215],
- "beige": [245, 245, 220],
- "bisque": [255, 228, 196],
- "blanched_almond": [255, 235, 205],
- "wheat": [245, 222, 179],
- "corn_silk": [255, 248, 220],
- "lemon_chiffon": [255, 250, 205],
- "light_golden_rod_yellow": [250, 250, 210],
- "light_yellow": [255, 255, 224],
- "saddle_brown": [139, 69, 19],
- "sienna": [160, 82, 45],
- "chocolate": [210, 105, 30],
- "peru": [205, 133, 63],
- "sandy_brown": [244, 164, 96],
- "burly_wood": [222, 184, 135],
- "tan": [210, 180, 140],
- "rosy_brown": [188, 143, 143],
- "moccasin": [255, 228, 181],
- "navajo_white": [255, 222, 173],
- "peach_puff": [255, 218, 185],
- "misty_rose": [255, 228, 225],
- "lavender_blush": [255, 240, 245],
- "linen": [250, 240, 230],
- "old_lace": [253, 245, 230],
- "papaya_whip": [255, 239, 213],
- "sea_shell": [255, 245, 238],
- "mint_cream": [245, 255, 250],
- "slate_gray": [112, 128, 144],
- "light_slate_gray": [119, 136, 153],
- "light_steel_blue": [176, 196, 222],
- "lavender": [230, 230, 250],
- "floral_white": [255, 250, 240],
- "alice_blue": [240, 248, 255],
- "ghost_white": [248, 248, 255],
- "honeydew": [240, 255, 240],
- "ivory": [255, 255, 240],
- "azure": [240, 255, 255],
- "snow": [255, 250, 250],
- "silver": [192, 192, 192],
- "gainsboro": [220, 220, 220],
- "white_smoke": [245, 245, 245],
-}
-
-color_name2id = dict([(n, k) for k, n in enumerate(color_name2rgb.keys())])
-color_id2name = dict([(k, n) for k, n in enumerate(color_name2rgb.keys())])
-
-######################################################################
-
-
-def all_properties(height, width, nb_squares, square_i, square_j, square_c):
- s = []
-
- for r, c_r in [(k, color_id2name[square_c[k].item()]) for k in range(nb_squares)]:
- s += [f"there is {c_r}"]
-
- if square_i[r] >= height - height // 3:
- s += [f"{c_r} bottom"]
- if square_i[r] < height // 3:
- s += [f"{c_r} top"]
- if square_j[r] >= width - width // 3:
- s += [f"{c_r} right"]
- if square_j[r] < width // 3:
- s += [f"{c_r} left"]
-
- for t, c_t in [
- (k, color_id2name[square_c[k].item()]) for k in range(nb_squares)
- ]:
- if square_i[r] > square_i[t]:
- s += [f"{c_r} below {c_t}"]
- if square_i[r] < square_i[t]:
- s += [f"{c_r} above {c_t}"]
- if square_j[r] > square_j[t]:
- s += [f"{c_r} right of {c_t}"]
- if square_j[r] < square_j[t]:
- s += [f"{c_r} left of {c_t}"]
-
- return s
-
-
-######################################################################
-
-# Generates sequences
-
-
-def generate(
- nb,
- height,
- width,
- max_nb_squares=5,
- max_nb_properties=10,
- nb_colors=5,
- pruner=None,
-):
- assert nb_colors >= max_nb_squares and nb_colors <= len(color_name2rgb) - 1
-
- descr = []
-
- for n in range(nb):
- # we want uniform over the combinations of 1 to max_nb_squares
- # pixels of nb_colors
- logits = math.log(nb_colors) * torch.arange(1, max_nb_squares + 1).float()
- dist = torch.distributions.categorical.Categorical(logits=logits)
- nb_squares = dist.sample((1,)) + 1
- # nb_squares = torch.randint(max_nb_squares, (1,)) + 1
- square_position = torch.randperm(height * width)[:nb_squares]
-
- # color 0 is white and reserved for the background
- square_c = torch.randperm(nb_colors)[:nb_squares] + 1
- square_i = square_position.div(width, rounding_mode="floor")
- square_j = square_position % width
-
- img = torch.zeros(height * width, dtype=torch.int64)
- for k in range(nb_squares):
- img[square_position[k]] = square_c[k]
-
- # generates all the true properties
-
- s = all_properties(height, width, nb_squares, square_i, square_j, square_c)
-
- if pruner is not None:
- s = list(filter(pruner, s))
-
- # picks at most max_nb_properties at random
-
- nb_properties = torch.randint(max_nb_properties, (1,)) + 1
- s = (
- " <sep> ".join([s[k] for k in torch.randperm(len(s))[:nb_properties]])
- + " <img> "
- + " ".join([f"{color_id2name[n.item()]}" for n in img])
- )
-
- descr += [s]
-
- return descr
-
-
-######################################################################
-
-# Extracts the image after <img> in descr as a 1x3xHxW tensor
-
-
-def descr2img(descr, height, width):
- result = []
-
- def token2color(t):
- try:
- return color_name2rgb[t]
- except KeyError:
- return [128, 128, 128]
-
- for d in descr:
- d = d.split("<img>")[1]
- d = d.strip().split(" ")[: height * width]
- d = d + ["<unk>"] * (height * width - len(d))
- d = [token2color(t) for t in d]
- img = torch.tensor(d).permute(1, 0).reshape(1, 3, height, width)
- result.append(img)
-
- return torch.cat(result, 0)
-
-
-######################################################################
-
-# Returns all the properties of the image after <img> in descr
-
-
-def descr2properties(descr, height, width):
- if type(descr) == list:
- return [descr2properties(d, height, width) for d in descr]
-
- d = descr.split("<img>")
- img_tokens = d[-1] if len(d) > 1 else ""
- img_tokens = img_tokens.strip().split(" ")[: height * width]
- if len(img_tokens) != height * width:
- return []
-
- seen = {}
- for k, x in enumerate(img_tokens):
- if x != color_id2name[0]:
- if x in color_name2rgb:
- if x in seen:
- return []
- else:
- return []
- seen[x] = (color_name2id[x], k // width, k % width)
-
- square_infos = tuple(zip(*seen.values()))
-
- if square_infos:
- square_c = torch.tensor(square_infos[0])
- square_i = torch.tensor(square_infos[1])
- square_j = torch.tensor(square_infos[2])
- else:
- square_c = torch.tensor([])
- square_i = torch.tensor([])
- square_j = torch.tensor([])
-
- s = all_properties(height, width, len(seen), square_i, square_j, square_c)
-
- return s
-
-
-######################################################################
-
-# Returns a triplet composed of (1) the total number of properties
-# before <img> in descr, (2) the total number of properties the image
-# after <img> verifies, and (3) the number of properties in (1) not in
-# (2)
-
-
-def nb_properties(descr, height, width, pruner=None):
- if type(descr) == list:
- return [nb_properties(d, height, width, pruner) for d in descr]
-
- d = descr.split("<img>", 1)
- if len(d) == 0:
- return 0
- d = d[0].strip().split("<sep>")
- d = [x.strip() for x in d]
-
- all_properties = set(descr2properties(descr, height, width))
-
- if pruner is None:
- requested_properties = set(d)
- else:
- requested_properties = set(filter(pruner, d))
-
- missing_properties = requested_properties - all_properties
-
- return (len(requested_properties), len(all_properties), len(missing_properties))
-
-
-######################################################################
-
-if __name__ == "__main__":
- for n in range(16):
- descr = generate(nb=1, height=12, width=16)
-
- print(nb_properties(descr, height=12, width=16))
-
- with open(f"picoclvr_example_{n:02d}.txt", "w") as f:
- for d in descr:
- f.write(f"{d}\n\n")
-
- img = descr2img(descr, height=12, width=16)
- if img.size(0) == 1:
- img = F.pad(img, (1, 1, 1, 1), value=64)
-
- torchvision.utils.save_image(
- img / 255.0,
- f"picoclvr_example_{n:02d}.png",
- padding=1,
- nrow=4,
- pad_value=0.8,
- )
-
- import time
-
- start_time = time.perf_counter()
- descr = generate(nb=1000, height=12, width=16)
- end_time = time.perf_counter()
- print(f"{len(descr) / (end_time - start_time):.02f} samples per second")
-
-######################################################################
+++ /dev/null
-#!/usr/bin/env python
-
-# @XREMOTE_HOST: elk.fleuret.org
-# @XREMOTE_EXEC: python
-# @XREMOTE_PRE: source ${HOME}/misc/venv/pytorch/bin/activate
-# @XREMOTE_PRE: killall -u ${USER} -q -9 python || true
-# @XREMOTE_PRE: ln -sf ${HOME}/data/pytorch ./data
-# @XREMOTE_SEND: *.py *.sh
-
-# Any copyright is dedicated to the Public Domain.
-# https://creativecommons.org/publicdomain/zero/1.0/
-
-# Written by Francois Fleuret <francois@fleuret.org>
-
-import math, sys
-
-import torch, torchvision
-
-from torch import nn
-from torch.nn import functional as F
-
-######################################################################
-
-nb_quantization_levels = 101
-
-
-def quantize(x, xmin, xmax):
- return (
- ((x - xmin) / (xmax - xmin) * nb_quantization_levels)
- .long()
- .clamp(min=0, max=nb_quantization_levels - 1)
- )
-
-
-def dequantize(q, xmin, xmax):
- return q / nb_quantization_levels * (xmax - xmin) + xmin
-
-
-######################################################################
-
-
-def generate_sets_and_params(
- batch_nb_mlps,
- nb_samples,
- batch_size,
- nb_epochs,
- device=torch.device("cpu"),
- print_log=False,
- save_as_examples=False,
-):
- data_input = torch.zeros(batch_nb_mlps, 2 * nb_samples, 2, device=device)
- data_targets = torch.zeros(
- batch_nb_mlps, 2 * nb_samples, dtype=torch.int64, device=device
- )
-
- nb_rec = 8
- nb_values = 2 # more increases the min-max gap
-
- rec_support = torch.empty(batch_nb_mlps, nb_rec, 4, device=device)
-
- while (data_targets.float().mean(-1) - 0.5).abs().max() > 0.1:
- i = (data_targets.float().mean(-1) - 0.5).abs() > 0.1
- nb = i.sum()
- support = torch.rand(nb, nb_rec, 2, nb_values, device=device) * 2 - 1
- support = support.sort(-1).values
- support = support[:, :, :, torch.tensor([0, nb_values - 1])].view(nb, nb_rec, 4)
-
- x = torch.rand(nb, 2 * nb_samples, 2, device=device) * 2 - 1
- y = (
- (
- (x[:, None, :, 0] >= support[:, :, None, 0]).long()
- * (x[:, None, :, 0] <= support[:, :, None, 1]).long()
- * (x[:, None, :, 1] >= support[:, :, None, 2]).long()
- * (x[:, None, :, 1] <= support[:, :, None, 3]).long()
- )
- .max(dim=1)
- .values
- )
-
- data_input[i], data_targets[i], rec_support[i] = x, y, support
-
- train_input, train_targets = (
- data_input[:, :nb_samples],
- data_targets[:, :nb_samples],
- )
- test_input, test_targets = data_input[:, nb_samples:], data_targets[:, nb_samples:]
-
- q_train_input = quantize(train_input, -1, 1)
- train_input = dequantize(q_train_input, -1, 1)
-
- q_test_input = quantize(test_input, -1, 1)
- test_input = dequantize(q_test_input, -1, 1)
-
- if save_as_examples:
- a = (
- 2
- * torch.arange(nb_quantization_levels).float()
- / (nb_quantization_levels - 1)
- - 1
- )
- xf = torch.cat(
- [
- a[:, None, None].expand(
- nb_quantization_levels, nb_quantization_levels, 1
- ),
- a[None, :, None].expand(
- nb_quantization_levels, nb_quantization_levels, 1
- ),
- ],
- 2,
- )
- xf = xf.reshape(1, -1, 2).expand(min(q_train_input.size(0), 10), -1, -1)
- print(f"{xf.size()=} {x.size()=}")
- yf = (
- (
- (xf[:, None, :, 0] >= rec_support[: xf.size(0), :, None, 0]).long()
- * (xf[:, None, :, 0] <= rec_support[: xf.size(0), :, None, 1]).long()
- * (xf[:, None, :, 1] >= rec_support[: xf.size(0), :, None, 2]).long()
- * (xf[:, None, :, 1] <= rec_support[: xf.size(0), :, None, 3]).long()
- )
- .max(dim=1)
- .values
- )
-
- full_input, full_targets = xf, yf
-
- q_full_input = quantize(full_input, -1, 1)
- full_input = dequantize(q_full_input, -1, 1)
-
- for k in range(q_full_input[:10].size(0)):
- with open(f"example_full_{k:04d}.dat", "w") as f:
- for u, c in zip(full_input[k], full_targets[k]):
- f.write(f"{c} {u[0].item()} {u[1].item()}\n")
-
- for k in range(q_train_input[:10].size(0)):
- with open(f"example_train_{k:04d}.dat", "w") as f:
- for u, c in zip(train_input[k], train_targets[k]):
- f.write(f"{c} {u[0].item()} {u[1].item()}\n")
-
- hidden_dim = 32
- w1 = torch.randn(batch_nb_mlps, hidden_dim, 2, device=device) / math.sqrt(2)
- b1 = torch.zeros(batch_nb_mlps, hidden_dim, device=device)
- w2 = torch.randn(batch_nb_mlps, 2, hidden_dim, device=device) / math.sqrt(
- hidden_dim
- )
- b2 = torch.zeros(batch_nb_mlps, 2, device=device)
-
- w1.requires_grad_()
- b1.requires_grad_()
- w2.requires_grad_()
- b2.requires_grad_()
- optimizer = torch.optim.Adam([w1, b1, w2, b2], lr=1e-2)
-
- criterion = nn.CrossEntropyLoss()
- criterion.to(device)
-
- for k in range(nb_epochs):
- acc_train_loss = 0.0
- nb_train_errors = 0
-
- for input, targets in zip(
- train_input.split(batch_size, dim=1), train_targets.split(batch_size, dim=1)
- ):
- h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :]
- h = F.relu(h)
- output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :]
- loss = F.cross_entropy(
- output.reshape(-1, output.size(-1)), targets.reshape(-1)
- )
- acc_train_loss += loss.item() * input.size(0)
-
- wta = output.argmax(-1)
- nb_train_errors += (wta != targets).long().sum(-1)
-
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
-
- with torch.no_grad():
- for p in [w1, b1, w2, b2]:
- m = (
- torch.rand(p.size(), device=p.device) <= k / (nb_epochs - 1)
- ).long()
- pq = quantize(p, -2, 2)
- p[...] = (1 - m) * p + m * dequantize(pq, -2, 2)
-
- train_error = nb_train_errors / train_input.size(1)
- acc_train_loss = acc_train_loss / train_input.size(1)
-
- # print(f"{k=} {acc_train_loss=} {train_error=}")
-
- acc_test_loss = 0
- nb_test_errors = 0
-
- for input, targets in zip(
- test_input.split(batch_size, dim=1), test_targets.split(batch_size, dim=1)
- ):
- h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :]
- h = F.relu(h)
- output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :]
- loss = F.cross_entropy(output.reshape(-1, output.size(-1)), targets.reshape(-1))
- acc_test_loss += loss.item() * input.size(0)
-
- wta = output.argmax(-1)
- nb_test_errors += (wta != targets).long().sum(-1)
-
- test_error = nb_test_errors / test_input.size(1)
- q_params = torch.cat(
- [quantize(p.view(batch_nb_mlps, -1), -2, 2) for p in [w1, b1, w2, b2]], dim=1
- )
- q_train_set = torch.cat([q_train_input, train_targets[:, :, None]], -1).reshape(
- batch_nb_mlps, -1
- )
- q_test_set = torch.cat([q_test_input, test_targets[:, :, None]], -1).reshape(
- batch_nb_mlps, -1
- )
-
- return q_train_set, q_test_set, q_params, test_error
-
-
-######################################################################
-
-
-def evaluate_q_params(
- q_params,
- q_set,
- batch_size=25,
- device=torch.device("cpu"),
- nb_mlps_per_batch=1024,
- save_as_examples=False,
-):
- errors = []
- nb_mlps = q_params.size(0)
-
- for n in range(0, nb_mlps, nb_mlps_per_batch):
- batch_nb_mlps = min(nb_mlps_per_batch, nb_mlps - n)
- batch_q_params = q_params[n : n + batch_nb_mlps]
- batch_q_set = q_set[n : n + batch_nb_mlps]
- hidden_dim = 32
- w1 = torch.empty(batch_nb_mlps, hidden_dim, 2, device=device)
- b1 = torch.empty(batch_nb_mlps, hidden_dim, device=device)
- w2 = torch.empty(batch_nb_mlps, 2, hidden_dim, device=device)
- b2 = torch.empty(batch_nb_mlps, 2, device=device)
-
- with torch.no_grad():
- k = 0
- for p in [w1, b1, w2, b2]:
- print(f"{p.size()=}")
- x = dequantize(
- batch_q_params[:, k : k + p.numel() // batch_nb_mlps], -2, 2
- ).view(p.size())
- p.copy_(x)
- k += p.numel() // batch_nb_mlps
-
- batch_q_set = batch_q_set.view(batch_nb_mlps, -1, 3)
- data_input = dequantize(batch_q_set[:, :, :2], -1, 1).to(device)
- data_targets = batch_q_set[:, :, 2].to(device)
-
- print(f"{data_input.size()=} {data_targets.size()=}")
-
- criterion = nn.CrossEntropyLoss()
- criterion.to(device)
-
- acc_loss = 0.0
- nb_errors = 0
-
- for input, targets in zip(
- data_input.split(batch_size, dim=1), data_targets.split(batch_size, dim=1)
- ):
- h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :]
- h = F.relu(h)
- output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :]
- loss = F.cross_entropy(
- output.reshape(-1, output.size(-1)), targets.reshape(-1)
- )
- acc_loss += loss.item() * input.size(0)
- wta = output.argmax(-1)
- nb_errors += (wta != targets).long().sum(-1)
-
- errors.append(nb_errors / data_input.size(1))
- acc_loss = acc_loss / data_input.size(1)
-
- return torch.cat(errors)
-
-
-######################################################################
-
-
-def generate_sequence_and_test_set(
- nb_mlps,
- nb_samples,
- batch_size,
- nb_epochs,
- device,
- nb_mlps_per_batch=1024,
-):
- seqs, q_test_sets, test_errors = [], [], []
-
- for n in range(0, nb_mlps, nb_mlps_per_batch):
- q_train_set, q_test_set, q_params, test_error = generate_sets_and_params(
- batch_nb_mlps=min(nb_mlps_per_batch, nb_mlps - n),
- nb_samples=nb_samples,
- batch_size=batch_size,
- nb_epochs=nb_epochs,
- device=device,
- )
-
- seqs.append(
- torch.cat(
- [
- q_train_set,
- q_train_set.new_full(
- (
- q_train_set.size(0),
- 1,
- ),
- nb_quantization_levels,
- ),
- q_params,
- ],
- dim=-1,
- )
- )
-
- q_test_sets.append(q_test_set)
- test_errors.append(test_error)
-
- seq = torch.cat(seqs)
- q_test_set = torch.cat(q_test_sets)
- test_error = torch.cat(test_errors)
-
- return seq, q_test_set, test_error
-
-
-######################################################################
-
-if __name__ == "__main__":
- import time
-
- batch_nb_mlps, nb_samples = 128, 250
-
- generate_sets_and_params(
- batch_nb_mlps=10,
- nb_samples=nb_samples,
- batch_size=25,
- nb_epochs=100,
- device=torch.device("cpu"),
- print_log=False,
- save_as_examples=True,
- )
-
- exit(0)
-
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
- start_time = time.perf_counter()
-
- data = []
-
- seq, q_test_set, test_error = generate_sequence_and_test_set(
- nb_mlps=batch_nb_mlps,
- nb_samples=nb_samples,
- device=device,
- batch_size=25,
- nb_epochs=250,
- nb_mlps_per_batch=17,
- )
-
- end_time = time.perf_counter()
- print(f"{seq.size(0) / (end_time - start_time):.02f} samples per second")
-
- q_train_set = seq[:, : nb_samples * 3]
- q_params = seq[:, nb_samples * 3 + 1 :]
- print(f"SANITY #2 {q_train_set.size()=} {q_params.size()=} {seq.size()=}")
- error_train = evaluate_q_params(q_params, q_train_set, nb_mlps_per_batch=17)
- print(f"train {error_train*100}%")
- error_test = evaluate_q_params(q_params, q_test_set, nb_mlps_per_batch=17)
- print(f"test {error_test*100}%")
+++ /dev/null
-#!/usr/bin/env python
-
-# Any copyright is dedicated to the Public Domain.
-# https://creativecommons.org/publicdomain/zero/1.0/
-
-# Written by Francois Fleuret <francois@fleuret.org>
-
-import math
-
-import torch, torchvision
-
-from torch import nn
-from torch.nn import functional as F
-
-######################################################################
-
-
-def rpl_exec(program, stack):
- stack = stack.copy()
- for op in program:
- if op == "add":
- if len(stack) > 1:
- a, b = stack.pop(), stack.pop()
- stack.append(a + b)
- elif op == "min":
- if len(stack) > 1:
- a, b = stack.pop(), stack.pop()
- stack.append(min(a, b))
- elif op == "max":
- if len(stack) > 1:
- a, b = stack.pop(), stack.pop()
- stack.append(max(a, b))
- elif op == "swp":
- if len(stack) > 1:
- a, b = stack.pop(), stack.pop()
- stack.append(a)
- stack.append(b)
- elif op == "rep":
- if len(stack) > 1:
- a, b = stack.pop(), stack.pop()
- stack += [b] * a
- elif op == "dup":
- if len(stack) > 0:
- a = stack.pop()
- stack.append(a)
- stack.append(a)
- elif op == "del":
- if len(stack) > 0:
- a = stack.pop()
- else:
- raise ValueError(f"Unknown instruction {op}")
-
- return stack
-
-
-rpl_ops = ["add", "min", "max", "swp", "rep", "dup", "del"]
-
-######################################################################
-
-
-def generate(
- nb_starting_values=3, nb_result_values_max=None, max_input=9, prog_len=6, nb_runs=5
-):
- prog_len = (1 + torch.randint(2 * prog_len, (1,))).clamp(max=prog_len).item()
-
- while True:
- no_empty_stack = True
- prog = [rpl_ops[k] for k in torch.randint(len(rpl_ops), (prog_len,))]
-
- result = []
- for _ in range(nb_runs):
- stack = [
- x.item() for x in torch.randint(max_input + 1, (nb_starting_values,))
- ]
- result_stack = rpl_exec(prog, stack)
- if len(result_stack) == 0:
- no_empty_stack = False
- result = result + ["<in>"] + stack + ["<out>"] + result_stack
-
- result = result + ["<prg>"] + prog
- result = result + ["<end>"]
-
- if no_empty_stack and (
- nb_result_values_max is None or len(result_stack) <= nb_result_values_max
- ):
- break
-
- return result
-
-
-def next_marker(seq, tokens, start=0):
- pos = None
- for t in tokens:
- try:
- i = seq.index(t, start)
- if pos is None or i < pos:
- pos = i
- except ValueError:
- pass
- return pos
-
-
-def decompose(seq):
- io = []
- k = 0
- while seq[k] == "<in>":
- o = next_marker(seq, ["<out>"], start=k + 1)
- if o is None:
- raise ValueError("Missing output markers (should be correct in the prompt)")
- e = next_marker(seq, ["<in>", "<prg>"], start=o)
- if e is None:
- raise ValueError(
- "Missing input/output markers (should be correct in the prompt)"
- )
- try:
- io.append(
- ([int(x) for x in seq[k + 1 : o]], [int(x) for x in seq[o + 1 : e]])
- )
- except ValueError:
- raise ValueError(
- "Invalid input/output value (should be correct in the prompt)"
- )
-
- k = e
-
- if seq[k] == "<prg>":
- e = next_marker(seq, ["<end>"], start=k)
- if e is None:
- prog = []
- else:
- prog = seq[k + 1 : e]
- else:
- raise ValueError("Missing <prg> (it should be in the prompt)")
-
- return prog, io
-
-
-def stack_distance(target_stack, result_stack):
- return abs(len(result_stack) - len(target_stack)) + sum(
- [0 if x == y else 1 for x, y in zip(result_stack, target_stack)]
- )
-
-
-def compute_nb_errors(seq):
- prog, io = decompose(seq)
-
- nb_total, nb_errors = 0, 0
-
- stacks = []
-
- if len(set(prog) - set(rpl_ops)) > 0:
- # Program is not valid, we count 100% error
- for start_stack, target_stack in io:
- stacks.append((start_stack, target_stack, ["N/A"], False))
- nb_total += len(target_stack)
- nb_errors += len(target_stack)
-
- else:
- # Program is valid
- for start_stack, target_stack in io:
- result_stack = rpl_exec(prog, start_stack)
- nb_total += len(target_stack)
- e = stack_distance(target_stack, result_stack)
- nb_errors += e
- stacks.append((start_stack, target_stack, result_stack, e == 0))
-
- return nb_total, nb_errors, prog, stacks
-
-
-######################################################################
-
-if __name__ == "__main__":
- seq = generate()
- print(seq)
- seq[3] = 7
- print(seq)
- print(compute_nb_errors(seq))
+++ /dev/null
-#!/usr/bin/env python
-
-# Any copyright is dedicated to the Public Domain.
-# https://creativecommons.org/publicdomain/zero/1.0/
-
-# Written by Francois Fleuret <francois@fleuret.org>
-
-import torch, torchvision
-import torch.nn.functional as F
-
-
-def generate_sequences(
- nb, height, width, nb_colors, length, prompt_length, device=torch.device("cpu")
-):
- worlds = torch.randint(nb_colors, (nb, height, width), device=device)
- world_prior_visits = torch.zeros(nb, height, width, device=device)
-
- # nb x 2
- snake_position = torch.cat(
- (
- torch.randint(height, (nb, 1), device=device),
- torch.randint(width, (nb, 1), device=device),
- ),
- 1,
- )
- snake_direction = torch.randint(4, (nb,), device=device)
- sequences = torch.empty(nb, 2 * length, device=device, dtype=torch.int64)
- sequences_prior_visits = torch.zeros(
- nb, 2 * length, device=device, dtype=torch.int64
- )
- i = torch.arange(nb, device=device) # [:,None]
-
- for l in range(length):
- # nb x 3
- snake_next_direction = torch.cat(
- (
- (snake_direction[:, None] - 1) % 4,
- snake_direction[:, None],
- (snake_direction[:, None] + 1) % 4,
- ),
- 1,
- )
-
- # nb x 3
- vh = (snake_next_direction + 1) % 2 * (snake_next_direction - 1)
- vw = snake_next_direction % 2 * (snake_next_direction - 2)
-
- # nb x 3 x 2
- snake_next_speed = torch.cat((vh[:, :, None], vw[:, :, None]), 2)
- snake_next_position = snake_position[:, None, :] + snake_next_speed
-
- # nb x 3
- val = torch.logical_and(
- torch.logical_and(
- snake_next_position[:, :, 0] >= 0, snake_next_position[:, :, 0] < height
- ),
- torch.logical_and(
- snake_next_position[:, :, 1] >= 0, snake_next_position[:, :, 1] < width
- ),
- ).float()
- val = (
- # The multiplicative factors bias toward moving forward
- torch.rand_like(val)
- * val
- * torch.tensor([[1.0, 2.0, 1.0]], device=device)
- )
-
- # nb
- j = val.argmax(1)
- snake_direction = snake_next_direction[i, j]
-
- sequences[:, 2 * l] = worlds[i, snake_position[:, 0], snake_position[:, 1]] + 4
- sequences_prior_visits[:, 2 * l] = world_prior_visits[
- i, snake_position[:, 0], snake_position[:, 1]
- ]
- if l < prompt_length:
- world_prior_visits[i, snake_position[:, 0], snake_position[:, 1]] += 1
- sequences[:, 2 * l + 1] = snake_direction
-
- # nb x 2
- snake_position = snake_next_position[i, j]
-
- return sequences, sequences_prior_visits, worlds, world_prior_visits
-
-
-# generate_snake_sequences(nb=1, height=4, width=6, nb_colors=3, length=20)
-# exit(0)
-
-
-def solver(input, ar_mask):
- for n in range(input.size(0)):
- i, j, memory = 0, 0, {}
- # print(input[n])
- # print(ar_mask[n])
- for l in range(input.size(1) // 2):
- if ar_mask[n, 2 * l] == 1:
- if memory.get((i, j)) is None:
- input[n, 2 * l] = -1
- else:
- input[n, 2 * l] = memory[(i, j)]
- else:
- # print(f'@3 {memory=}')
- if memory.get((i, j)) is None:
- memory[(i, j)] = input[n, 2 * l]
- else:
- assert memory[(i, j)] == input[n, 2 * l], f"n={n} l={l}"
- # print(f'@1 {i=} {j=}')
- d = input[n, 2 * l + 1].item()
- i += (d + 1) % 2 * (d - 1)
- j += d % 2 * (d - 2)
- # print(f'@2 {i=} {j=}')
-
-
-def seq2str(seq):
- return "".join(["NESW123456789"[i] for i in seq])
-
-
-######################################################################
-
-if __name__ == "__main__":
- train_input, train_prior_visits, _, _ = generate_sequences(
- nb=20,
- height=9,
- width=12,
- nb_colors=5,
- length=50,
- prompt_length=100,
- )
-
- print([seq2str(s) for s in train_input])
-
-######################################################################
+++ /dev/null
-#!/usr/bin/env python
-
-# Any copyright is dedicated to the Public Domain.
-# https://creativecommons.org/publicdomain/zero/1.0/
-
-# Written by Francois Fleuret <francois@fleuret.org>
-
-import torch, torchvision
-
-######################################################################
-
-# CODE_OP=[0 for push, 1 for pop] + 2 * n_stack
-# CODE_VAL=val + 2 * nb_stacks
-
-
-def generate_sequences(
- nb, nb_steps, nb_stacks, nb_digits, values=None, device=torch.device("cpu")
-):
- stack = torch.empty(nb, nb_stacks, nb_steps, dtype=torch.int64)
- stack_counts = torch.zeros(nb, nb_stacks, dtype=torch.int64)
- k = torch.arange(nb)
- result = torch.empty(nb, (1 + nb_digits) * nb_steps, dtype=torch.int64)
- recorded_stack_counts = torch.zeros(
- nb, (1 + nb_digits) * nb_steps, dtype=torch.int64
- )
-
- for t in range(nb_steps):
- op = torch.randint(2, (nb,))
- st = torch.randint(nb_stacks, (nb,))
- op = op * (stack_counts[k, st] > 0)
- if values is None:
- val_push = torch.randint(10**nb_digits, (nb,))
- else:
- val_push = values[torch.randint(values.size(0), (nb,))]
- val_pop = stack[
- k,
- st,
- (stack_counts[k, st] - 1).clamp(min=0),
- ]
- stack[k, st, stack_counts[k, st]] = val_push
- recorded_stack_counts[:, (1 + nb_digits) * t] = stack_counts[k, st]
- stack_counts[k[op == 0], st[op == 0]] += 1
- stack_counts[k[op == 1], st[op == 1]] -= 1
- result[:, (1 + nb_digits) * t] = st * 2 + op
- for d in range(nb_digits):
- result[:, (1 + nb_digits) * t + 1 + d] = (
- (op * val_pop + (1 - op) * val_push) // (10**d)
- ) % 10 + 2 * nb_stacks
-
- return result.to(device), recorded_stack_counts.to(device)
-
-
-def remove_popped_values(seq, nb_stacks, nb_digits):
- m = torch.logical_and(seq % 2 == 1, seq < 2 * nb_stacks).long()
- for d in range(nb_digits):
- k = d + 1
- seq[:, k:] = -m[:, :-k] + (1 - m[:, :-k]) * seq[:, k:]
-
-
-def seq_to_str(seq, nb_stacks, nb_digits, recorded_stack_counts=None):
- assert seq.size(0) % (1 + nb_digits) == 0
- s = ""
- for t in range(seq.size(0) // (1 + nb_digits)):
- n_op = seq[(1 + nb_digits) * t]
- if t > 0:
- s += " "
- if recorded_stack_counts is not None:
- s += f"[{recorded_stack_counts[(1 + nb_digits)*t]}] "
- s += f"POP" if n_op % 2 == 1 else f"PSH"
- if nb_stacks > 1:
- s += f"_{n_op//2}"
- for d in range(nb_digits):
- if seq[(1 + nb_digits) * t + 1 + d] == -1:
- s += " ?"
- else:
- s += f" {seq[(1 + nb_digits) * t + 1 + d] - 2 * nb_stacks:1d}"
- return s
-
-
-######################################################################
-
-if __name__ == "__main__":
- nb, nb_steps, nb_stacks, nb_digits = 150000, 20, 2, 1
- seq, recorded_stack_counts = generate_sequences(
- nb=nb,
- nb_steps=nb_steps,
- nb_stacks=nb_stacks,
- nb_digits=nb_digits,
- )
-
- for n in range(min(10, seq.size(0))):
- print(
- seq_to_str(
- seq[n],
- nb_stacks=nb_stacks,
- nb_digits=nb_digits,
- recorded_stack_counts=recorded_stack_counts[n],
- )
- )
- # print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits))
-
- print("-- PREPARED FOR TEST -----------------")
-
- remove_popped_values(seq, nb_stacks, nb_digits)
-
- for n in range(min(10, seq.size(0))):
- print(seq_to_str(seq[n], nb_stacks=nb_stacks, nb_digits=nb_digits))
pass
-class TaskFromFile(Task):
- def tensorize(self, pairs, shuffle):
- len_max = max([len(x[0]) for x in pairs])
-
- input = torch.cat(
- [
- torch.tensor(
- [
- [self.char2id[c] for c in s[0] + "#" * (len_max - len(s[0]))]
- for s in pairs
- ]
- )
- ],
- 0,
- ).to("cpu")
-
- pred_mask = torch.cat(
- [
- torch.tensor(
- [
- [int(c) for c in s[1] + "0" * (len_max - len(s[1]))]
- for s in pairs
- ]
- )
- ],
- 0,
- ).to("cpu")
-
- if shuffle:
- i = torch.randperm(input.size(0))
- input = input[i].contiguous()
- pred_mask = pred_mask[i].contiguous()
-
- return input, pred_mask
-
- # trim all the tensors in the tuple z to remove as much token from
- # left and right in the first tensor. If z is a tuple, all its
- # elements are trimed according to the triming for the first
- def trim(self, z, token="#"):
- n = self.char2id[token]
- if type(z) == tuple:
- x = z[0]
- i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
- a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
- return tuple([t[:, a:b] for t in z])
- else:
- i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
- a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
- return z[:, a:b]
-
- def __init__(
- self,
- train_filename,
- test_filename,
- nb_train_samples,
- nb_test_samples,
- batch_size,
- shuffle=False,
- device=torch.device("cpu"),
- ):
- self.batch_size = batch_size
- self.device = device
-
- def read_file(filename, nb=-1):
- pairs = []
- with open(filename, "r") as f:
- while True:
- sequence = f.readline().strip()
- if not sequence:
- break
- pred_mask = f.readline().strip()
- assert len(sequence) == len(pred_mask)
- assert set(pred_mask).issubset({"0", "1", "2"}), f"{set(pred_mask)}"
- pairs.append((sequence, pred_mask))
- if len(pairs) == nb:
- break
-
- if nb > 0:
- pairs = pairs[:nb]
- assert len(pairs) == nb
-
- return pairs
-
- train_pairs = read_file(train_filename, nb_train_samples)
- test_pairs = read_file(test_filename, nb_test_samples)
-
- symbols = ["#"] + list(
- set("".join([x[0] for x in train_pairs + test_pairs])) - set(["#"])
- )
- self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
- self.id2char = dict([(n, c) for c, n in self.char2id.items()])
-
- self.train_input, self.train_pred_masks = self.tensorize(
- train_pairs, shuffle=shuffle
- )
- self.test_input, self.test_pred_masks = self.tensorize(
- test_pairs, shuffle=shuffle
- )
-
- def batches(self, split="train", nb_to_use=-1, desc=None):
- assert split in {"train", "test"}
- input = self.train_input if split == "train" else self.test_input
- if nb_to_use > 0:
- input = input[:nb_to_use]
- if desc is None:
- desc = f"epoch-{split}"
- for batch in tqdm.tqdm(
- input.split(self.batch_size), dynamic_ncols=True, desc=desc
- ):
- yield self.trim(batch).to(self.device)
-
- def vocabulary_size(self):
- return len(self.char2id)
-
- def tensor2str(self, t):
- return ["".join([self.id2char[x.item()] for x in s]) for s in t]
-
- def produce_results(
- self, n_epoch, model, result_dir, logger, deterministic_synthesis
- ):
- correct = self.trim(self.test_input[:1000]).to(self.device)
- result = correct.clone()
- pred_mask = self.test_pred_masks[:1000, : result.size(1)].to(self.device)
- ar_mask = (pred_mask > 0).long()
- result *= 1 - ar_mask # paraaaaanoiaaaaaaa
-
- logger(f"----------------------------------------------------------")
-
- for e in self.tensor2str(result[:50]):
- logger(f"test_before {e}")
-
- masked_inplace_autoregression(
- model,
- self.batch_size,
- result,
- ar_mask,
- deterministic_synthesis,
- device=self.device,
- )
-
- logger(f"----------------------------------------------------------")
-
- for e, c in zip(self.tensor2str(result[:50]), self.tensor2str(correct[:50])):
- logger(f"test_after {e}")
- logger(f"correct {c}")
-
- logger(f"----------------------------------------------------------")
-
- err_mask = (pred_mask == 2).long()
- nb_total = err_mask.sum().item()
- nb_correct = ((correct == result).long() * err_mask).sum().item()
-
- logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}")
- logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}")
-
-
-####################
-
-import problems
-
-
-class SandBox(Task):
- def __init__(
- self,
- problem,
- nb_train_samples,
- nb_test_samples,
- batch_size,
- logger=None,
- device=torch.device("cpu"),
- max_nb_codes=1024,
- ):
- super().__init__()
-
- self.batch_size = batch_size
- self.device = device
- self.problem = problem
-
- self.train_input, self.train_ar_mask = self.problem.generate_sequences(
- nb_train_samples
- )
- self.test_input, self.test_ar_mask = self.problem.generate_sequences(
- nb_test_samples
- )
-
- self.train_input, self.train_ar_mask = self.train_input.to(
- device
- ), self.train_ar_mask.to(device)
- self.test_input, self.test_ar_mask = self.test_input.to(
- device
- ), self.test_ar_mask.to(device)
-
- self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
-
- # A bit of paranoia never hurts
- assert self.nb_codes <= max_nb_codes
- assert self.train_input.min() >= 0
- assert self.test_input.min() >= 0
- assert tuple(x.item() for x in self.train_ar_mask.unique()) in {
- (0,),
- (1,),
- (0, 1),
- }
- assert tuple(x.item() for x in self.test_ar_mask.unique()) in {
- (0,),
- (1,),
- (0, 1),
- }
-
- if logger is not None:
- for s, a in zip(self.train_input[:100], self.train_ar_mask[:100]):
- logger(f"train_sequences {self.problem.seq2str(s)}")
- a = "".join(["01"[x.item()] for x in a])
- logger(f" {a}")
-
- def batches(self, split="train", nb_to_use=-1, desc=None):
- assert split in {"train", "test"}
- input = self.train_input if split == "train" else self.test_input
- if nb_to_use > 0:
- input = input[:nb_to_use]
- if desc is None:
- desc = f"epoch-{split}"
- for batch in tqdm.tqdm(
- input.split(self.batch_size), dynamic_ncols=True, desc=desc
- ):
- yield batch
-
- def vocabulary_size(self):
- return self.nb_codes
-
- def produce_results(
- self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
- ):
- def compute_accuracy(input, ar_mask, logger=None):
- input, ar_mask = input[:nmax], ar_mask[:nmax]
- result = input.clone() * (1 - ar_mask)
-
- masked_inplace_autoregression(
- model,
- self.batch_size,
- result,
- ar_mask,
- deterministic_synthesis,
- progress_bar_desc=None,
- device=self.device,
- )
-
- log_ground_truth = ar_mask.min() == 0
-
- if logger is not None:
- for sp, st in zip(result[:10], input[:10]):
- logger(
- f"test_sequences {n_epoch} prediction {self.problem.seq2str(sp)}"
- )
- if log_ground_truth:
- logger(
- f" {n_epoch} ground truth {self.problem.seq2str(st)}"
- )
-
- nb_total, nb_correct = self.problem.compute_nb_correct(
- input, ar_mask, result
- )
-
- # nb_total = ar_mask.sum().item()
- # nb_correct = ((result == input).long() * ar_mask).sum().item()
-
- return nb_total, nb_correct
-
- train_nb_total, train_nb_correct = compute_accuracy(
- self.train_input, self.train_ar_mask
- )
-
- logger(
- f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
- )
-
- test_nb_total, test_nb_correct = compute_accuracy(
- self.test_input, self.test_ar_mask, logger
- )
-
- logger(
- f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
- )
-
- logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
-
- if save_attention_image is not None:
- for k in range(10):
- ns = torch.randint(self.test_input.size(0), (1,)).item()
- input = self.test_input[ns : ns + 1].clone()
-
- with torch.autograd.no_grad():
- t = model.training
- model.eval()
- # model.record_attention(True)
- model(BracketedSequence(input))
- model.train(t)
- # ram = model.retrieve_attention()
- # model.record_attention(False)
-
- # tokens_output = [c for c in self.problem.seq2str(input[0])]
- # tokens_input = ["n/a"] + tokens_output[:-1]
- # for n_head in range(ram[0].size(1)):
- # filename = os.path.join(
- # result_dir, f"sandbox_attention_{k}_h{n_head}.pdf"
- # )
- # attention_matrices = [m[0, n_head] for m in ram]
- # save_attention_image(
- # filename,
- # tokens_input,
- # tokens_output,
- # attention_matrices,
- # k_top=10,
- ##min_total_attention=0.9,
- # token_gap=12,
- # layer_gap=50,
- # )
- # logger(f"wrote {filename}")
-
-
-######################################################################
-
-import picoclvr
-
-
-class PicoCLVR(Task):
- # Make a tensor from a list of strings
- def tensorize(self, descr):
- token_descr = [s.strip().split(" ") for s in descr]
- l = max([len(s) for s in token_descr])
- token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
- id_descr = [[self.token2id[u] for u in s] for s in token_descr]
- return torch.tensor(id_descr, device=self.device)
-
- # Make a list of strings from a tensor
- def detensorize(self, x):
- return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
-
- # trim all the tensors in the tuple z to remove as much token from
- # left and right in the first tensor. If z is a tuple, all its
- # elements are trimed according to the triming for the first
- def trim(self, z, token="<nul>"):
- n = self.token2id[token]
- if type(z) == tuple:
- x = z[0]
- i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
- a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
- return tuple([t[:, a:b] for t in z])
- else:
- i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
- a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
- return z[:, a:b]
-
- ######################
-
- def __init__(
- self,
- nb_train_samples,
- nb_test_samples,
- batch_size,
- height,
- width,
- nb_colors=5,
- logger=None,
- device=torch.device("cpu"),
- pruner_train=None,
- pruner_eval=None,
- ):
- super().__init__()
-
- def generate_descr(nb, cache_suffix, pruner):
- return picoclvr.generate(
- nb,
- height=self.height,
- width=self.width,
- nb_colors=nb_colors,
- pruner=pruner,
- )
-
- self.height = height
- self.width = width
- self.batch_size = batch_size
- self.device = device
- self.pruner_train = pruner_train
- self.pruner_eval = pruner_eval
-
- if logger is not None:
- logger(
- f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
- )
-
- self.train_descr = generate_descr(
- nb_train_samples, "train", pruner=self.pruner_train
- )
- self.test_descr = generate_descr(nb_test_samples, "test", pruner=None)
-
- # Build the tokenizer
- tokens = {"<nul>", "<img>"}
- for d in [self.train_descr, self.test_descr]:
- for s in d:
- for t in s.strip().split(" "):
- tokens.add(t)
- # make this set a sorted list to get the same tensors given
- # the same descr
- tokens = list(tokens)
- tokens.sort()
- self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
- self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
- self.t_img, self.t_nul = self.token2id["<img>"], self.token2id["<nul>"]
-
- # Tokenize the train and test sets
- self.train_input = self.tensorize(self.train_descr)
- self.test_input = self.tensorize(self.test_descr)
-
- def batches(self, split="train", nb_to_use=-1, desc=None):
- assert split in {"train", "test"}
- input = self.train_input if split == "train" else self.test_input
- for batch in tqdm.tqdm(
- input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
- ):
- yield self.trim(batch)
-
- def vocabulary_size(self):
- return len(self.token2id)
-
- def compute_missing_properties(
- self, n_epoch, model, logger, deterministic_synthesis, pruner=None
- ):
- acc_nb_requested_properties = []
- acc_nb_missing_properties = []
- acc_nb_results = 0
-
- for input in tqdm.tqdm(
- self.test_input.split(self.batch_size),
- dynamic_ncols=True,
- desc=f"test-properties",
- ):
- result = input.clone()
- ar_mask = (result == self.t_img).long().cumsum(dim=1).clamp(max=1)
- result = (1 - ar_mask) * result + ar_mask * self.t_nul
- masked_inplace_autoregression(
- model,
- self.batch_size,
- result,
- ar_mask,
- deterministic_synthesis,
- progress_bar_desc=None,
- device=self.device,
- )
-
- result_descr = self.detensorize(result)
- np = picoclvr.nb_properties(
- result_descr,
- height=self.height,
- width=self.width,
- pruner=pruner,
- )
- nb_requested_properties, _, nb_missing_properties = zip(*np)
- acc_nb_requested_properties += nb_requested_properties
- acc_nb_missing_properties += nb_missing_properties
- acc_nb_results += len(result_descr)
-
- nb_requested_properties = sum(acc_nb_requested_properties)
- nb_missing_properties = sum(acc_nb_missing_properties)
-
- prefix = "" if pruner is None else "pruned_"
- logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
- logger(
- f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
- )
- logger(
- f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
- )
-
- logger(
- f"main_test_accuracy {n_epoch} {1-nb_missing_properties/nb_requested_properties}"
- )
-
- ######################################################################
-
- def produce_results(
- self, n_epoch, model, result_dir, logger, deterministic_synthesis
- ):
- self.compute_missing_properties(n_epoch, model, logger, deterministic_synthesis)
-
- if self.pruner_eval is not None:
- self.compute_missing_properties(n_epoch, model, self.pruner_eval)
-
- nb_tokens_to_generate = self.height * self.width + 3
- result_descr = []
- nb_per_primer = 8
- primer = []
-
- for primer_descr in [
- "red above green <sep> green top <sep> blue right of red",
- "there is red <sep> there is yellow <sep> there is blue",
- "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left",
- "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top",
- ]:
- primer += [primer_descr + " <img>"] * nb_per_primer
-
- result = self.tensorize(primer)
- fill = result.new_full(
- result.size()[:-1] + (self.height * self.width + 1,), self.t_nul
- )
- result = torch.cat((result, fill), 1)
- ar_mask = (result == self.t_nul).long()
- masked_inplace_autoregression(
- model,
- self.batch_size,
- result,
- ar_mask,
- deterministic_synthesis,
- device=self.device,
- )
- result_descr = self.detensorize(result)
-
- np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
-
- acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np)
- acc_nb_results = len(result_descr)
-
- nb_requested_properties = sum(acc_nb_requested_properties)
- nb_missing_properties = sum(acc_nb_missing_properties)
-
- prefix = "demo_"
- logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
- logger(
- f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
- )
- logger(
- f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
- )
-
- img = picoclvr.descr2img(result_descr, height=self.height, width=self.width)
-
- if img.dim() == 5:
- if img.size(1) == 1:
- img = F.pad(img.squeeze(1), pad=(1, 1, 1, 1), value=64)
- else:
- img = torch.cat(
- [
- torchvision.utils.make_grid(x, padding=1, pad_value=64)[None]
- for x in img
- ],
- 0,
- )
-
- image_name = os.path.join(result_dir, f"picoclvr_result_{n_epoch:04d}.png")
- torchvision.utils.save_image(
- img / 255.0, image_name, nrow=nb_per_primer, padding=1, pad_value=0.0
- )
- logger(f"wrote {image_name}")
-
-
-######################################################################
-
-
-class MNIST(Task):
- def __init__(
- self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu")
- ):
- super().__init__()
-
- self.nb_train_samples = (nb_train_samples,)
- self.nb_test_samples = (nb_test_samples,)
- self.batch_size = batch_size
- self.device = device
- data_set = torchvision.datasets.MNIST(root="./data", train=True, download=True)
- self.train_input = data_set.data[:nb_train_samples].view(-1, 28 * 28).long()
- data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True)
- self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long()
-
- def batches(self, split="train", nb_to_use=-1, desc=None):
- assert split in {"train", "test"}
- input = self.train_input if split == "train" else self.test_input
- if nb_to_use > 0:
- input = input[:nb_to_use]
- if desc is None:
- desc = f"epoch-{split}"
- for batch in tqdm.tqdm(
- input.split(self.batch_size), dynamic_ncols=True, desc=desc
- ):
- yield batch
-
- def vocabulary_size(self):
- return 256
-
- def produce_results(
- self, n_epoch, model, result_dir, logger, deterministic_synthesis
- ):
- results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
- ar_mask = torch.full_like(results, 1)
- masked_inplace_autoregression(
- model,
- self.batch_size,
- results,
- ar_mask,
- deterministic_synthesis,
- device=self.device,
- )
- image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png")
- torchvision.utils.save_image(
- 1 - results.reshape(-1, 1, 28, 28) / 255.0,
- image_name,
- nrow=16,
- pad_value=0.8,
- )
- logger(f"wrote {image_name}")
-
-
-######################################################################
-
-import maze
-
-
-class Maze(Task):
- def map2seq(self, *m):
- return torch.cat([x.flatten(1) for x in m], 1)
-
- def seq2map(self, s):
- s = s.reshape(s.size(0), -1, self.height, self.width)
- return (s[:, k] for k in range(s.size(1)))
-
- def __init__(
- self,
- nb_train_samples,
- nb_test_samples,
- batch_size,
- height,
- width,
- nb_walls,
- device=torch.device("cpu"),
- ):
- super().__init__()
-
- self.batch_size = batch_size
- self.height = height
- self.width = width
- self.device = device
-
- train_mazes, train_paths, _ = maze.create_maze_data(
- nb_train_samples,
- height=height,
- width=width,
- nb_walls=nb_walls,
- progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-train"),
- )
- self.train_input = self.map2seq(train_mazes.to(device), train_paths.to(device))
-
- test_mazes, test_paths, _ = maze.create_maze_data(
- nb_test_samples,
- height=height,
- width=width,
- nb_walls=nb_walls,
- progress_bar=lambda x: tqdm.tqdm(x, dynamic_ncols=True, desc=f"data-test"),
- )
- self.test_input = self.map2seq(test_mazes.to(device), test_paths.to(device))
-
- self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
-
- def batches(self, split="train", nb_to_use=-1, desc=None):
- assert split in {"train", "test"}
- input = self.train_input if split == "train" else self.test_input
- if nb_to_use > 0:
- input = input[:nb_to_use]
- if desc is None:
- desc = f"epoch-{split}"
- for batch in tqdm.tqdm(
- input.split(self.batch_size), dynamic_ncols=True, desc=desc
- ):
- yield batch
-
- def vocabulary_size(self):
- return self.nb_codes
-
- def compute_error(
- self, model, split="train", nb_to_use=-1, deterministic_synthesis=False
- ):
- model_device = next(model.parameters()).device
- nb_total, nb_correct = 0, 0
- count = torch.zeros(
- self.width * self.height,
- self.width * self.height,
- device=model_device,
- dtype=torch.int64,
- )
-
- for input in self.batches(split, nb_to_use):
- input = input.to(model_device)
- result = input.clone()
- ar_mask = result.new_zeros(result.size())
- ar_mask[:, self.height * self.width :] = 1
- result *= 1 - ar_mask
- masked_inplace_autoregression(
- model,
- self.batch_size,
- result,
- ar_mask,
- deterministic_synthesis,
- progress_bar_desc=None,
- device=self.device,
- )
- mazes, paths = self.seq2map(result)
- path_correctness = maze.path_correctness(mazes, paths)
- nb_correct += path_correctness.long().sum()
- nb_total += mazes.size(0)
-
- optimal_path_lengths = (
- (input[:, self.height * self.width :] == maze.v_path).long().sum(1)
- )
- predicted_path_lengths = (
- (result[:, self.height * self.width :] == maze.v_path).long().sum(1)
- )
- optimal_path_lengths = optimal_path_lengths[path_correctness]
- predicted_path_lengths = predicted_path_lengths[path_correctness]
- count[optimal_path_lengths, predicted_path_lengths] += 1
-
- if count.max() == 0:
- count = None
- else:
- count = count[
- : count.sum(1).nonzero().max() + 1, : count.sum(0).nonzero().max() + 1
- ]
-
- return nb_total, nb_correct, count
-
- def produce_results(
- self, n_epoch, model, result_dir, logger, deterministic_synthesis
- ):
- train_nb_total, train_nb_correct, count = self.compute_error(
- model,
- "train",
- nb_to_use=1000,
- deterministic_synthesis=deterministic_synthesis,
- )
- logger(
- f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
- )
-
- test_nb_total, test_nb_correct, count = self.compute_error(
- model,
- "test",
- nb_to_use=1000,
- deterministic_synthesis=deterministic_synthesis,
- )
- logger(
- f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
- )
-
- logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
-
- if count is not None:
- proportion_optimal = count.diagonal().sum().float() / count.sum()
- logger(f"proportion_optimal_test {proportion_optimal*100:.02f}%")
- with open(
- os.path.join(result_dir, f"maze_result_{n_epoch:04d}.txt"), "w"
- ) as f:
- for i in range(count.size(0)):
- for j in range(count.size(1)):
- eol = " " if j < count.size(1) - 1 else "\n"
- f.write(f"{count[i,j]}{eol}")
-
- input = self.test_input[:48].to(next(model.parameters()).device)
- result = input.clone()
- ar_mask = result.new_zeros(result.size())
- ar_mask[:, self.height * self.width :] = 1
- result *= 1 - ar_mask
- masked_inplace_autoregression(
- model,
- self.batch_size,
- result,
- ar_mask,
- deterministic_synthesis,
- device=self.device,
- )
-
- mazes, paths = self.seq2map(input)
- _, predicted_paths = self.seq2map(result)
-
- filename = os.path.join(result_dir, f"maze_result_{n_epoch:04d}.png")
- maze.save_image(
- filename,
- mazes=mazes,
- target_paths=paths,
- predicted_paths=predicted_paths,
- path_correct=maze.path_correctness(mazes, predicted_paths),
- path_optimal=maze.path_optimality(paths, predicted_paths),
- )
- logger(f"wrote {filename}")
-
-
-######################################################################
-
-
-import snake
-
-
-class Snake(Task):
- def __init__(
- self,
- nb_train_samples,
- nb_test_samples,
- batch_size,
- height,
- width,
- nb_colors,
- length,
- prompt_length,
- device=torch.device("cpu"),
- ):
- super().__init__()
-
- self.batch_size = batch_size
- self.height = height
- self.width = width
- self.device = device
- self.prompt_length = prompt_length
-
- self.train_input, self.train_prior_visits, _, _ = snake.generate_sequences(
- nb_train_samples,
- height,
- width,
- nb_colors,
- length,
- prompt_length,
- self.device,
- )
- self.test_input, self.test_prior_visits, _, _ = snake.generate_sequences(
- nb_test_samples,
- height,
- width,
- nb_colors,
- length,
- prompt_length,
- self.device,
- )
-
- self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
-
- def batches(self, split="train", nb_to_use=-1, desc=None):
- assert split in {"train", "test"}
- input = self.train_input if split == "train" else self.test_input
- if nb_to_use > 0:
- input = input[:nb_to_use]
- if desc is None:
- desc = f"epoch-{split}"
- for batch in tqdm.tqdm(
- input.split(self.batch_size), dynamic_ncols=True, desc=desc
- ):
- yield batch
-
- def vocabulary_size(self):
- return self.nb_codes
-
- def produce_results(
- self, n_epoch, model, result_dir, logger, deterministic_synthesis
- ):
- def compute_nb_correct(input, prior_visits):
- result = input.clone()
- i = torch.arange(result.size(1), device=result.device)[None, :]
- ar_mask = (
- torch.logical_and(i >= self.prompt_length * 2, i % 2 == 0)
- .long()
- .expand_as(result)
- )
- result *= 1 - ar_mask
-
- masked_inplace_autoregression(
- model,
- self.batch_size,
- result,
- ar_mask,
- deterministic_synthesis,
- device=self.device,
- )
-
- nb_total = ((prior_visits > 0) * ar_mask).sum()
-
- nb_correct = ((result == input).long() * (prior_visits > 0) * ar_mask).sum()
-
- return nb_total, nb_correct
-
- test_nb_total, test_nb_correct = compute_nb_correct(
- self.test_input[:1000], self.test_prior_visits[:1000]
- )
-
- logger(
- f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
- )
-
- logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
-
-
-######################################################################
-
-
-import stack
-
-
-class Stack(Task):
- def __init__(
- self,
- nb_train_samples,
- nb_test_samples,
- batch_size,
- logger,
- nb_steps,
- nb_stacks,
- nb_digits,
- fraction_values_for_train=None,
- device=torch.device("cpu"),
- ):
- super().__init__()
-
- self.batch_size = batch_size
- self.nb_steps = nb_steps
- self.nb_stacks = nb_stacks
- self.nb_digits = nb_digits
- self.device = device
-
- if fraction_values_for_train is None:
- values_for_train = None
- values_for_test = None
- else:
- all = torch.randperm(10**nb_digits)
- nb_for_train = int(all.size(0) * fraction_values_for_train)
- values_for_train = all[:nb_for_train]
- values_for_test = all[nb_for_train:]
-
- self.train_input, self.train_stack_counts = stack.generate_sequences(
- nb_train_samples,
- nb_steps,
- nb_stacks,
- nb_digits,
- values_for_train,
- self.device,
- )
-
- self.test_input, self.test_stack_counts = stack.generate_sequences(
- nb_test_samples,
- nb_steps,
- nb_stacks,
- nb_digits,
- values_for_test,
- self.device,
- )
-
- i = torch.logical_and(self.test_input % 2 == 1, self.test_input < 2 * nb_stacks)
- counts = self.test_stack_counts.flatten()[i.flatten()]
- counts = F.one_hot(counts).sum(0)
- logger(f"test_pop_stack_counts {counts}")
-
- self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
-
- def batches(self, split="train", nb_to_use=-1, desc=None):
- assert split in {"train", "test"}
- input = self.train_input if split == "train" else self.test_input
- if nb_to_use > 0:
- input = input[:nb_to_use]
- if desc is None:
- desc = f"epoch-{split}"
- for batch in tqdm.tqdm(
- input.split(self.batch_size), dynamic_ncols=True, desc=desc
- ):
- yield batch
-
- def vocabulary_size(self):
- return self.nb_codes
-
- def produce_results(
- self, n_epoch, model, result_dir, logger, deterministic_synthesis
- ):
- def compute_nb_correct(input):
- result = input.clone()
- stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
- ar_mask = (result != input).long()
- masked_inplace_autoregression(
- model,
- self.batch_size,
- result,
- ar_mask,
- deterministic_synthesis,
- device=self.device,
- )
-
- errors = ((result != input).long() * ar_mask).reshape(
- -1, 1 + self.nb_digits
- )
- ar_mask = ar_mask.reshape(-1, 1 + self.nb_digits)
-
- nb_total = ar_mask.max(1).values.sum()
- nb_correct = nb_total - errors.max(1).values.sum()
-
- return nb_total, nb_correct
-
- test_nb_total, test_nb_correct = compute_nb_correct(self.test_input[:1000])
-
- logger(
- f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
- )
-
- logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
-
- ##############################################################
- # Log a few generated sequences
- input = self.test_input[:10, : 12 * (1 + self.nb_digits)]
- result = input.clone()
- stack.remove_popped_values(result, self.nb_stacks, self.nb_digits)
- ar_mask = (result != input).long()
-
- # for n in range(result.size(0)):
- # logger(
- # f"test_before {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
- # )
-
- masked_inplace_autoregression(
- model,
- self.batch_size,
- result,
- ar_mask,
- deterministic_synthesis,
- device=self.device,
- )
-
- #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
- for label, input in [
- ("train", self.train_input[:32]),
- ("test", self.test_input[:32]),
- ]:
- output = model(BracketedSequence(input)).x
- output = output.log_softmax(dim=-1)
- filename = os.path.join(
- result_dir, f"stack_with_crossentropy_{n_epoch:04d}_{label}.txt"
- )
- with open(filename, "w") as f:
- for n in range(input.size(0)):
- s = stack.seq_to_str(
- input[n], nb_stacks=self.nb_stacks, nb_digits=self.nb_digits
- )
- for t, k, w in zip(range(input[n].size(0)), input[n], s.split(" ")):
- u = (
- " " * (10 - len(w))
- + w
- + " "
- + str(output[n][t][k].exp().item())
- + "\n"
- )
- f.write(u)
- f.write("\n")
- logger(f"wrote {filename}")
- #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-
- for n in range(result.size(0)):
- logger(
- f"test_after {stack.seq_to_str(result[n],nb_stacks=self.nb_stacks,nb_digits=self.nb_digits)}"
- )
- ##############################################################
-
-
-######################################################################
-
-import rpl
-
-
-class RPL(Task):
- def tensorize(self, sequences):
- len_max = max([len(x) for x in sequences])
- return torch.cat(
- [
- torch.tensor(
- [
- [
- self.token2id[str(c)]
- for c in s + ["<nul>"] * (len_max - len(s))
- ]
- for s in sequences
- ]
- )
- ],
- 0,
- )
-
- def seq2str(self, seq):
- return " ".join([self.id2token[i] for i in seq])
-
- def __init__(
- self,
- nb_train_samples,
- nb_test_samples,
- batch_size,
- nb_starting_values=3,
- max_input=9,
- prog_len=6,
- nb_runs=5,
- no_prog=False,
- logger=None,
- device=torch.device("cpu"),
- ):
- super().__init__()
-
- self.batch_size = batch_size
- self.device = device
- self.no_prog = no_prog
-
- train_sequences = [
- rpl.generate(
- nb_starting_values=nb_starting_values,
- nb_result_values_max=4 * nb_starting_values,
- max_input=max_input,
- prog_len=prog_len,
- nb_runs=nb_runs,
- )
- for _ in tqdm.tqdm(range(nb_train_samples), desc="train-data")
- ]
-
- test_sequences = [
- rpl.generate(
- nb_starting_values=nb_starting_values,
- nb_result_values_max=4 * nb_starting_values,
- max_input=max_input,
- prog_len=prog_len,
- nb_runs=nb_runs,
- )
- for _ in tqdm.tqdm(range(nb_test_samples), desc="test-data")
- ]
-
- symbols = list(
- set(["<nul>"] + [x for l in train_sequences + test_sequences for x in l])
- )
- val_max = max([x if type(x) is int else 0 for x in symbols])
- symbols = list(filter(lambda x: type(x) is str, symbols))
- symbols.sort()
- symbols += [str(n) for n in range(val_max + 1)]
- self.token2id = dict([(c, n) for n, c in enumerate(symbols)])
- self.id2token = dict([(n, c) for c, n in self.token2id.items()])
-
- self.t_nul = self.token2id["<nul>"]
- self.t_input = self.token2id["<in>"]
- self.t_output = self.token2id["<out>"]
- self.t_prog = self.token2id["<prg>"]
- self.t_end = self.token2id["<end>"]
-
- self.train_input = self.tensorize(train_sequences)
- self.test_input = self.tensorize(test_sequences)
-
- if no_prog:
- # Excise the program from every train and test example
- k = torch.arange(self.train_input.size(1), device=self.train_input.device)[
- None, :
- ]
- p = (
- ((self.train_input == self.t_prog).long() * k)
- .max(1, keepdim=True)
- .values
- )
- self.train_input = (
- self.train_input * (k <= p).long()
- + self.t_end * (k == p + 1).long()
- + self.t_nul * (k > p + 1).long()
- )
- k = torch.arange(self.test_input.size(1), device=self.test_input.device)[
- None, :
- ]
- p = (
- ((self.test_input == self.t_prog).long() * k)
- .max(1, keepdim=True)
- .values
- )
- self.test_input = (
- self.test_input * (k <= p).long()
- + self.t_end * (k == p + 1).long()
- + self.t_nul * (k > p + 1).long()
- )
-
- if logger is not None:
- logger(f"value_max {val_max}")
- for x in self.train_input[:25]:
- end = (x != self.t_nul).nonzero().max().item() + 1
- seq = [self.id2token[i.item()] for i in x[:end]]
- s = " ".join(seq)
- logger(f"example_seq {s}")
-
- self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
-
- def batches(self, split="train", nb_to_use=-1, desc=None):
- assert split in {"train", "test"}
- input = self.train_input if split == "train" else self.test_input
- if nb_to_use > 0:
- input = input[:nb_to_use]
- if desc is None:
- desc = f"epoch-{split}"
- for batch in tqdm.tqdm(
- input.split(self.batch_size), dynamic_ncols=True, desc=desc
- ):
- last = (batch != self.t_nul).max(0).values.nonzero().max() + 3
- batch = batch[:, :last].to(self.device)
- yield batch
-
- def vocabulary_size(self):
- return self.nb_codes
-
- def produce_results(
- self, n_epoch, model, result_dir, logger, deterministic_synthesis
- ):
- # --------------------------------------------------------------------
- def compute_nb_errors_prog(input, nb_to_log=0):
- result = input.clone()
- s = (result == self.t_prog).long()
- ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
- result = (1 - ar_mask) * result + ar_mask * self.t_nul
-
- masked_inplace_autoregression(
- model,
- self.batch_size,
- result,
- ar_mask,
- deterministic_synthesis,
- device=self.device,
- )
-
- sum_nb_total, sum_nb_errors = 0, 0
- for one_input, one_result in zip(input, result):
- seq = [self.id2token[i.item()] for i in one_result]
- nb_total, nb_errors, prog, stacks = rpl.compute_nb_errors(seq)
- sum_nb_total += 1
- sum_nb_errors += 0 if nb_errors == 0 else 1
- if nb_to_log > 0:
- gt_seq = [self.id2token[i.item()] for i in one_input]
- _, _, gt_prog, _ = rpl.compute_nb_errors(gt_seq)
- gt_prog = " ".join([str(x) for x in gt_prog])
- prog = " ".join([str(x) for x in prog])
- comment = "*" if nb_errors == 0 else "-"
- logger(f"{comment} PROG [{gt_prog}] PREDICTED [{prog}]")
- for start_stack, target_stack, result_stack, correct in stacks:
- comment = "*" if correct else "-"
- start_stack = " ".join([str(x) for x in start_stack])
- target_stack = " ".join([str(x) for x in target_stack])
- result_stack = " ".join([str(x) for x in result_stack])
- logger(
- f" {comment} [{start_stack}] -> [{target_stack}] PREDICTED [{result_stack}]"
- )
- nb_to_log -= 1
-
- return sum_nb_total, sum_nb_errors
-
- # --------------------------------------------------------------------
- def compute_nb_errors_output(input, nb_to_log=0):
- result = input.clone()
- k = torch.arange(result.size(1), device=result.device)[None, :]
- last_output_idx = (
- ((result == self.t_output) * k).max(dim=1, keepdim=True).values
- )
- first_prog_idx = (
- ((result == self.t_prog) * k).max(dim=1, keepdim=True).values
- )
- ar_mask = (k > last_output_idx).long() * (k < first_prog_idx).long()
- result = (1 - ar_mask) * result + ar_mask * self.t_nul
-
- masked_inplace_autoregression(
- model,
- self.batch_size,
- result,
- ar_mask,
- deterministic_synthesis,
- device=self.device,
- )
-
- sum_nb_total, sum_nb_errors = 0, 0
- for one_input, one_result, i, j in zip(
- input, result, last_output_idx, first_prog_idx
- ):
- seq = [self.id2token[i.item()] for i in one_result]
- sum_nb_total += 1
- correct = (one_input - one_result).abs().max() == 0
- sum_nb_errors += 0 if correct else 1
- if nb_to_log > 0:
- result_stack = [
- self.id2token[i.item()] for i in one_result[i : j + 1]
- ]
- target_stack = [
- self.id2token[i.item()] for i in one_input[i : j + 1]
- ]
- comment = "*" if correct else "-"
- result_stack = " ".join([str(x) for x in result_stack])
- target_stack = " ".join([str(x) for x in target_stack])
- logger(
- f"output_test {comment} [{target_stack}] PREDICTED [{result_stack}]"
- )
- nb_to_log -= 1
-
- return sum_nb_total, sum_nb_errors
-
- # --------------------------------------------------------------------
-
- if not self.no_prog:
- test_nb_total, test_nb_errors = compute_nb_errors_prog(
- self.test_input[:1000].to(self.device), nb_to_log=10
- )
-
- logger(
- f"accuracy_prog_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
- )
-
- logger(f"main_test_accuracy {n_epoch} {1-test_nb_errors/test_nb_total}")
-
- test_nb_total, test_nb_errors = compute_nb_errors_output(
- self.test_input[:1000].to(self.device), nb_to_log=10
- )
-
- logger(
- f"accuracy_output_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
- )
-
- if save_attention_image is None:
- logger("no save_attention_image (is pycairo installed?)")
- else:
- ns = torch.randint(self.test_input.size(0), (1,)).item()
- input = self.test_input[ns : ns + 1].clone()
- last = (input != self.t_nul).max(0).values.nonzero().max() + 3
- input = input[:, :last].to(self.device)
-
- with torch.autograd.no_grad():
- t = model.training
- model.eval()
- model.record_attention(True)
- model(BracketedSequence(input))
- model.train(t)
- ram = model.retrieve_attention()
- model.record_attention(False)
-
- tokens_output = [self.id2token[i.item()] for i in input[0]]
- tokens_input = ["n/a"] + tokens_output[:-1]
- for n_head in range(ram[0].size(1)):
- filename = os.path.join(
- result_dir, f"rpl_attention_{n_epoch}_h{n_head}.pdf"
- )
- attention_matrices = [m[0, n_head] for m in ram]
- save_attention_image(
- filename,
- tokens_input,
- tokens_output,
- attention_matrices,
- k_top=10,
- # min_total_attention=0.9,
- token_gap=12,
- layer_gap=50,
- )
- logger(f"wrote {filename}")
-
-
-######################################################################
-
-
-import expr
-
-
-class Expr(Task):
- def tensorize(self, sequences):
- len_max = max([len(x) for x in sequences])
- return torch.cat(
- [
- torch.tensor(
- [
- [self.char2id[c] for c in s + "#" * (len_max - len(s))]
- for s in sequences
- ]
- )
- ],
- 0,
- ).to(self.device)
-
- def __init__(
- self,
- nb_train_samples,
- nb_test_samples,
- nb_variables,
- sequence_length,
- operand_max,
- result_max,
- batch_size,
- device=torch.device("cpu"),
- ):
- super().__init__()
-
- self.batch_size = batch_size
- self.device = device
-
- train_sequences = expr.generate_sequences(
- nb_train_samples,
- nb_variables=nb_variables,
- length=sequence_length,
- operand_max=operand_max,
- result_max=result_max,
- )
-
- test_sequences = expr.generate_sequences(
- nb_test_samples,
- nb_variables=nb_variables,
- length=sequence_length,
- operand_max=operand_max,
- result_max=result_max,
- )
-
- symbols = list(set("#" + "".join(train_sequences + test_sequences)))
- symbols.sort()
-
- self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
- self.id2char = dict([(n, c) for c, n in self.char2id.items()])
-
- self.filler, self.space = self.char2id["#"], self.char2id[" "]
-
- self.train_input = self.tensorize(train_sequences)
- self.test_input = self.tensorize(test_sequences)
-
- self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
-
- def batches(self, split="train", nb_to_use=-1, desc=None):
- assert split in {"train", "test"}
- input = self.train_input if split == "train" else self.test_input
- if nb_to_use > 0:
- input = input[:nb_to_use]
- if desc is None:
- desc = f"epoch-{split}"
- for batch in tqdm.tqdm(
- input.split(self.batch_size), dynamic_ncols=True, desc=desc
- ):
- last = (batch != self.filler).max(0).values.nonzero().max() + 3
- batch = batch[:, :last]
- yield batch
-
- def vocabulary_size(self):
- return self.nb_codes
-
- def seq2str(self, s):
- return "".join([self.id2char[k.item()] for k in s])
-
- def produce_results(
- self,
- n_epoch,
- model,
- result_dir,
- logger,
- deterministic_synthesis,
- input_file=None,
- ):
- def compute_nb_correct(input):
- result = input.clone()
- s = (result == self.space).long()
- ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
- result = (1 - ar_mask) * result + ar_mask * self.filler
- masked_inplace_autoregression(
- model,
- self.batch_size,
- result,
- ar_mask,
- deterministic_synthesis,
- device=self.device,
- )
-
- nb_total = input.size(0)
- nb_correct = (input == result).long().min(1).values.sum()
-
- #######################################################################
- # Comput predicted vs. true variable values
-
- nb_delta = torch.zeros(5, dtype=torch.int64)
- nb_missed = 0
-
- values_input = expr.extract_results([self.seq2str(s) for s in input])
- values_result = expr.extract_results([self.seq2str(s) for s in result])
-
- filename = os.path.join(result_dir, f"expr_result_{n_epoch:04d}.txt")
-
- with open(filename, "w") as f:
- for i, r in zip(values_input, values_result):
- for n, vi in i.items():
- vr = r.get(n)
- f.write(f"{vi} {-1 if vr is None else vr}\n")
-
- if vr is None or vr < 0:
- nb_missed += 1
- else:
- d = abs(vr - vi)
- if d >= nb_delta.size(0):
- nb_missed += 1
- else:
- nb_delta[d] += 1
-
- ######################################################################
-
- return nb_total, nb_correct, nb_delta, nb_missed
-
- (
- test_nb_total,
- test_nb_correct,
- test_nb_delta,
- test_nb_missed,
- ) = compute_nb_correct(self.test_input[:10000])
-
- logger(
- f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
- )
-
- logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
-
- nb_total = test_nb_delta.sum() + test_nb_missed
- for d in range(test_nb_delta.size(0)):
- logger(
- f"error_value {n_epoch} delta {d} {test_nb_delta[d]} {test_nb_delta[d]*100/nb_total:.02f}%"
- )
- logger(
- f"error_value {n_epoch} missed {test_nb_missed} {test_nb_missed*100/nb_total:.02f}%"
- )
-
- ##############################################################
- # Log a few generated sequences
- if input_file is None:
- input = self.test_input[:10]
- else:
- with open(input_file, "r") as f:
- sequences = [e.strip() for e in f.readlines()]
- sequences = [s + " " + "#" * 50 for s in sequences]
- input = self.tensorize(sequences)
-
- result = input.clone()
- s = (result == self.space).long()
- ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1)
- result = (1 - ar_mask) * result + ar_mask * self.filler
-
- for n in range(result.size(0)):
- logger(f"test_before {self.seq2str(result[n])}")
-
- masked_inplace_autoregression(
- model,
- self.batch_size,
- result,
- ar_mask,
- deterministic_synthesis,
- device=self.device,
- )
-
- correct = (1 - ar_mask) * self.space + ar_mask * input
- for n in range(result.size(0)):
- comment = "GOOD" if (result[n] - input[n]).abs().max() == 0 else ""
- logger(f"test_after {self.seq2str(result[n])} {comment}")
- logger(f"truth {self.seq2str(correct[n])}")
- ##############################################################
-
-
-######################################################################
-
-import grid
-
-
-class Grid(Task):
- # Make a tensor from a list of strings
- def str2tensor(self, descr):
- token_descr = [s.strip().split(" ") for s in descr]
- l = max([len(s) for s in token_descr])
- token_descr = [s + ["#"] * (l - len(s)) for s in token_descr]
- id_descr = [[self.token2id[u] for u in s] for s in token_descr]
- return torch.tensor(id_descr, device=self.device)
-
- # Make a list of strings from a tensor
- def tensor2str(self, x):
- return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
-
- # trim all the tensors in the tuple z to remove as much token from
- # left and right in the first tensor. If z is a tuple, all its
- # elements are trimed according to the triming for the first
- def trim(self, z, token="#"):
- n = self.token2id[token]
- if type(z) == tuple:
- x = z[0]
- i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
- a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
- return tuple([t[:, a:b] for t in z])
- else:
- i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
- a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
- return z[:, a:b]
-
- ######################
-
- def __init__(
- self,
- nb_train_samples,
- nb_test_samples,
- batch_size,
- size,
- fraction_play=0.0,
- logger=None,
- device=torch.device("cpu"),
- ):
- super().__init__()
-
- self.device = device
- self.batch_size = batch_size
- self.grid_factory = grid.GridFactory(size=size)
- self.fraction_play = fraction_play
-
- if logger is not None:
- logger(
- f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
- )
-
- self.train_descr = self.grid_factory.generate_samples(
- nb=nb_train_samples,
- fraction_play=fraction_play,
- progress_bar=lambda r: tqdm.tqdm(r),
- )
-
- self.test_descr = self.grid_factory.generate_samples(
- nb=nb_test_samples, fraction_play=0.0, progress_bar=lambda r: tqdm.tqdm(r)
- )
-
- if fraction_play > 0:
- self.play_descr = self.grid_factory.generate_samples(
- nb=25, fraction_play=1.0, progress_bar=lambda r: tqdm.tqdm(r)
- )
- else:
- self.play_descr = []
-
- # Build the tokenizer
- tokens = set()
- for d in [self.train_descr, self.test_descr, self.play_descr]:
- for s in d:
- for t in s.strip().split(" "):
- tokens.add(t)
- # make this set a sorted list to get the same tensors given
- # the same descr
- tokens = list(tokens)
- tokens.sort()
- tokens = ["#"] + tokens
- self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
- self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
- self.t_nul = self.token2id["#"]
- self.t_true = self.token2id["true"]
- self.t_false = self.token2id["false"]
- # self.t_pipe = self.token2id["|"]
-
- # Tokenize the train and test sets
- self.train_input = self.str2tensor(self.train_descr)
- self.test_input = self.str2tensor(self.test_descr)
- self.play_input = (
- None if len(self.play_descr) == 0 else self.str2tensor(self.play_descr)
- )
-
- def batches(self, split="train", nb_to_use=-1, desc=None):
- assert split in {"train", "test"}
- input = self.train_input if split == "train" else self.test_input
- for batch in tqdm.tqdm(
- input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
- ):
- yield self.trim(batch)
-
- def vocabulary_size(self):
- return len(self.token2id)
-
- def produce_results(
- self, n_epoch, model, result_dir, logger, deterministic_synthesis
- ):
- correct = self.test_input[:1000]
- result = correct.clone()
- ar_mask = torch.logical_or(result == self.t_true, result == self.t_false).long()
- result *= 1 - ar_mask # paraaaaanoiaaaaaaa
-
- logger(f"----------------------------------------------------------")
-
- for e in self.tensor2str(result[:10]):
- logger(f"test_before {e}")
-
- masked_inplace_autoregression(
- model,
- self.batch_size,
- result,
- ar_mask,
- deterministic_synthesis,
- device=self.device,
- )
-
- logger(f"----------------------------------------------------------")
-
- for e in self.tensor2str(result[:10]):
- logger(f"test_after {e}")
-
- logger(f"----------------------------------------------------------")
-
- nb_total = ar_mask.sum().item()
- nb_correct = ((correct == result).long() * ar_mask).sum().item()
-
- logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}")
- logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}")
-
- if self.play_input is not None:
- result = self.play_input.clone()
- ar_mask = (result == self.t_pipe).long().cumsum(dim=1).clamp(max=1)
- result *= 1 - ar_mask # paraaaaanoiaaaaaaa
-
- logger(f"----------------------------------------------------------")
-
- for e in self.tensor2str(result[:10]):
- logger(f"play_before {e}")
-
- masked_inplace_autoregression(
- model,
- self.batch_size,
- result,
- ar_mask,
- deterministic_synthesis,
- device=self.device,
- )
-
- logger(f"----------------------------------------------------------")
-
- for e in self.tensor2str(result[:10]):
- logger(f"play_after {e}")
-
- logger(f"----------------------------------------------------------")
-
-
-######################################################################
-
-import qmlp
-
-
-class QMLP(Task):
- ######################
-
- def __init__(
- self,
- nb_train_samples,
- nb_test_samples,
- batch_size,
- result_dir,
- logger=None,
- device=torch.device("cpu"),
- ):
- super().__init__()
-
- self.device = device
- self.batch_size = batch_size
- self.nb_samples_per_mlp = 256
-
- if logger is not None:
- logger(
- f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
- )
-
- seq, q_test_set, test_error = qmlp.generate_sequence_and_test_set(
- nb_mlps=nb_train_samples + nb_test_samples,
- nb_samples=self.nb_samples_per_mlp,
- device=self.device,
- batch_size=64,
- nb_epochs=250,
- nb_mlps_per_batch=1024,
- )
-
- self.train_input = seq[:nb_train_samples]
- self.train_q_test_set = q_test_set[:nb_train_samples]
- self.train_ref_test_errors = test_error[:nb_train_samples]
- self.test_input = seq[nb_train_samples:]
- self.test_q_test_set = q_test_set[nb_train_samples:]
- self.test_ref_test_errors = test_error[nb_train_samples:]
-
- filename = os.path.join(result_dir, f"train_errors_ref.dat")
- with open(filename, "w") as f:
- for e in self.train_ref_test_errors:
- f.write(f"{e}\n")
-
- filename = os.path.join(result_dir, f"test_errors_ref.dat")
- with open(filename, "w") as f:
- for e in self.test_ref_test_errors:
- f.write(f"{e}\n")
-
- self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
-
- def batches(self, split="train", nb_to_use=-1, desc=None):
- assert split in {"train", "test"}
- input = self.train_input if split == "train" else self.test_input
- for batch in tqdm.tqdm(
- input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
- ):
- yield batch
-
- def vocabulary_size(self):
- return self.nb_codes
-
- def produce_results(
- self, n_epoch, model, result_dir, logger, deterministic_synthesis
- ):
- correct = self.test_input[:1000]
- result = correct.clone()
- ar_mask = (
- torch.arange(result.size(1), device=result.device)
- > self.nb_samples_per_mlp * 3 + 1
- ).long()[None, :]
- ar_mask = ar_mask.expand_as(result)
- result *= 1 - ar_mask # paraaaaanoiaaaaaaa
-
- masked_inplace_autoregression(
- model,
- self.batch_size,
- result,
- ar_mask,
- deterministic_synthesis,
- device=self.device,
- )
-
- q_train_set = result[:, : self.nb_samples_per_mlp * 3]
- q_params = result[:, self.nb_samples_per_mlp * 3 + 1 :]
- error_test = qmlp.evaluate_q_params(q_params, self.test_q_test_set)
-
- filename = os.path.join(result_dir, f"test_errors_{n_epoch:04d}.dat")
- with open(filename, "w") as f:
- for e in error_test:
- f.write(f"{e}\n")
-
-
-######################################################################
-
-import greed
-
-
-class Greed(Task):
- def __init__(
- self,
- nb_train_samples,
- nb_test_samples,
- batch_size,
- height,
- width,
- T,
- nb_walls,
- nb_coins,
- logger=None,
- device=torch.device("cpu"),
- ):
- super().__init__()
-
- self.batch_size = batch_size
- self.device = device
-
- self.world = greed.GreedWorld(height, width, T, nb_walls, nb_coins)
-
- states, actions, rewards = self.world.generate_episodes(
- nb_train_samples + nb_test_samples
- )
- seq = self.world.episodes2seq(states, actions, rewards)
- self.train_input = seq[:nb_train_samples].to(self.device)
- self.test_input = seq[nb_train_samples:].to(self.device)
-
- def wipe_lookahead_rewards(self, batch):
- t = torch.arange(batch.size(1), device=batch.device)[None, :]
- u = torch.randint(batch.size(1), (batch.size(0), 1), device=batch.device)
- lr_mask = (t <= u).long() * (
- t % self.world.it_len == self.world.index_lookahead_reward
- ).long()
-
- return (
- lr_mask * self.world.lookahead_reward2code(greed.REWARD_UNKNOWN)
- + (1 - lr_mask) * batch
- )
-
- def batches(self, split="train", nb_to_use=-1, desc=None):
- assert split in {"train", "test"}
- input = self.train_input if split == "train" else self.test_input
- if nb_to_use > 0:
- input = input[:nb_to_use]
- if desc is None:
- desc = f"epoch-{split}"
- for batch in tqdm.tqdm(
- input.split(self.batch_size), dynamic_ncols=True, desc=desc
- ):
- yield self.wipe_lookahead_rewards(batch)
-
- def vocabulary_size(self):
- return self.world.nb_codes
-
- def thinking_autoregression(
- self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
- ):
- snapshots = []
-
- def ar(result, ar_mask, logit_biases=None):
- ar_mask = ar_mask.expand_as(result)
- result *= 1 - ar_mask
- masked_inplace_autoregression(
- model,
- self.batch_size,
- result,
- ar_mask,
- deterministic_synthesis=deterministic_synthesis,
- logit_biases=logit_biases,
- device=self.device,
- progress_bar_desc=None,
- )
- warnings.warn("keeping thinking snapshots", RuntimeWarning)
- snapshots.append(result[:100].detach().clone())
-
- # Generate iteration after iteration
-
- result = self.test_input[:250].clone()
- # Erase all the content but that of the first iteration
- result[:, self.world.it_len :] = -1
- # Set the lookahead_reward of the firs to UNKNOWN
- result[:, self.world.index_lookahead_reward] = self.world.lookahead_reward2code(
- greed.REWARD_UNKNOWN
- )
-
- t = torch.arange(result.size(1), device=result.device)[None, :]
-
- for u in tqdm.tqdm(
- range(0, result.size(1), self.world.it_len),
- desc="thinking",
- ):
- # Generate the next state but keep the initial one, the
- # lookahead_reward of previous iterations are set to
- # UNKNOWN
- if u > 0:
- result[
- :, u + self.world.index_lookahead_reward
- ] = self.world.lookahead_reward2code(greed.REWARD_UNKNOWN)
- ar_mask = (t >= u + self.world.index_states).long() * (
- t < u + self.world.index_states + self.world.state_len
- ).long()
- ar(result, ar_mask)
-
- # Generate the action and reward with lookahead_reward to +1
- result[
- :, u + self.world.index_lookahead_reward
- ] = self.world.lookahead_reward2code(greed.REWARD_PLUS)
- ar_mask = (t >= u + self.world.index_reward).long() * (
- t <= u + self.world.index_action
- ).long()
- ar(result, ar_mask)
-
- # Set the lookahead_reward to UNKNOWN for the next iterations
- result[
- :, u + self.world.index_lookahead_reward
- ] = self.world.lookahead_reward2code(greed.REWARD_UNKNOWN)
-
- filename = os.path.join(result_dir, f"test_thinking_compute_{n_epoch:04d}.txt")
- with open(filename, "w") as f:
- for n in range(snapshots[0].size(0)):
- for s in snapshots:
- lr, s, a, r = self.world.seq2episodes(
- s[n : n + 1],
- )
- str = self.world.episodes2str(
- lr, s, a, r, unicode=True, ansi_colors=True
- )
- f.write(str)
- f.write("\n\n")
-
- # Saving the generated sequences
-
- lr, s, a, r = self.world.seq2episodes(result)
- str = self.world.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True)
-
- filename = os.path.join(result_dir, f"test_thinking_seq_{n_epoch:04d}.txt")
- with open(filename, "w") as f:
- f.write(str)
- logger(f"wrote {filename}")
-
- def produce_results(
- self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
- ):
- result = self.wipe_lookahead_rewards(self.test_input[:250].clone())
-
- # Saving the ground truth
-
- lr, s, a, r = self.world.seq2episodes(
- result,
- )
- str = self.world.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True)
-
- filename = os.path.join(result_dir, f"test_true_seq_{n_epoch:04d}.txt")
- with open(filename, "w") as f:
- f.write(str)
- logger(f"wrote {filename}")
-
- # Re-generating from the first frame
-
- ar_mask = (
- torch.arange(result.size(1), device=result.device) >= self.world.it_len
- ).long()[None, :]
- ar_mask = ar_mask.expand_as(result)
- result *= 1 - ar_mask # paraaaaanoiaaaaaaa
-
- masked_inplace_autoregression(
- model,
- self.batch_size,
- result,
- ar_mask,
- deterministic_synthesis,
- device=self.device,
- )
-
- # Saving the generated sequences
-
- lr, s, a, r = self.world.seq2episodes(
- result,
- )
- str = self.world.episodes2str(lr, s, a, r, unicode=True, ansi_colors=True)
-
- filename = os.path.join(result_dir, f"test_seq_{n_epoch:04d}.txt")
- with open(filename, "w") as f:
- f.write(str)
- logger(f"wrote {filename}")
-
- self.thinking_autoregression(
- n_epoch, model, result_dir, logger, deterministic_synthesis, nmax
- )
-
-
-######################################################################
######################################################################
import world
+++ /dev/null
-#!/usr/bin/env python
-
-import torch
-
-
-def generate_turing_sequences(N, nb_iter=5, nb_states=3, nb_symbols=4, tape_size=5):
- next_state = torch.randint(nb_states, (N, nb_states, nb_symbols))
- next_symbol = torch.randint(nb_symbols, (N, nb_states, nb_symbols))
- next_move = torch.randint(3, (N, nb_states, nb_symbols))
-
- all_n = torch.arange(N)
-
- tape = torch.randint(nb_symbols, (N, tape_size))
- # position = torch.randint(tape_size, (N,))
- # state = torch.randint(nb_states, (N,))
- position = torch.zeros(N, dtype=torch.int64)
- state = torch.zeros(N, dtype=torch.int64)
-
- result = []
-
- for _ in range(nb_iter):
- result.append(tape.clone())
- current_symbol = tape[all_n, position]
- tape[all_n, position] = next_symbol[all_n, state, current_symbol]
- position = (position + next_move[all_n, state, current_symbol] - 1) % tape_size
- state = next_state[all_n, state, current_symbol]
-
- result = torch.cat([x[:, None, :] for x in result], dim=1)
-
- return result
-
-
-######################################################################
-
-if __name__ == "__main__":
- print("Basic check.")
-
- tapes = generate_turing_sequences(1, nb_iter=10)
-
- for i in range(tapes.size(1)):
- # print(f"- {i:03d} ------------------------")
- # for s, h, r in zip(state, position, tape):
- # print("".join([f"{x}" for x in r]))
- # print(" " * h + f"^[{s}]")
- for r in tapes:
- print("".join([f"{x}" for x in r[i]]))