- pairs = []
- with open(filename, "r") as f:
- for _ in range(nb_train_samples + nb_test_samples):
- sequence = f.readline().strip()
- pred_mask = f.readline().strip()
- assert len(sequence) == len(pred_mask)
- assert set(pred_mask).issubset({"0", "1", "2"}), f"{set(pred_mask)}"
- pairs.append((sequence, pred_mask))
-
- symbols = ["#"] + list(set("".join([x[0] for x in pairs])) - set(["#"]))
+ def read_file(filename, nb=-1):
+ pairs = []
+ with open(filename, "r") as f:
+ while True:
+ sequence = f.readline().strip()
+ if not sequence:
+ break
+ pred_mask = f.readline().strip()
+ assert len(sequence) == len(pred_mask)
+ assert set(pred_mask).issubset({"0", "1", "2"}), f"{set(pred_mask)}"
+ pairs.append((sequence, pred_mask))
+ if len(pairs) == nb:
+ break
+
+ if nb > 0:
+ pairs = pairs[:nb]
+ assert len(pairs) == nb
+
+ return pairs
+
+ train_pairs = read_file(train_filename, nb_train_samples)
+ test_pairs = read_file(test_filename, nb_test_samples)
+
+ symbols = ["#"] + list(
+ set("".join([x[0] for x in train_pairs + test_pairs])) - set(["#"])
+ )