X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=cec670461cbccd25ab12c56bf6ceeadae63e6e6d;hb=e244104e7b697b79e8500b5d648ec161c4ed9a63;hp=8fe89beff6d5947fc55eb713b570ba4d48fb1a61;hpb=45a3c70758eb867106537ff7c20491bc32ef5f1e;p=picoclvr.git diff --git a/tasks.py b/tasks.py index 8fe89be..cec6704 100755 --- a/tasks.py +++ b/tasks.py @@ -911,7 +911,7 @@ class Expr(Task): test_nb_correct, test_nb_delta, test_nb_missed, - ) = compute_nb_correct(self.test_input[:1000]) + ) = compute_nb_correct(self.test_input[:10000]) logger( f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%" @@ -937,11 +937,12 @@ class Expr(Task): input = self.tensorize(sequences) result = input.clone() - ar_mask = (result == self.space).long().cumsum(dim=1).clamp(max=1) + s = (result == self.space).long() + ar_mask = (s.cumsum(dim=1) - s).clamp(min=0, max=1) result = (1 - ar_mask) * result + ar_mask * self.filler - # for n in range(result.size(0)): - # logger(f"test_before {self.seq2str(result[n])}") + for n in range(result.size(0)): + logger(f"test_before {self.seq2str(result[n])}") masked_inplace_autoregression( model, @@ -956,7 +957,7 @@ class Expr(Task): for n in range(result.size(0)): comment = "GOOD" if (result[n] - input[n]).abs().max() == 0 else "" logger(f"test_after {self.seq2str(result[n])} {comment}") - logger(f"correct {self.seq2str(correct[n])}") + logger(f"truth {self.seq2str(correct[n])}") ############################################################## model.train(t)