+ while True:
+ a = (col.flatten() >= 0).nonzero()
+ a = a[torch.randint(a.size(0), (1,)).item()]
+ i, j = a // self.size, a % self.size
+ assert col[i, j] >= 0
+ dst = [(i, j), (i - 1, j), (i + 1, j), (i, j - 1), (i, j + 1)]
+ dst = list(
+ filter(
+ lambda x: x[0] >= 0
+ and x[1] >= 0
+ and x[0] < self.size
+ and x[1] < self.size
+ and col[x[0], x[1]] < 0,
+ dst,
+ )
+ )
+ if len(dst) > 0:
+ ni, nj = dst[torch.randint(len(dst), (1,)).item()]
+ col[ni, nj] = col[i, j]
+ shp[ni, nj] = shp[i, j]
+ col[i, j] = -1
+ shp[i, j] = -1
+ break
+
+ return col, shp
+
+ def transformation(self, t, scene):
+ col, shp = scene
+ if t == 0:
+ col, shp = col.flip(0), shp.flip(0)
+ description = "<chg> vertical flip"
+ elif t == 1:
+ col, shp = col.flip(1), shp.flip(1)
+ description = "<chg> horizontal flip"
+ elif t == 2:
+ col, shp = col.flip(0).t(), shp.flip(0).t()
+ description = "<chg> rotate 90 degrees"
+ elif t == 3:
+ col, shp = col.flip(0).flip(1), shp.flip(0).flip(1)
+ description = "<chg> rotate 180 degrees"
+ elif t == 4:
+ col, shp = col.flip(1).t(), shp.flip(1).t()
+ description = "<chg> rotate 270 degrees"
+
+ return (col.contiguous(), shp.contiguous()), description
+
+ def random_transformations(self, scene):