#!/usr/bin/env python
-import math, sys
+import math, sys, tqdm
import torch, torchvision
def __init__(self, mu, std):
super().__init__()
self.mu = nn.Parameter(mu)
- self.log_var = nn.Parameter(2*torch.log(std))
+ self.log_var = nn.Parameter(2 * torch.log(std))
def forward(self, x):
- return (x-self.mu)/torch.exp(self.log_var/2.0)
+ return (x - self.mu) / torch.exp(self.log_var / 2.0)
+
class SignSTE(nn.Module):
def __init__(self):
dim_hidden=64,
block_size=16,
nb_bits_per_block=10,
- lr_start=1e-3, lr_end=1e-5,
- nb_epochs=50,
+ lr_start=1e-3,
+ lr_end=1e-5,
+ nb_epochs=10,
batch_size=25,
device=torch.device("cpu"),
):
model.to(device)
for k in range(nb_epochs):
- lr=math.exp(math.log(lr_start) + math.log(lr_end/lr_start)/(nb_epochs-1)*k)
+ lr = math.exp(
+ math.log(lr_start) + math.log(lr_end / lr_start) / (nb_epochs - 1) * k
+ )
print(f"lr {lr}")
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
acc_loss, nb_samples = 0.0, 0
- for input in train_input.split(batch_size):
+ for input in tqdm.tqdm(
+ train_input.split(batch_size),
+ dynamic_ncols=True,
+ desc="vqae-train",
+ total=train_input.size(0) // batch_size,
+ ):
output = model(input)
loss = F.mse_loss(output, input)
acc_loss += loss.item() * input.size(0)
all_frames = []
nb = 25000
start_time = time.perf_counter()
- for n in range(nb):
+ for n in tqdm.tqdm(
+ range(nb),
+ dynamic_ncols=True,
+ desc="world-data",
+ ):
frames, actions = generate_sequence(nb_steps=31)
all_frames += frames
end_time = time.perf_counter()