X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=world.py;h=fb8609d82109ce2f29d0612a553bfe77d36a74c5;hb=2192d72289bbf2cd069f67d3e93daf7934f886af;hp=b0779878e0264d79e368f9147bb95363e7caf122;hpb=0c495b7d79915a65d6680203086a94e06df80580;p=picoclvr.git diff --git a/world.py b/world.py index b077987..fb8609d 100755 --- a/world.py +++ b/world.py @@ -65,15 +65,19 @@ class SignSTE(nn.Module): def train_encoder( train_input, test_input, - depth=3, + depth=2, dim_hidden=48, nb_bits_per_token=8, lr_start=1e-3, lr_end=1e-4, nb_epochs=10, batch_size=25, + logger=None, device=torch.device("cpu"), ): + if logger is None: + logger = lambda s: print(s) + mu, std = train_input.float().mean(), train_input.float().std() def encoder_core(depth, dim): @@ -132,7 +136,7 @@ def train_encoder( nb_parameters = sum(p.numel() for p in model.parameters()) - print(f"nb_parameters {nb_parameters}") + logger(f"nb_parameters {nb_parameters}") model.to(device) @@ -146,7 +150,7 @@ def train_encoder( for input in tqdm.tqdm(train_input.split(batch_size), desc="vqae-train"): z = encoder(input) - zq = z if k < 1 else quantizer(z) + zq = z if k < 2 else quantizer(z) output = decoder(zq) output = output.reshape( @@ -179,7 +183,7 @@ def train_encoder( train_loss = acc_train_loss / train_input.size(0) test_loss = acc_test_loss / test_input.size(0) - print(f"train_ae {k} lr {lr} train_loss {train_loss} test_loss {test_loss}") + logger(f"train_ae {k} lr {lr} train_loss {train_loss} test_loss {test_loss}") sys.stdout.flush() return encoder, quantizer, decoder @@ -322,22 +326,35 @@ def generate_episode(steps, size=64): def generate_episodes(nb, steps): - all_frames = [] + all_frames, all_actions = [], [] for n in tqdm.tqdm(range(nb), dynamic_ncols=True, desc="world-data"): frames, actions = generate_episode(steps) all_frames += frames - return torch.cat(all_frames, 0).contiguous() + all_actions += [actions[None, :]] + return torch.cat(all_frames, 0).contiguous(), torch.cat(all_actions, 0) -def create_data_and_processors(nb_train_samples, nb_test_samples, nb_epochs=10): - steps = [True] + [False] * 30 + [True] - train_input = generate_episodes(nb_train_samples, steps) - test_input = generate_episodes(nb_test_samples, steps) +def create_data_and_processors( + nb_train_samples, + nb_test_samples, + mode, + nb_steps, + nb_epochs=10, + device=torch.device("cpu"), + logger=None, +): + assert mode in ["first_last"] + + if mode == "first_last": + steps = [True] + [False] * (nb_steps + 1) + [True] - print(f"{train_input.size()=} {test_input.size()=}") + train_input, train_actions = generate_episodes(nb_train_samples, steps) + train_input, train_actions = train_input.to(device), train_actions.to(device) + test_input, test_actions = generate_episodes(nb_test_samples, steps) + test_input, test_actions = test_input.to(device), test_actions.to(device) encoder, quantizer, decoder = train_encoder( - train_input, test_input, nb_epochs=nb_epochs + train_input, test_input, nb_epochs=nb_epochs, logger=logger, device=device ) encoder.train(False) quantizer.train(False) @@ -347,35 +364,61 @@ def create_data_and_processors(nb_train_samples, nb_test_samples, nb_epochs=10): pow2 = (2 ** torch.arange(z.size(1), device=z.device))[None, None, :] z_h, z_w = z.size(2), z.size(3) - def frame2seq(x): - z = encoder(x) - ze_bool = (quantizer(z) >= 0).long() - seq = ( - ze_bool.permute(0, 2, 3, 1).reshape(ze_bool.size(0), -1, ze_bool.size(1)) - * pow2 - ).sum(-1) - return seq - - def seq2frame(seq, T=1e-2): - zd_bool = (seq[:, :, None] // pow2) % 2 - zd_bool = zd_bool.reshape(zd_bool.size(0), z_h, z_w, -1).permute(0, 3, 1, 2) - logits = decoder(zd_bool * 2.0 - 1.0) - logits = logits.reshape( - logits.size(0), -1, 3, logits.size(2), logits.size(3) - ).permute(0, 2, 3, 4, 1) - results = torch.distributions.categorical.Categorical( - logits=logits / T - ).sample() - return results - - return train_input, test_input, frame2seq, seq2frame + def frame2seq(input, batch_size=25): + seq = [] + + for x in input.split(batch_size): + z = encoder(x) + ze_bool = (quantizer(z) >= 0).long() + output = ( + ze_bool.permute(0, 2, 3, 1).reshape( + ze_bool.size(0), -1, ze_bool.size(1) + ) + * pow2 + ).sum(-1) + + seq.append(output) + + return torch.cat(seq, dim=0) + + def seq2frame(input, batch_size=25, T=1e-2): + frames = [] + + for seq in input.split(batch_size): + zd_bool = (seq[:, :, None] // pow2) % 2 + zd_bool = zd_bool.reshape(zd_bool.size(0), z_h, z_w, -1).permute(0, 3, 1, 2) + logits = decoder(zd_bool * 2.0 - 1.0) + logits = logits.reshape( + logits.size(0), -1, 3, logits.size(2), logits.size(3) + ).permute(0, 2, 3, 4, 1) + output = torch.distributions.categorical.Categorical( + logits=logits / T + ).sample() + + frames.append(output) + + return torch.cat(frames, dim=0) + + return train_input, train_actions, test_input, test_actions, frame2seq, seq2frame ###################################################################### if __name__ == "__main__": - train_input, test_input, frame2seq, seq2frame = create_data_and_processors( - 10000, 1000 + ( + train_input, + train_actions, + test_input, + test_actions, + frame2seq, + seq2frame, + ) = create_data_and_processors( + # 10000, 1000, + 100, + 100, + nb_epochs=2, + mode="first_last", + nb_steps=20, ) input = test_input[:64]