From e3a8032a070175ece08fc79c77312d5f2f59150e Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 17 Jul 2023 14:25:45 +0200 Subject: [PATCH] Update. --- world.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/world.py b/world.py index da7de75..64c7434 100755 --- a/world.py +++ b/world.py @@ -169,7 +169,7 @@ def train_encoder( train_loss = F.cross_entropy(output, input) if lambda_entropy > 0: - loss = loss + lambda_entropy * loss_H(z, h_threshold=0.5) + train_loss = train_loss + lambda_entropy * loss_H(z, h_threshold=0.5) acc_train_loss += train_loss.item() * input.size(0) @@ -439,26 +439,21 @@ if __name__ == "__main__": frame2seq, seq2frame, ) = create_data_and_processors( - # 10000, 1000, - 100, - 100, - nb_epochs=2, + 25000, 1000, + nb_epochs=10, mode="first_last", nb_steps=20, ) - input = test_input[:64] + input = test_input[:256] seq = frame2seq(input) - - print(f"{seq.size()=} {seq.dtype=} {seq.min()=} {seq.max()=}") - output = seq2frame(seq) torchvision.utils.save_image( - input.float() / (Box.nb_rgb_levels - 1), "orig.png", nrow=8 + input.float() / (Box.nb_rgb_levels - 1), "orig.png", nrow=16 ) torchvision.utils.save_image( - output.float() / (Box.nb_rgb_levels - 1), "qtiz.png", nrow=8 + output.float() / (Box.nb_rgb_levels - 1), "qtiz.png", nrow=16 ) -- 2.39.5