-class TaskExpr(Task):
- def __init__(
- self,
- nb_train_samples,
- nb_test_samples,
- nb_variables,
- sequence_length,
- batch_size,
- device=torch.device("cpu"),
- ):
- self.batch_size = batch_size
- self.device = device
-
- train_sequences = expr.generate_sequences(
- nb_train_samples,
- nb_variables=nb_variables,
- length=sequence_length,
- # length=2 * sequence_length,
- # randomize_length=True,
- )
- test_sequences = expr.generate_sequences(
- nb_test_samples,
- nb_variables=nb_variables,
- length=sequence_length,
- )
- self.char2id = dict(
- [
- (c, n)
- for n, c in enumerate(
- set("#" + "".join(train_sequences + test_sequences))
- )
- ]
- )
- self.id2char = dict([(n, c) for c, n in self.char2id.items()])
-
- self.filler, self.space = self.char2id["#"], self.char2id[" "]
-
- len_max = max([len(x) for x in train_sequences])
- self.train_input = torch.cat(
- [
- torch.tensor(
- [
- [self.char2id[c] for c in s + "#" * (len_max - len(s))]
- for s in train_sequences
- ]
- )
- ],
- 0,
- ).to(device)
-
- len_max = max([len(x) for x in test_sequences])
- self.test_input = torch.cat(
- [
- torch.tensor(
- [
- [self.char2id[c] for c in s + "#" * (len_max - len(s))]
- for s in test_sequences
- ]
- )
- ],
- 0,
- ).to(device)
-
- self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
-
- def batches(self, split="train", nb_to_use=-1, desc=None):
- assert split in {"train", "test"}
- input = self.train_input if split == "train" else self.test_input
- if nb_to_use > 0:
- input = input[:nb_to_use]
- if desc is None:
- desc = f"epoch-{split}"
- for batch in tqdm.tqdm(
- input.split(self.batch_size), dynamic_ncols=True, desc=desc
- ):
- if split == "train":
- last = (batch != self.filler).max(0).values.nonzero().max() + 1
- batch = batch[:, :last]
- yield batch
-
- def vocabulary_size(self):
- return self.nb_codes
-
- def seq2str(self, s):
- return "".join([self.id2char[k.item()] for k in s])
-
- def produce_results(self, n_epoch, model):
- with torch.autograd.no_grad():
- t = model.training
- model.eval()
-
- def compute_nb_correct(input):
- result = input.clone()
- ar_mask = (result == self.space).long().cumsum(dim=1).clamp(max=1)
- result = (1 - ar_mask) * result + ar_mask * self.filler
- masked_inplace_autoregression(
- model, self.batch_size, result, ar_mask, device=self.device
- )
-
- nb_total = input.size(0)
- nb_correct = (input == result).long().min(1).values.sum()
-
- #######################################################################
- # Comput predicted vs. true variable values
-
- values_input = expr.extract_results([self.seq2str(s) for s in input])
- max_input = max([max(x.values()) for x in values_input])
- values_result = expr.extract_results([self.seq2str(s) for s in result])
- max_result = max(
- [-1 if len(x) == 0 else max(x.values()) for x in values_result]
- )
-
- nb_missing = torch.zeros(max_input + 1)
- nb_predicted = torch.zeros(max_input + 1, max_result + 1)
-
- for i, r in zip(values_input, values_result):
- for n, vi in i.items():
- vr = r.get(n)
- if vr is None or vr < 0:
- nb_missing[vi] += 1
- else:
- nb_predicted[vi, vr] += 1
- ######################################################################
-
- return nb_total, nb_correct
-
- test_nb_total, test_nb_correct = compute_nb_correct(self.test_input[:1000])
-
- log_string(
- f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
- )
-
- ##############################################################
- # Log a few generated sequences
- input = self.test_input[:10]
- result = input.clone()
- ar_mask = (result == self.space).long().cumsum(dim=1).clamp(max=1)
- result = (1 - ar_mask) * result + ar_mask * self.filler
- for n in range(result.size(0)):
- log_string(f"test_before {self.seq2str(result[n])}")
- masked_inplace_autoregression(
- model, self.batch_size, result, ar_mask, device=self.device
- )
- correct = (1 - ar_mask) * self.space + ar_mask * input
- for n in range(result.size(0)):
- comment = "GOOD" if (result[n] - input[n]).abs().max() == 0 else ""
- log_string(f"test_after {self.seq2str(result[n])} {comment}")
- log_string(f"correct {self.seq2str(correct[n])}")
- ##############################################################