- mu, std = train_input.float().mean(), train_input.float().std()
-
- def encoder_core(depth, dim):
- l = [
- [
- nn.Conv2d(
- dim * 2**k, dim * 2**k, kernel_size=5, stride=1, padding=2
- ),
- nn.ReLU(),
- nn.Conv2d(dim * 2**k, dim * 2 ** (k + 1), kernel_size=2, stride=2),
- nn.ReLU(),
- ]
- for k in range(depth)
- ]
-
- return nn.Sequential(*[x for m in l for x in m])
-
- def decoder_core(depth, dim):
- l = [
- [
- nn.ConvTranspose2d(
- dim * 2 ** (k + 1), dim * 2**k, kernel_size=2, stride=2
- ),
- nn.ReLU(),
- nn.ConvTranspose2d(
- dim * 2**k, dim * 2**k, kernel_size=5, stride=1, padding=2
- ),
- nn.ReLU(),
- ]
- for k in range(depth - 1, -1, -1)
- ]
-
- return nn.Sequential(*[x for m in l for x in m])
-
- encoder = nn.Sequential(
- Normalizer(mu, std),
- nn.Conv2d(3, dim_hidden, kernel_size=1, stride=1),
- nn.ReLU(),
- # 64x64
- encoder_core(depth=depth, dim=dim_hidden),
- # 8x8
- nn.Conv2d(dim_hidden * 2**depth, nb_bits_per_token, kernel_size=1, stride=1),
- )
-
- quantizer = SignSTE()
-
- decoder = nn.Sequential(
- nn.Conv2d(nb_bits_per_token, dim_hidden * 2**depth, kernel_size=1, stride=1),
- # 8x8
- decoder_core(depth=depth, dim=dim_hidden),
- # 64x64
- nn.ConvTranspose2d(dim_hidden, 3 * Box.nb_rgb_levels, kernel_size=1, stride=1),
- )
-
- model = nn.Sequential(encoder, decoder)
-
- nb_parameters = sum(p.numel() for p in model.parameters())
-
- print(f"nb_parameters {nb_parameters}")
-
- model.to(device)
-
- for k in range(nb_epochs):
- lr = math.exp(
- math.log(lr_start) + math.log(lr_end / lr_start) / (nb_epochs - 1) * k
- )
- optimizer = torch.optim.Adam(model.parameters(), lr=lr)
-
- acc_train_loss = 0.0
-
- for input in tqdm.tqdm(train_input.split(batch_size), desc="vqae-train"):
- z = encoder(input)
- zq = z if k < 1 else quantizer(z)
- output = decoder(zq)
-
- output = output.reshape(
- output.size(0), -1, 3, output.size(2), output.size(3)
+ f_start = torch.zeros(nb, height, width, dtype=torch.int64)
+ f_end = torch.zeros(nb, height, width, dtype=torch.int64)
+ n = torch.arange(f_start.size(0))
+
+ for n in range(nb):
+ nb_fish = torch.randint(max_nb_obj, (1,)).item() + 1
+ for c in range(nb_fish):
+ i, j = (
+ torch.randint(height - 2, (1,))[0] + 1,
+ torch.randint(width - 2, (1,))[0] + 1,