if args.dirty_debug:
args.accuracy_to_make_c_quizzes = 0.0
+ args.nb_gpts = 2
nb_new_c_quizzes_for_train = 100
nb_new_c_quizzes_for_test = 10
- args.nb_gpts = 2
######################################################################
c = c.long()[:, None]
c = (
(1 - ((c == 1).long() + (c == 0).long() + (c == -1).long()))
- * torch.tensor([192, 192, 192], device=c.device)
+ * torch.tensor([128, 128, 128], device=c.device)
+ (c == 1).long() * torch.tensor([0, 255, 0], device=c.device)
+ (c == 0).long() * torch.tensor([255, 255, 255], device=c.device)
+ (c == -1).long() * torch.tensor([255, 0, 0], device=c.device)