From 46645637edb8a39ed6a674696f9d78cc4603b805 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Sat, 6 Jul 2024 08:21:39 +0300 Subject: [PATCH] Update. --- main.py | 2 +- reasoning.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index 50e5611..02e1a8d 100755 --- a/main.py +++ b/main.py @@ -126,9 +126,9 @@ if args.result_dir is None: if args.dirty_debug: args.accuracy_to_make_c_quizzes = 0.0 + args.nb_gpts = 2 nb_new_c_quizzes_for_train = 100 nb_new_c_quizzes_for_test = 10 - args.nb_gpts = 2 ###################################################################### diff --git a/reasoning.py b/reasoning.py index 374d518..951e04a 100755 --- a/reasoning.py +++ b/reasoning.py @@ -117,7 +117,7 @@ class Reasoning(problem.Problem): c = c.long()[:, None] c = ( (1 - ((c == 1).long() + (c == 0).long() + (c == -1).long())) - * torch.tensor([192, 192, 192], device=c.device) + * torch.tensor([128, 128, 128], device=c.device) + (c == 1).long() * torch.tensor([0, 255, 0], device=c.device) + (c == 0).long() * torch.tensor([255, 255, 255], device=c.device) + (c == -1).long() * torch.tensor([255, 0, 0], device=c.device) -- 2.20.1