-class Problem:
- def generate_sequences(self, nb):
- pass
-
- def log_performance(self, sequences, logger):
- pass
-
-
-class ProblemByheart(Problem):
- def __init__(self):
- nb_seq, len_prompt, len_result = 100, 5, 5
- self.seq = torch.randint(10, (nb_seq, len_prompt + 1 + len_result))
- self.seq[:, len_prompt] = 10
-
- def generate_sequences(self, nb):
- sequences = self.seq[torch.randint(self.seq.size(0), (nb,))]
- ar_mask = (sequences==10).long()
- ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
- return sequences, ar_mask
-
- # problems = [ProblemByheart()]
- # nb_common_codes = 100
-
- # def generate_sequences(nb_samples):
- # problem_indexes = torch.randint(len(problems), (nb_samples,))
- # nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0)
- # print(f"{nb_samples_per_problem}")
- # all_seq = []
- # for nb, p in zip(nb_samples_per_problem, problems):
- # all_seq.append(p.generate_sequences(nb_samples_per_problem[nb]))
- # return all_seq
-
- # for strain, stest in zip(train_seq, test_seq):
- # s = torch.cat((strain, stest), 0)
-
-class SandBox(Task):
- def __init__(
- self,
- problem,
- nb_train_samples,
- nb_test_samples,
- batch_size,
- logger=None,
- device=torch.device("cpu"),
- ):
- super().__init__()
-
- self.batch_size = batch_size
- self.device = device
-
- self.train_input, self.train_ar_mask = problem.generate_sequences(nb_train_samples)
- self.test_input, self.test_ar_mask = problem.generate_sequences(nb_test_samples)
-
- self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
-
- def batches(self, split="train", nb_to_use=-1, desc=None):
- assert split in {"train", "test"}
- input = self.train_input if split == "train" else self.test_input
- if nb_to_use > 0:
- input = input[:nb_to_use]
- if desc is None:
- desc = f"epoch-{split}"
- for batch in tqdm.tqdm(
- input.split(self.batch_size), dynamic_ncols=True, desc=desc
- ):
- yield batch
-
- def vocabulary_size(self):
- return self.nb_codes
-
- def produce_results(
- self, n_epoch, model, result_dir, logger, deterministic_synthesis
- ):
-
- def compute_accuracy(input, ar_mask):
- result = input.clone() * (1-ar_mask)
- masked_inplace_autoregression(
- model,
- self.batch_size,
- result,
- ar_mask,
- deterministic_synthesis,
- progress_bar_desc=None,
- device=self.device,
- )
-
- nb_total = ar_mask.sum().item()
- nb_correct = ((result==input).long() * ar_mask).sum().item()
-
- return nb_total, nb_correct
-
- train_nb_total, train_nb_correct = compute_accuracy(self.train_input, self.train_ar_mask)
-
- logger(
- f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
- )
-
- test_nb_total, test_nb_correct = compute_accuracy(self.test_input, self.test_ar_mask)
-
- logger(
- f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
- )
-
-######################################################################