entropy = -torch.xlogy(token_probas, token_probas).sum()
train_set_perplexity = math.exp(entropy)
-##############################
-
+######################################################################
# A bit of paranoia never hurts
-train_examples = {}
+def subsets_as_tuples(batches, cs):
+ s = set()
+ for batch in batches:
+ for x in batch:
+ s.add(tuple([v.item() for v in x]))
+ if len(s) == cs:
+ yield s
+ s = set()
+ yield s
-for input in task.batches(split="train"):
- assert input.dim() == 2 and input.dtype == torch.int64
- for x in input:
- train_examples[x.sum().item()] = x
-
-nb_total, nb_collisions = 0, 0
-for input in task.batches(split="test"):
- assert input.dim() == 2 and input.dtype == torch.int64
- for x in input:
- nb_total += 1
- y = train_examples.get(x.sum().item())
- if y is not None:
- if x.size() == y.size() and (x - y).abs().sum() == 0:
- nb_collisions += 1
-
-del train_examples
+
+nb_test, nb_in_train = 0, 0
+for test_subset in subsets_as_tuples(task.batches(split="test"), 25000):
+ in_train = set()
+ for train_subset in subsets_as_tuples(task.batches(split="train"), 25000):
+ in_train.update(test_subset.intersection(train_subset))
+ nb_in_train += len(in_train)
+ nb_test += len(test_subset)
log_string(
- f"data_check {nb_collisions*100/nb_total:.02f}% ({nb_collisions}/{nb_total}) of test samples are in the train set"
+ f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set"
)
+assert (
+ nb_in_train <= nb_test // 100
+), "More than 1% of test samples are in the train set"
+
##############################
if args.learning_rate_schedule == "cos":
class ProblemLenId(Problem):
- def __init__(self, nb_sentences=100, len_max=5):
+ def __init__(self, len_max=10):
self.len_max = len_max
def generate_sequences(self, nb):
+ (k > a) * (k < b) * i[1]
+ (k == b) * 11
+ (k > b) * (k < c) * i[1]
- + (k == c) * 12
- + (k > c) * 13
+ + (k >= c) * 12
)
ar_mask = (sequences == 11).long()
ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
return sequences, ar_mask
def seq2str(self, seq):
- return "".join("0123456789|>.?"[x.item()] for x in seq)
+ return "".join("0123456789|>_"[x.item()] for x in seq)
####################