projects
/
picoclvr.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[picoclvr.git]
/
world.py
diff --git
a/world.py
b/world.py
index
da7de75
..
64c7434
100755
(executable)
--- a/
world.py
+++ b/
world.py
@@
-169,7
+169,7
@@
def train_encoder(
train_loss = F.cross_entropy(output, input)
if lambda_entropy > 0:
train_loss = F.cross_entropy(output, input)
if lambda_entropy > 0:
-
loss =
loss + lambda_entropy * loss_H(z, h_threshold=0.5)
+
train_loss = train_
loss + lambda_entropy * loss_H(z, h_threshold=0.5)
acc_train_loss += train_loss.item() * input.size(0)
acc_train_loss += train_loss.item() * input.size(0)
@@
-439,26
+439,21
@@
if __name__ == "__main__":
frame2seq,
seq2frame,
) = create_data_and_processors(
frame2seq,
seq2frame,
) = create_data_and_processors(
- # 10000, 1000,
- 100,
- 100,
- nb_epochs=2,
+ 25000, 1000,
+ nb_epochs=10,
mode="first_last",
nb_steps=20,
)
mode="first_last",
nb_steps=20,
)
- input = test_input[:
64
]
+ input = test_input[:
256
]
seq = frame2seq(input)
seq = frame2seq(input)
-
- print(f"{seq.size()=} {seq.dtype=} {seq.min()=} {seq.max()=}")
-
output = seq2frame(seq)
torchvision.utils.save_image(
output = seq2frame(seq)
torchvision.utils.save_image(
- input.float() / (Box.nb_rgb_levels - 1), "orig.png", nrow=
8
+ input.float() / (Box.nb_rgb_levels - 1), "orig.png", nrow=
16
)
torchvision.utils.save_image(
)
torchvision.utils.save_image(
- output.float() / (Box.nb_rgb_levels - 1), "qtiz.png", nrow=
8
+ output.float() / (Box.nb_rgb_levels - 1), "qtiz.png", nrow=
16
)
)