##############################
# filetask
-parser.add_argument("--filetask_file", type=str, default=None)
+parser.add_argument("--filetask_train_file", type=str, default=None)
+
+parser.add_argument("--filetask_test_file", type=str, default=None)
##############################
# rpl options
if args.task == "file":
assert (
- args.filetask_file is not None
- ), "You have to specify the task file with --filetask_file <filename>"
+ args.filetask_train_file is not None and args.filetask_test_file is not None
+ ), "You have to specify the task train and test files"
task = tasks.TaskFromFile(
- args.filetask_file,
+ args.filetask_train_file,
+ args.filetask_test_file,
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
batch_size=args.batch_size,
def __init__(
self,
- filename,
+ train_filename,
+ test_filename,
nb_train_samples,
nb_test_samples,
batch_size,
self.batch_size = batch_size
self.device = device
- pairs = []
- with open(filename, "r") as f:
- for _ in range(nb_train_samples + nb_test_samples):
- sequence = f.readline().strip()
- pred_mask = f.readline().strip()
- assert len(sequence) == len(pred_mask)
- assert set(pred_mask).issubset({"0", "1", "2"}), f"{set(pred_mask)}"
- pairs.append((sequence, pred_mask))
-
- symbols = ["#"] + list(set("".join([x[0] for x in pairs])) - set(["#"]))
+ def read_file(filename, nb=-1):
+ pairs = []
+ with open(filename, "r") as f:
+ while True:
+ sequence = f.readline().strip()
+ if not sequence:
+ break
+ pred_mask = f.readline().strip()
+ assert len(sequence) == len(pred_mask)
+ assert set(pred_mask).issubset({"0", "1", "2"}), f"{set(pred_mask)}"
+ pairs.append((sequence, pred_mask))
+ if len(pairs) == nb:
+ break
+
+ if nb > 0:
+ pairs = pairs[:nb]
+ assert len(pairs) == nb
+
+ return pairs
+
+ train_pairs = read_file(train_filename, nb_train_samples)
+ test_pairs = read_file(test_filename, nb_test_samples)
+
+ symbols = ["#"] + list(
+ set("".join([x[0] for x in train_pairs + test_pairs])) - set(["#"])
+ )
self.char2id = dict([(c, n) for n, c in enumerate(symbols)])
self.id2char = dict([(n, c) for c, n in self.char2id.items()])
- self.train_input, self.train_pred_masks = self.tensorize(
- pairs[:nb_train_samples]
- )
- self.test_input, self.test_pred_masks = self.tensorize(pairs[nb_train_samples:])
-
- assert self.train_input.size(0) == nb_train_samples
- assert self.test_input.size(0) == nb_test_samples
+ self.train_input, self.train_pred_masks = self.tensorize(train_pairs)
+ self.test_input, self.test_pred_masks = self.tensorize(test_pairs)
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
logger(f"----------------------------------------------------------")
- for e in self.tensor2str(result[:10]):
+ for e in self.tensor2str(result[:50]):
logger(f"test_before {e}")
masked_inplace_autoregression(
logger(f"----------------------------------------------------------")
- for e, c in zip(self.tensor2str(result[:10]), self.tensor2str(correct[:10])):
+ for e, c in zip(self.tensor2str(result[:50]), self.tensor2str(correct[:50])):
logger(f"test_after {e}")
logger(f"correct {c}")