From: François Fleuret Date: Wed, 17 Jul 2024 03:24:56 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=76436f85416f2c4f5cedd55b7fbe2a3864d04921;p=culture.git Update. --- diff --git a/main.py b/main.py index 76db5e2..74a3cfb 100755 --- a/main.py +++ b/main.py @@ -367,8 +367,8 @@ def one_epoch(model, quiz_machine, local_device=main_device): acc_train_loss += loss.item() * input.size(0) loss_per_samples = loss_per_token.detach().flatten(1).mean(dim=1) - n_forward = input[:, 0] == self.token_forward - to_store = from_w & n_forward + n_forward = input[:, 0] == quiz_machine.token_forward + to_store = from_w & n_forward.to("cpu") if to_store.any(): hard_w_quizzes.append( (input[to_store].to("cpu"), loss_per_samples[to_store].to("cpu"))