From: François Fleuret Date: Tue, 11 Jul 2023 06:13:35 +0000 (+0200) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=1f6f5e352af881e57e26fa39ca5bf793c5d2c9c5;p=culture.git Update. --- diff --git a/world.py b/world.py index a43eff9..d32d545 100755 --- a/world.py +++ b/world.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -import math, sys +import math, sys, tqdm import torch, torchvision @@ -232,10 +232,11 @@ class Normalizer(nn.Module): def __init__(self, mu, std): super().__init__() self.mu = nn.Parameter(mu) - self.log_var = nn.Parameter(2*torch.log(std)) + self.log_var = nn.Parameter(2 * torch.log(std)) def forward(self, x): - return (x-self.mu)/torch.exp(self.log_var/2.0) + return (x - self.mu) / torch.exp(self.log_var / 2.0) + class SignSTE(nn.Module): def __init__(self): @@ -256,8 +257,9 @@ def train_encoder( dim_hidden=64, block_size=16, nb_bits_per_block=10, - lr_start=1e-3, lr_end=1e-5, - nb_epochs=50, + lr_start=1e-3, + lr_end=1e-5, + nb_epochs=10, batch_size=25, device=torch.device("cpu"), ): @@ -312,12 +314,19 @@ def train_encoder( model.to(device) for k in range(nb_epochs): - lr=math.exp(math.log(lr_start) + math.log(lr_end/lr_start)/(nb_epochs-1)*k) + lr = math.exp( + math.log(lr_start) + math.log(lr_end / lr_start) / (nb_epochs - 1) * k + ) print(f"lr {lr}") optimizer = torch.optim.Adam(model.parameters(), lr=lr) acc_loss, nb_samples = 0.0, 0 - for input in train_input.split(batch_size): + for input in tqdm.tqdm( + train_input.split(batch_size), + dynamic_ncols=True, + desc="vqae-train", + total=train_input.size(0) // batch_size, + ): output = model(input) loss = F.mse_loss(output, input) acc_loss += loss.item() * input.size(0) @@ -341,7 +350,11 @@ if __name__ == "__main__": all_frames = [] nb = 25000 start_time = time.perf_counter() - for n in range(nb): + for n in tqdm.tqdm( + range(nb), + dynamic_ncols=True, + desc="world-data", + ): frames, actions = generate_sequence(nb_steps=31) all_frames += frames end_time = time.perf_counter()