- self.pruner_train = pruner_train
- self.pruner_eval = pruner_eval
-
- param = {
- "nb_train_samples": nb_train_samples,
- "nb_test_samples": nb_test_samples,
- "height": height,
- "width": width,
- "nb_colors": nb_colors,
- "batch_size": batch_size,
- "rng_state": list(torch.get_rng_state()),
- }
-
- if logger is not None:
- logger(
- f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
- )
-
- self.train_descr = generate_descr(
- nb_train_samples, "train", pruner=self.pruner_train
- )
- self.test_descr = generate_descr(nb_test_samples, "test", pruner=None)
-
- # Build the tokenizer
- tokens = {"<nul>", "<img>"}
- for d in [self.train_descr, self.test_descr]:
- for s in d:
- for t in s.strip().split(" "):
- tokens.add(t)
- # make this set a sorted list to get the same tensors given
- # the same descr
- tokens = list(tokens)
- tokens.sort()
- self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
- self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
-
- # Tokenize the train and test sets
- self.train_input = self.tensorize(self.train_descr)
- self.test_input = self.tensorize(self.test_descr)
-
- def batches(self, split="train"):
- assert split in {"train", "test"}
- input = self.train_input if split == "train" else self.test_input
- for batch in tqdm.tqdm(
- input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
- ):
- yield self.trim(batch)
-
- def vocabulary_size(self):
- return len(self.token2id)
-
- def compute_missing_properties(
- self, n_epoch, model, logger, deterministic_synthesis, pruner=None
- ):
- acc_nb_requested_properties = []
- acc_nb_missing_properties = []
- acc_nb_results = 0
-
- for input in tqdm.tqdm(
- self.test_input.split(self.batch_size),
- dynamic_ncols=True,
- desc=f"test-properties",
- ):
- tape, loss_masks, _ = self.excise_last_image(input)
- tape, loss_masks = self.add_generated_image(
- tape, loss_masks, model, deterministic_synthesis
- )
- result_descr = self.detensorize(tape)
- np = picoclvr.nb_properties(
- result_descr,
- height=self.height,
- width=self.width,
- pruner=pruner,
- )
- nb_requested_properties, _, nb_missing_properties = zip(*np)
- acc_nb_requested_properties += nb_requested_properties
- acc_nb_missing_properties += nb_missing_properties
- acc_nb_results += len(result_descr)
-
- nb_requested_properties = sum(acc_nb_requested_properties)
- nb_missing_properties = sum(acc_nb_missing_properties)
-
- prefix = "" if pruner is None else "pruned_"
- logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
- logger(
- f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
- )
- logger(
- f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
- )
-
- ######################################################################
-
- def produce_results(
- self, n_epoch, model, result_dir, logger, deterministic_synthesis
- ):
- self.compute_missing_properties(n_epoch, model, logger, deterministic_synthesis)
-
- if self.pruner_eval is not None:
- self.compute_missing_properties(n_epoch, model, self.pruner_eval)
-
- nb_tokens_to_generate = self.height * self.width + 3
- result_descr = []
- nb_per_primer = 8
- primer = []
-
- for primer_descr in [
- "red above green <sep> green top <sep> blue right of red",
- "there is red <sep> there is yellow <sep> there is blue",
- "red below yellow <sep> yellow below green <sep> green below blue <sep> red right <sep> yellow left <sep> green right <sep> blue left",
- "green bottom <sep> yellow bottom <sep> green left of blue <sep> yellow right of blue <sep> blue top",
- ]:
- primer += [primer_descr] * nb_per_primer
-
- tape = self.tensorize(primer)
- loss_masks = 1 - (tape == self.token2id["<nul>"]).long()
- tape, loss_masks = self.add_generated_image(
- tape, loss_masks, model, deterministic_synthesis
- )
- result_descr = self.detensorize(tape)
-
- np = picoclvr.nb_properties(result_descr, height=self.height, width=self.width)
-
- acc_nb_requested_properties, _, acc_nb_missing_properties = zip(*np)
- acc_nb_results = len(result_descr)
-
- nb_requested_properties = sum(acc_nb_requested_properties)
- nb_missing_properties = sum(acc_nb_missing_properties)
-
- prefix = "demo_"
- logger(f"nb_{prefix}samples {n_epoch} {acc_nb_results}")
- logger(
- f"property_{prefix}nb {n_epoch} requested {sum(acc_nb_requested_properties)} missing {sum(acc_nb_missing_properties)}"
- )
- logger(
- f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
- )
-
- img = picoclvr.descr2img(result_descr, height=self.height, width=self.width)