X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=world.py;h=1d64fa39cc66885cf32896df4e0ae199fa1cc2bf;hb=a2ffcd9b27aa0f3cc0b56090a32e88b73dfa0a54;hp=12c65535ff493700429167132494c5befc05be7e;hpb=5366dfd7bd57ec3298d1030f7d5327ff26bc5aad;p=picoclvr.git diff --git a/world.py b/world.py index 12c6553..1d64fa3 100755 --- a/world.py +++ b/world.py @@ -61,12 +61,13 @@ class SignSTE(nn.Module): else: return s + class DiscreteSampler2d(nn.Module): def __init__(self): super().__init__() def forward(self, x): - s = (x >= x.max(-3,keepdim=True).values).float() + s = (x >= x.max(-3, keepdim=True).values).float() if self.training: u = x.softmax(dim=-3) @@ -96,7 +97,6 @@ def train_encoder( logger=None, device=torch.device("cpu"), ): - mu, std = train_input.float().mean(), train_input.float().std() def encoder_core(depth, dim): @@ -459,7 +459,8 @@ if __name__ == "__main__": frame2seq, seq2frame, ) = create_data_and_processors( - 25000, 1000, + 25000, + 1000, nb_epochs=5, mode="first_last", nb_steps=20,