lr_end=1e-4,
nb_epochs=10,
batch_size=25,
+ logger=None,
device=torch.device("cpu"),
):
+ if logger is None:
+ logger = lambda s: print(s)
+
mu, std = train_input.float().mean(), train_input.float().std()
def encoder_core(depth, dim):
nb_parameters = sum(p.numel() for p in model.parameters())
- print(f"nb_parameters {nb_parameters}")
+ logger(f"nb_parameters {nb_parameters}")
model.to(device)
train_loss = acc_train_loss / train_input.size(0)
test_loss = acc_test_loss / test_input.size(0)
- print(f"train_ae {k} lr {lr} train_loss {train_loss} test_loss {test_loss}")
+ logger(f"train_ae {k} lr {lr} train_loss {train_loss} test_loss {test_loss}")
sys.stdout.flush()
return encoder, quantizer, decoder
for n in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world-data"):
frames, actions = generate_episode(steps)
all_frames += frames
- all_actions += [actions]
+ all_actions += [actions[None, :]]
return torch.cat(all_frames, 0).contiguous(), torch.cat(all_actions, 0)
nb_steps,
nb_epochs=10,
device=torch.device("cpu"),
+ logger=None,
):
assert mode in ["first_last"]
test_input, test_actions = test_input.to(device), test_actions.to(device)
encoder, quantizer, decoder = train_encoder(
- train_input, test_input, nb_epochs=nb_epochs, device=device
+ train_input, test_input, nb_epochs=nb_epochs, logger=logger, device=device
)
encoder.train(False)
quantizer.train(False)