X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=world.py;h=64c7434129c15eb1bd630e67c33b330c7bb26b9b;hb=e3a8032a070175ece08fc79c77312d5f2f59150e;hp=da7de75bd143e95244812b6666179ff915bd5d1e;hpb=a92a5ca00f4277f7a133fa6cfaada2bc1981f524;p=picoclvr.git diff --git a/world.py b/world.py index da7de75..64c7434 100755 --- a/world.py +++ b/world.py @@ -169,7 +169,7 @@ def train_encoder( train_loss = F.cross_entropy(output, input) if lambda_entropy > 0: - loss = loss + lambda_entropy * loss_H(z, h_threshold=0.5) + train_loss = train_loss + lambda_entropy * loss_H(z, h_threshold=0.5) acc_train_loss += train_loss.item() * input.size(0) @@ -439,26 +439,21 @@ if __name__ == "__main__": frame2seq, seq2frame, ) = create_data_and_processors( - # 10000, 1000, - 100, - 100, - nb_epochs=2, + 25000, 1000, + nb_epochs=10, mode="first_last", nb_steps=20, ) - input = test_input[:64] + input = test_input[:256] seq = frame2seq(input) - - print(f"{seq.size()=} {seq.dtype=} {seq.min()=} {seq.max()=}") - output = seq2frame(seq) torchvision.utils.save_image( - input.float() / (Box.nb_rgb_levels - 1), "orig.png", nrow=8 + input.float() / (Box.nb_rgb_levels - 1), "orig.png", nrow=16 ) torchvision.utils.save_image( - output.float() / (Box.nb_rgb_levels - 1), "qtiz.png", nrow=8 + output.float() / (Box.nb_rgb_levels - 1), "qtiz.png", nrow=16 )