projects
/
pytorch.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[pytorch.git]
/
tiny_vae.py
diff --git
a/tiny_vae.py
b/tiny_vae.py
index
577f717
..
784f775
100755
(executable)
--- a/
tiny_vae.py
+++ b/
tiny_vae.py
@@
-24,10
+24,14
@@
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
######################################################################
######################################################################
-parser = argparse.ArgumentParser(description="Tiny LeNet-like auto-encoder.")
+parser = argparse.ArgumentParser(
+ description="Very simple implementation of a VAE for teaching."
+)
parser.add_argument("--nb_epochs", type=int, default=100)
parser.add_argument("--nb_epochs", type=int, default=100)
+parser.add_argument("--learning_rate", type=float, default=2e-4)
+
parser.add_argument("--batch_size", type=int, default=100)
parser.add_argument("--data_dir", type=str, default="./data/")
parser.add_argument("--batch_size", type=int, default=100)
parser.add_argument("--data_dir", type=str, default="./data/")
@@
-135,6
+139,7
@@
class ImageGivenLatentNet(nn.Module):
def forward(self, z):
output = self.model(z.view(z.size(0), -1, 1, 1))
mu, log_var = output[:, 0:1], output[:, 1:2]
def forward(self, z):
output = self.model(z.view(z.size(0), -1, 1, 1))
mu, log_var = output[:, 0:1], output[:, 1:2]
+ # log_var.flatten(1)[...]=log_var.flatten(1)[:,:1]
return mu, log_var
return mu, log_var
@@
-160,7
+165,7
@@
model_p_X_given_z = ImageGivenLatentNet(
optimizer = optim.Adam(
itertools.chain(model_p_X_given_z.parameters(), model_q_Z_given_x.parameters()),
optimizer = optim.Adam(
itertools.chain(model_p_X_given_z.parameters(), model_q_Z_given_x.parameters()),
- lr=
4e-4
,
+ lr=
args.learning_rate
,
)
model_p_X_given_z.to(device)
)
model_p_X_given_z.to(device)