+ acc_test_loss += test_loss.item() * input.size(0)
+
+ train_loss = acc_train_loss / train_input.size(0)
+ test_loss = acc_test_loss / test_input.size(0)
+
+ print(f"train_ae {k} lr {lr} train_loss {train_loss} test_loss {test_loss}")
+ sys.stdout.flush()
+
+ return encoder, quantizer, decoder
+
+def generate_episodes(nb):
+ all_frames = []
+ for n in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world-data"):
+ frames, actions = generate_episode(nb_steps=31)
+ all_frames += [ frames[0], frames[-1] ]
+ return torch.cat(all_frames, 0).contiguous()
+
+def create_data_and_processors(nb_train_samples, nb_test_samples):
+ train_input = generate_episodes(nb_train_samples)
+ test_input = generate_episodes(nb_test_samples)
+ encoder, quantizer, decoder = train_encoder(train_input, test_input, nb_epochs=2)
+
+ input = test_input[:64]
+
+ z = encoder(input.float())
+ height, width = z.size(2), z.size(3)
+ zq = quantizer(z).long()
+ pow2=(2**torch.arange(zq.size(1), device=zq.device))[None,None,:]
+ seq = (zq.permute(0,2,3,1).clamp(min=0).reshape(zq.size(0),-1,zq.size(1)) * pow2).sum(-1)
+ print(f"{seq.size()=}")
+
+ ZZ=zq
+
+ zq = ((seq[:,:,None] // pow2)%2)*2-1
+ zq = zq.reshape(zq.size(0), height, width, -1).permute(0,3,1,2)
+
+ print(ZZ[0])
+ print(zq[0])
+
+ print("CHECK", (ZZ-zq).abs().sum())
+
+ results = decoder(zq.float())
+ T = 0.1
+ results = results.reshape(
+ results.size(0), -1, 3, results.size(2), results.size(3)
+ ).permute(0, 2, 3, 4, 1)
+ results = torch.distributions.categorical.Categorical(logits=results / T).sample()
+
+
+ torchvision.utils.save_image(
+ input.float() / (Box.nb_rgb_levels - 1), "orig.png", nrow=8
+ )
+
+ torchvision.utils.save_image(
+ results.float() / (Box.nb_rgb_levels - 1), "qtiz.png", nrow=8
+ )