def generate_sequences(self, nb):
sequences = self.seq[torch.randint(self.seq.size(0), (nb,))]
- ar_mask = (sequences==10).long()
+ ar_mask = (sequences == 10).long()
ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
return sequences, ar_mask
# nb_common_codes = 100
# def generate_sequences(nb_samples):
- # problem_indexes = torch.randint(len(problems), (nb_samples,))
- # nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0)
- # print(f"{nb_samples_per_problem}")
- # all_seq = []
- # for nb, p in zip(nb_samples_per_problem, problems):
- # all_seq.append(p.generate_sequences(nb_samples_per_problem[nb]))
- # return all_seq
+ # problem_indexes = torch.randint(len(problems), (nb_samples,))
+ # nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0)
+ # print(f"{nb_samples_per_problem}")
+ # all_seq = []
+ # for nb, p in zip(nb_samples_per_problem, problems):
+ # all_seq.append(p.generate_sequences(nb_samples_per_problem[nb]))
+ # return all_seq
# for strain, stest in zip(train_seq, test_seq):
- # s = torch.cat((strain, stest), 0)
+ # s = torch.cat((strain, stest), 0)
+
class SandBox(Task):
def __init__(
batch_size,
logger=None,
device=torch.device("cpu"),
+ max_nb_codes=1024,
):
super().__init__()
self.batch_size = batch_size
self.device = device
- self.train_input, self.train_ar_mask = problem.generate_sequences(nb_train_samples)
+ self.train_input, self.train_ar_mask = problem.generate_sequences(
+ nb_train_samples
+ )
self.test_input, self.test_ar_mask = problem.generate_sequences(nb_test_samples)
self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+ # A bit of paranoia never hurts
+ assert (
+ self.nb_codes <= max_nb_codes
+ and self.train_input.min() >= 0
+ and self.test_input.min() >= 0
+ and tuple(self.train_ar_mask.unique()) == (0, 1)
+ and tuple(self.test_ar_mask.unique()) == (0, 1)
+ )
+
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
input = self.train_input if split == "train" else self.test_input
def produce_results(
self, n_epoch, model, result_dir, logger, deterministic_synthesis
):
-
def compute_accuracy(input, ar_mask):
- result = input.clone() * (1-ar_mask)
+ result = input.clone() * (1 - ar_mask)
masked_inplace_autoregression(
model,
self.batch_size,
)
nb_total = ar_mask.sum().item()
- nb_correct = ((result==input).long() * ar_mask).sum().item()
+ nb_correct = ((result == input).long() * ar_mask).sum().item()
return nb_total, nb_correct
- train_nb_total, train_nb_correct = compute_accuracy(self.train_input, self.train_ar_mask)
+ train_nb_total, train_nb_correct = compute_accuracy(
+ self.train_input, self.train_ar_mask
+ )
logger(
f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
)
- test_nb_total, test_nb_correct = compute_accuracy(self.test_input, self.test_ar_mask)
+ test_nb_total, test_nb_correct = compute_accuracy(
+ self.test_input, self.test_ar_mask
+ )
logger(
f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
)
+
######################################################################
import picoclvr