X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=78910a06d7a24977b462af0836a9219073381507;hb=b35745d09b33aed20670ecb96726f89206487a24;hp=00b7a493481653d8eee786a5f3bf2299dacc3b01;hpb=a3f5378c684fb58b0dc839638e768c3a4b8e8a83;p=picoclvr.git diff --git a/tasks.py b/tasks.py index 00b7a49..78910a0 100755 --- a/tasks.py +++ b/tasks.py @@ -132,16 +132,13 @@ class TaskFromFile(Task): sequence = f.readline().strip() pred_mask = f.readline().strip() assert len(sequence) == len(pred_mask) - assert set(pred_mask) == {"0", "1", "2"}, f"{set(pred_mask)}" + assert set(pred_mask).issubset({"0", "1", "2"}), f"{set(pred_mask)}" pairs.append((sequence, pred_mask)) symbols = ["#"] + list(set("".join([x[0] for x in pairs])) - set(["#"])) - print("SANITY", symbols) self.char2id = dict([(c, n) for n, c in enumerate(symbols)]) self.id2char = dict([(n, c) for c, n in self.char2id.items()]) - print(self.char2id) - self.train_input, self.train_pred_masks = self.tensorize( pairs[:nb_train_samples] ) @@ -163,7 +160,6 @@ class TaskFromFile(Task): return len(self.char2id) def tensor2str(self, t): - print(f"{type(t)=}") return ["".join([self.id2char[x.item()] for x in s]) for s in t] def produce_results(