- nb = 25000
- start_time = time.perf_counter()
- for n in range(nb):
- frames, actions = generate_sequence(nb_steps=31)
- all_frames += frames
- end_time = time.perf_counter()
- print(f"{nb / (end_time - start_time):.02f} samples per second")
+ for n in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world-data"):
+ frames, actions = generate_episode(nb_steps=31)
+ all_frames += [ frames[0], frames[-1] ]
+ return torch.cat(all_frames, 0).contiguous()
+
+def create_data_and_processors(nb_train_samples, nb_test_samples):
+ train_input = generate_episodes(nb_train_samples)
+ test_input = generate_episodes(nb_test_samples)
+ encoder, quantizer, decoder = train_encoder(train_input, test_input, nb_epochs=2)
+
+ input = test_input[:64]
+
+ z = encoder(input.float())
+ height, width = z.size(2), z.size(3)
+ zq = quantizer(z).long()
+ pow2=(2**torch.arange(zq.size(1), device=zq.device))[None,None,:]
+ seq = (zq.permute(0,2,3,1).clamp(min=0).reshape(zq.size(0),-1,zq.size(1)) * pow2).sum(-1)
+ print(f"{seq.size()=}")
+
+ ZZ=zq