X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=world.py;h=aad0bfb9727a3757dd90a0bfcb56e74040c6e011;hb=732349f7c16e43ff84380d28e021d671f2c56492;hp=12c65535ff493700429167132494c5befc05be7e;hpb=5366dfd7bd57ec3298d1030f7d5327ff26bc5aad;p=picoclvr.git diff --git a/world.py b/world.py index 12c6553..aad0bfb 100755 --- a/world.py +++ b/world.py @@ -1,5 +1,10 @@ #!/usr/bin/env python +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + import math, sys, tqdm import torch, torchvision @@ -61,12 +66,13 @@ class SignSTE(nn.Module): else: return s + class DiscreteSampler2d(nn.Module): def __init__(self): super().__init__() def forward(self, x): - s = (x >= x.max(-3,keepdim=True).values).float() + s = (x >= x.max(-3, keepdim=True).values).float() if self.training: u = x.softmax(dim=-3) @@ -96,7 +102,6 @@ def train_encoder( logger=None, device=torch.device("cpu"), ): - mu, std = train_input.float().mean(), train_input.float().std() def encoder_core(depth, dim): @@ -459,7 +464,8 @@ if __name__ == "__main__": frame2seq, seq2frame, ) = create_data_and_processors( - 25000, 1000, + 25000, + 1000, nb_epochs=5, mode="first_last", nb_steps=20,