return properties
- def generate_example(self):
+ def generate_scene_and_questions(self):
while True:
while True:
scene = self.generate_scene()
if len(false) >= self.nb_questions:
break
+ # print(f"{a=}")
+
if a < 10:
break
true = [true[k] for k in torch.randperm(len(true))[: self.nb_questions]]
false = [false[k] for k in torch.randperm(len(false))[: self.nb_questions]]
- true = [(q, "yes") for q in true]
- false = [(q, "no") for q in false]
+ true = ["<prop> " + q + " <true>" for q in true]
+ false = ["<prop> " + q + " <false>" for q in false]
union = true + false
questions = [union[k] for k in torch.randperm(len(union))[: self.nb_questions]]
- return scene, questions
+ result = " ".join(
+ ["<obj> " + x for x in self.grid_positions(scene)] + questions
+ )
+
+ return scene, result
+
+ def generate_samples(self, nb, progress_bar=None):
+ result = []
+
+ r = range(nb)
+ if progress_bar is not None:
+ r = progress_bar(r)
+
+ for _ in r:
+ result.append(self.generate_scene_and_questions()[1])
+
+ return result
######################################################################
if __name__ == "__main__":
+ import time
+
grid_factory = GridFactory()
- scene, questions = grid_factory.generate_example()
+
+ start_time = time.perf_counter()
+ samples = grid_factory.generate_samples(10000)
+ end_time = time.perf_counter()
+ print(f"{len(samples) / (end_time - start_time):.02f} samples per second")
+
+ scene, questions = grid_factory.generate_scene_and_questions()
grid_factory.print_scene(scene)
print(questions)
"--task",
type=str,
default="twotargets",
- help="byheart, learnop, guessop, twotargets, addition, picoclvr, mnist, maze, snake, stack, expr, rpl",
+ help="byheart, learnop, guessop, twotargets, addition, picoclvr, mnist, maze, snake, stack, expr, rpl, grid",
)
parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
"nb_train_samples": 25000,
"nb_test_samples": 1000,
},
+ "grid": {
+ "model": "37M",
+ "nb_epochs": 25,
+ "batch_size": 25,
+ "nb_train_samples": 250000,
+ "nb_test_samples": 10000,
+ },
}
if args.task in default_task_args:
device=device,
)
+elif args.task == "grid":
+ task = tasks.Grid(
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.batch_size,
+ height=args.picoclvr_height,
+ width=args.picoclvr_width,
+ logger=log_string,
+ device=device,
+ )
+
elif args.task == "world":
task = tasks.World(
nb_train_samples=args.nb_train_samples,
##############################################################
+######################################################################
+
+import grid
+
+
+class Grid(Task):
+ # Make a tensor from a list of strings
+ def tensorize(self, descr):
+ token_descr = [s.strip().split(" ") for s in descr]
+ l = max([len(s) for s in token_descr])
+ token_descr = [s + ["<nul>"] * (l - len(s)) for s in token_descr]
+ id_descr = [[self.token2id[u] for u in s] for s in token_descr]
+ return torch.tensor(id_descr, device=self.device)
+
+ # Make a list of strings from a tensor
+ def detensorize(self, x):
+ return [" ".join([self.id2token[t.item()] for t in r]) for r in x]
+
+ # trim all the tensors in the tuple z to remove as much token from
+ # left and right in the first tensor. If z is a tuple, all its
+ # elements are trimed according to the triming for the first
+ def trim(self, z, token="<nul>"):
+ n = self.token2id[token]
+ if type(z) == tuple:
+ x = z[0]
+ i = (1 - (F.pad(x, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
+ a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
+ return tuple([t[:, a:b] for t in z])
+ else:
+ i = (1 - (F.pad(z, (1, 1), value=n) == n).min(0).values.long()).cumsum(0)
+ a, b = (i == 0).nonzero().max(), (i == i.max()).nonzero().min()
+ return z[:, a:b]
+
+ ######################
+
+ def __init__(
+ self,
+ nb_train_samples,
+ nb_test_samples,
+ batch_size,
+ height,
+ width,
+ logger=None,
+ device=torch.device("cpu"),
+ ):
+ super().__init__()
+
+ self.device = device
+ self.batch_size = batch_size
+ self.grid_factory = grid.GridFactory(height=height, width=width)
+
+ if logger is not None:
+ logger(
+ f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
+ )
+
+ self.train_descr = self.grid_factory.generate_samples(
+ nb_train_samples, lambda r: tqdm.tqdm(r)
+ )
+ self.test_descr = self.grid_factory.generate_samples(
+ nb_test_samples, lambda r: tqdm.tqdm(r)
+ )
+
+ # Build the tokenizer
+ tokens = {}
+ for d in [self.train_descr, self.test_descr]:
+ for s in d:
+ for t in s.strip().split(" "):
+ tokens.add(t)
+ # make this set a sorted list to get the same tensors given
+ # the same descr
+ tokens = list(tokens)
+ tokens.sort()
+ tokens = ["<nul>"] + tokens
+ self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
+ self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
+ self.t_nul = self.token2id["<nul>"]
+ self.t_true = self.token2id["<true>"]
+ self.t_false = self.token2id["<false>"]
+
+ # Tokenize the train and test sets
+ self.train_input = self.tensorize(self.train_descr)
+ self.test_input = self.tensorize(self.test_descr)
+
+ def batches(self, split="train"):
+ assert split in {"train", "test"}
+ input = self.train_input if split == "train" else self.test_input
+ for batch in tqdm.tqdm(
+ input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
+ ):
+ yield self.trim(batch)
+
+ def vocabulary_size(self):
+ return len(self.token2id)
+
+ def produce_results(
+ self, n_epoch, model, result_dir, logger, deterministic_synthesis
+ ):
+ correct = self.test_input[:1000]
+ result = correct.clone()
+ ar_mask = torch.logical_or(result == self.t_true, result == self.t_false).long()
+ result *= 1 - ar_mask
+
+ for e in self.detensorize(result[:10]):
+ logger(f"test_before {e}")
+
+ masked_inplace_autoregression(
+ model,
+ self.batch_size,
+ result,
+ ar_mask,
+ deterministic_synthesis,
+ device=self.device,
+ )
+
+ for e in self.detensorize(result[:10]):
+ logger(f"test_after {e}")
+
+ nb_total = ar_mask.sum().item()
+ nb_correct = ((correct == result).long() * ar_mask).sum().item()
+
+ logger(f"test_performance {nb_total=} {nb_correct=}")
+ logger(f"main_test_accuracy {nb_correct / nb_total}")
+
+
######################################################################
import world