+######################################################################
+
+
+class Normalizer(nn.Module):
+ def __init__(self, mu, std):
+ super().__init__()
+ self.register_buffer("mu", mu)
+ self.register_buffer("log_var", 2 * torch.log(std))
+
+ def forward(self, x):
+ return (x - self.mu) / torch.exp(self.log_var / 2.0)
+
+
+class SignSTE(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ # torch.sign() takes three values
+ s = (x >= 0).float() * 2 - 1
+
+ if self.training:
+ u = torch.tanh(x)
+ return s + u - u.detach()
+ else:
+ return s
+
+
+class DiscreteSampler2d(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ s = (x >= x.max(-3, keepdim=True).values).float()
+
+ if self.training:
+ u = x.softmax(dim=-3)
+ return s + u - u.detach()
+ else:
+ return s
+
+
+def loss_H(binary_logits, h_threshold=1):
+ p = binary_logits.sigmoid().mean(0)
+ h = (-p.xlogy(p) - (1 - p).xlogy(1 - p)) / math.log(2)
+ h.clamp_(max=h_threshold)
+ return h_threshold - h.mean()
+
+
+def train_encoder(
+ train_input,
+ test_input,
+ depth,
+ nb_bits_per_token,
+ dim_hidden=48,
+ lambda_entropy=0.0,
+ lr_start=1e-3,
+ lr_end=1e-4,
+ nb_epochs=10,
+ batch_size=25,
+ logger=None,
+ device=torch.device("cpu"),
+):
+ mu, std = train_input.float().mean(), train_input.float().std()
+
+ def encoder_core(depth, dim):
+ l = [
+ [
+ nn.Conv2d(
+ dim * 2**k, dim * 2**k, kernel_size=5, stride=1, padding=2
+ ),
+ nn.ReLU(),
+ nn.Conv2d(dim * 2**k, dim * 2 ** (k + 1), kernel_size=2, stride=2),
+ nn.ReLU(),
+ ]
+ for k in range(depth)
+ ]
+
+ return nn.Sequential(*[x for m in l for x in m])
+
+ def decoder_core(depth, dim):
+ l = [
+ [
+ nn.ConvTranspose2d(
+ dim * 2 ** (k + 1), dim * 2**k, kernel_size=2, stride=2
+ ),
+ nn.ReLU(),
+ nn.ConvTranspose2d(
+ dim * 2**k, dim * 2**k, kernel_size=5, stride=1, padding=2
+ ),
+ nn.ReLU(),
+ ]
+ for k in range(depth - 1, -1, -1)
+ ]
+
+ return nn.Sequential(*[x for m in l for x in m])
+
+ encoder = nn.Sequential(
+ Normalizer(mu, std),
+ nn.Conv2d(3, dim_hidden, kernel_size=1, stride=1),
+ nn.ReLU(),
+ # 64x64
+ encoder_core(depth=depth, dim=dim_hidden),
+ # 8x8
+ nn.Conv2d(dim_hidden * 2**depth, nb_bits_per_token, kernel_size=1, stride=1),
+ )
+
+ quantizer = SignSTE()
+
+ decoder = nn.Sequential(
+ nn.Conv2d(nb_bits_per_token, dim_hidden * 2**depth, kernel_size=1, stride=1),
+ # 8x8
+ decoder_core(depth=depth, dim=dim_hidden),
+ # 64x64
+ nn.ConvTranspose2d(dim_hidden, 3 * Box.nb_rgb_levels, kernel_size=1, stride=1),
+ )
+
+ model = nn.Sequential(encoder, decoder)
+
+ nb_parameters = sum(p.numel() for p in model.parameters())
+
+ logger(f"vqae nb_parameters {nb_parameters}")
+
+ model.to(device)
+
+ for k in range(nb_epochs):
+ lr = math.exp(
+ math.log(lr_start) + math.log(lr_end / lr_start) / (nb_epochs - 1) * k
+ )
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
+
+ acc_train_loss = 0.0
+
+ for input in tqdm.tqdm(train_input.split(batch_size), desc="vqae-train"):
+ input = input.to(device)
+ z = encoder(input)
+ zq = quantizer(z)
+ output = decoder(zq)
+
+ output = output.reshape(
+ output.size(0), -1, 3, output.size(2), output.size(3)
+ )
+
+ train_loss = F.cross_entropy(output, input)
+
+ if lambda_entropy > 0:
+ train_loss = train_loss + lambda_entropy * loss_H(z, h_threshold=0.5)
+
+ acc_train_loss += train_loss.item() * input.size(0)
+
+ optimizer.zero_grad()
+ train_loss.backward()
+ optimizer.step()
+
+ acc_test_loss = 0.0
+
+ for input in tqdm.tqdm(test_input.split(batch_size), desc="vqae-test"):
+ input = input.to(device)
+ z = encoder(input)
+ zq = quantizer(z)
+ output = decoder(zq)
+
+ output = output.reshape(
+ output.size(0), -1, 3, output.size(2), output.size(3)
+ )
+
+ test_loss = F.cross_entropy(output, input)
+
+ acc_test_loss += test_loss.item() * input.size(0)
+
+ train_loss = acc_train_loss / train_input.size(0)
+ test_loss = acc_test_loss / test_input.size(0)
+
+ logger(f"vqae train {k} lr {lr} train_loss {train_loss} test_loss {test_loss}")
+ sys.stdout.flush()
+
+ return encoder, quantizer, decoder
+
+
+######################################################################
+
+