parser.add_argument("--mixing_hard", action="store_true", default=False)
+parser.add_argument("--mixing_deterministic_start", action="store_true", default=False)
+
######################################################################
args = parser.parse_args()
elif args.task == "mixing":
task = tasks.SandBox(
- problem=problems.ProblemMixing(hard=args.mixing_hard),
+ problem=problems.ProblemMixing(
+ hard=args.mixing_hard, random_start=not args.mixing_deterministic_start
+ ),
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
batch_size=args.batch_size,
class ProblemMixing(Problem):
- def __init__(self, height=4, width=4, nb_time_steps=9, hard=False):
+ def __init__(
+ self, height=4, width=4, nb_time_steps=9, hard=False, random_start=True
+ ):
self.height = height
self.width = width
self.nb_time_steps = nb_time_steps
self.hard = hard
+ self.random_start = random_start
def start_random(self, nb):
y = torch.arange(self.height * self.width).reshape(1, -1).expand(nb, -1)
- # m = (torch.rand(y.size()).sort(dim=-1).indices < y.size(1) // 2).long()
+ if self.random_start:
+ i = (
+ torch.arange(self.height)
+ .reshape(1, -1, 1)
+ .expand(nb, self.height, self.width)
+ )
+ j = (
+ torch.arange(self.width)
+ .reshape(1, 1, -1)
+ .expand(nb, self.height, self.width)
+ )
- i = (
- torch.arange(self.height)
- .reshape(1, -1, 1)
- .expand(nb, self.height, self.width)
- )
- j = (
- torch.arange(self.width)
- .reshape(1, 1, -1)
- .expand(nb, self.height, self.width)
- )
+ ri = torch.randint(self.height, (nb,)).reshape(nb, 1, 1)
+ rj = torch.randint(self.width, (nb,)).reshape(nb, 1, 1)
- ri = torch.randint(self.height, (nb,)).reshape(nb, 1, 1)
- rj = torch.randint(self.width, (nb,)).reshape(nb, 1, 1)
+ m = 1 - torch.logical_or(i == ri, j == rj).long().flatten(1)
- m = 1 - torch.logical_or(i == ri, j == rj).long().flatten(1)
+ y = y * m + self.height * self.width * (1 - m)
- y = (y * m + self.height * self.width * (1 - m)).reshape(
- nb, self.height, self.width
- )
+ y = y.reshape(nb, self.height, self.width)
return y