projects
/
picoclvr.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[picoclvr.git]
/
world.py
diff --git
a/world.py
b/world.py
index
5c21fad
..
fb8609d
100755
(executable)
--- a/
world.py
+++ b/
world.py
@@
-72,8
+72,12
@@
def train_encoder(
lr_end=1e-4,
nb_epochs=10,
batch_size=25,
lr_end=1e-4,
nb_epochs=10,
batch_size=25,
+ logger=None,
device=torch.device("cpu"),
):
device=torch.device("cpu"),
):
+ if logger is None:
+ logger = lambda s: print(s)
+
mu, std = train_input.float().mean(), train_input.float().std()
def encoder_core(depth, dim):
mu, std = train_input.float().mean(), train_input.float().std()
def encoder_core(depth, dim):
@@
-132,7
+136,7
@@
def train_encoder(
nb_parameters = sum(p.numel() for p in model.parameters())
nb_parameters = sum(p.numel() for p in model.parameters())
-
print
(f"nb_parameters {nb_parameters}")
+
logger
(f"nb_parameters {nb_parameters}")
model.to(device)
model.to(device)
@@
-179,7
+183,7
@@
def train_encoder(
train_loss = acc_train_loss / train_input.size(0)
test_loss = acc_test_loss / test_input.size(0)
train_loss = acc_train_loss / train_input.size(0)
test_loss = acc_test_loss / test_input.size(0)
-
print
(f"train_ae {k} lr {lr} train_loss {train_loss} test_loss {test_loss}")
+
logger
(f"train_ae {k} lr {lr} train_loss {train_loss} test_loss {test_loss}")
sys.stdout.flush()
return encoder, quantizer, decoder
sys.stdout.flush()
return encoder, quantizer, decoder
@@
-326,7
+330,7
@@
def generate_episodes(nb, steps):
for n in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world-data"):
frames, actions = generate_episode(steps)
all_frames += frames
for n in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world-data"):
frames, actions = generate_episode(steps)
all_frames += frames
- all_actions += [actions]
+ all_actions += [actions
[None, :]
]
return torch.cat(all_frames, 0).contiguous(), torch.cat(all_actions, 0)
return torch.cat(all_frames, 0).contiguous(), torch.cat(all_actions, 0)
@@
-337,6
+341,7
@@
def create_data_and_processors(
nb_steps,
nb_epochs=10,
device=torch.device("cpu"),
nb_steps,
nb_epochs=10,
device=torch.device("cpu"),
+ logger=None,
):
assert mode in ["first_last"]
):
assert mode in ["first_last"]
@@
-349,7
+354,7
@@
def create_data_and_processors(
test_input, test_actions = test_input.to(device), test_actions.to(device)
encoder, quantizer, decoder = train_encoder(
test_input, test_actions = test_input.to(device), test_actions.to(device)
encoder, quantizer, decoder = train_encoder(
- train_input, test_input, nb_epochs=nb_epochs, device=device
+ train_input, test_input, nb_epochs=nb_epochs,
logger=logger,
device=device
)
encoder.train(False)
quantizer.train(False)
)
encoder.train(False)
quantizer.train(False)