X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=tasks.py;h=f6d34a8d30f5439362521d374775f375c7c2b3ca;hb=f87a57354a1e575181e760fdaedbb2c2d5cf9fa0;hp=50d541b11ccab4fb1dca97815412e1e88d69fc93;hpb=9047bd8185ed99c1302d8812551af3d5bd4602cb;p=culture.git diff --git a/tasks.py b/tasks.py index 50d541b..f6d34a8 100755 --- a/tasks.py +++ b/tasks.py @@ -14,9 +14,6 @@ from torch.nn import functional as F from mygpt import BracketedSequence -# from graph import save_attention_image -save_attention_image = None - ###################################################################### @@ -220,7 +217,7 @@ class World(Task): self.save_image( result[:96], result_dir, - f"world_result_{n_epoch:04d}_{model.id:02d}.png", + f"world_prediction_{n_epoch:04d}_{model.id:02d}.png", logger, ) @@ -252,7 +249,7 @@ class World(Task): new_quizzes, ar_mask, deterministic_synthesis=False, - progress_bar_desc="new quizzes", + progress_bar_desc="creating quizzes", device=self.device, ) @@ -290,17 +287,12 @@ class World(Task): inverted_result, ar_mask, deterministic_synthesis=True, - progress_bar_desc="solving reverse quizzes", + progress_bar_desc="solving reversed quizzes", device=self.device, ) - nb_correct += ( - ( - (new_quizzes == result).long() - * (inverted_quizzes, inverted_result).long() - ) - .min(dim=-1) - .values - ) + nb_correct += (new_quizzes == result).long().min(dim=-1).values * ( + inverted_quizzes == inverted_result + ).long().min(dim=-1).values return new_quizzes, nb_correct