X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=af71b85ed7de9d0639b5fc4e95351693608be030;hb=b59fca62aa31de18a3e0cd0bb54e395d4b1254ae;hp=ca71182345e40b1123a9a32703a4d34b10066280;hpb=27dd45177a8552a8b8c9e3a4d9388844fa1d4d27;p=picoclvr.git diff --git a/tasks.py b/tasks.py index ca71182..af71b85 100755 --- a/tasks.py +++ b/tasks.py @@ -1111,6 +1111,7 @@ class RPL(Task): self.test_input = self.tensorize(test_sequences) if no_prog: + # Excise the program from every train and test example k = torch.arange(self.train_input.size(1), device=self.train_input.device)[ None, : ] @@ -1185,13 +1186,13 @@ class RPL(Task): ) sum_nb_total, sum_nb_errors = 0, 0 - for x, y in zip(input, result): - seq = [self.id2token[i.item()] for i in y] + for one_input, one_result in zip(input, result): + seq = [self.id2token[i.item()] for i in one_result] nb_total, nb_errors, prog, stacks = rpl.compute_nb_errors(seq) sum_nb_total += 1 sum_nb_errors += 0 if nb_errors == 0 else 1 if nb_to_log > 0: - gt_seq = [self.id2token[i.item()] for i in x] + gt_seq = [self.id2token[i.item()] for i in one_input] _, _, gt_prog, _ = rpl.compute_nb_errors(gt_seq) gt_prog = " ".join([str(x) for x in gt_prog]) prog = " ".join([str(x) for x in prog]) @@ -1232,14 +1233,20 @@ class RPL(Task): ) sum_nb_total, sum_nb_errors = 0, 0 - for x, y, i, j in zip(input, result, last_output_idx, first_prog_idx): - seq = [self.id2token[i.item()] for i in y] + for one_input, one_result, i, j in zip( + input, result, last_output_idx, first_prog_idx + ): + seq = [self.id2token[i.item()] for i in one_result] sum_nb_total += 1 - correct = (x - y).abs().max() == 0 + correct = (one_input - one_result).abs().max() == 0 sum_nb_errors += 0 if correct else 1 if nb_to_log > 0: - result_stack = [self.id2token[i.item()] for i in y[i : j + 1]] - target_stack = [self.id2token[i.item()] for i in x[i : j + 1]] + result_stack = [ + self.id2token[i.item()] for i in one_result[i : j + 1] + ] + target_stack = [ + self.id2token[i.item()] for i in one_input[i : j + 1] + ] comment = "*" if correct else "-" result_stack = " ".join([str(x) for x in result_stack]) target_stack = " ".join([str(x) for x in target_stack])