X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=problems.py;h=b8fcdb34b32a07e8fd46900c88d889c3cf9761ea;hb=76e62a5782fc2509ce989fcfc0d0aedc17322b3a;hp=40598568806400b80abe9af5ba64d8d92110c07e;hpb=8cb67a3cf972dbba5741b5b48d531c1e84439745;p=picoclvr.git diff --git a/problems.py b/problems.py index 4059856..b8fcdb3 100755 --- a/problems.py +++ b/problems.py @@ -313,8 +313,8 @@ class ProblemMixing(Problem): return y def start_error(self, x): - i = torch.arange(self.height).reshape(1,-1,1).expand_as(x) - j = torch.arange(self.width).reshape(1,1,-1).expand_as(x) + i = torch.arange(self.height, device=x.device).reshape(1,-1,1).expand_as(x) + j = torch.arange(self.width, device=x.device).reshape(1,1,-1).expand_as(x) ri = (x == self.height * self.width).long().sum(dim=-1).argmax(-1).view(-1,1,1) rj = (x == self.height * self.width).long().sum(dim=-2).argmax(-1).view(-1,1,1)