class Problem:
- def generate(nb):
+ def generate_sequences(self, nb):
pass
- def perf(seq, logger):
+ def log_performance(self, sequences, logger):
pass
def __init__(self):
nb_seq, len_prompt, len_result = 100, 5, 5
self.seq = torch.randint(10, (nb_seq, len_prompt + 1 + len_result))
- self.seq[:, len_prompt] = -1
+ self.seq[:, len_prompt] = 10
def generate_sequences(self, nb):
- return self.seq[torch.randint(self.seq.size(0), (nb,))]
-
+ sequences = self.seq[torch.randint(self.seq.size(0), (nb,))]
+ ar_mask = (sequences==10).long()
+ ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
+ return sequences, ar_mask
+
+ # problems = [ProblemByheart()]
+ # nb_common_codes = 100
+
+ # def generate_sequences(nb_samples):
+ # problem_indexes = torch.randint(len(problems), (nb_samples,))
+ # nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0)
+ # print(f"{nb_samples_per_problem}")
+ # all_seq = []
+ # for nb, p in zip(nb_samples_per_problem, problems):
+ # all_seq.append(p.generate_sequences(nb_samples_per_problem[nb]))
+ # return all_seq
+
+ # for strain, stest in zip(train_seq, test_seq):
+ # s = torch.cat((strain, stest), 0)
class SandBox(Task):
def __init__(
self,
+ problem,
nb_train_samples,
nb_test_samples,
batch_size,
super().__init__()
self.batch_size = batch_size
+ self.device = device
- problems = [ProblemByheart()]
- nb_common_codes = 100
-
- def generate_sequences(nb_samples):
- problem_indexes = torch.randint(len(problems), (nb_samples,))
- nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0)
- print(f"{nb_samples_per_problem}")
- all_seq = []
- for nb, p in zip(nb_samples_per_problem, problems):
- all_seq.append(p.generate_sequences(nb_samples_per_problem[nb]))
- return all_seq
-
- train_seq = generate_sequences(nb_train_samples)
- test_seq = generate_sequences(nb_test_samples)
-
- for strain, stest in zip(train_seq, test_seq):
- s = torch.cat((strain, stest), 0)
+ self.train_input, self.train_ar_mask = problem.generate_sequences(nb_train_samples)
+ self.test_input, self.test_ar_mask = problem.generate_sequences(nb_test_samples)
self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
def produce_results(
self, n_epoch, model, result_dir, logger, deterministic_synthesis
):
- # logger(
- # f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
- # )
- pass
+ def compute_accuracy(input, ar_mask):
+ result = input.clone() * (1-ar_mask)
+ masked_inplace_autoregression(
+ model,
+ self.batch_size,
+ result,
+ ar_mask,
+ deterministic_synthesis,
+ progress_bar_desc=None,
+ device=self.device,
+ )
+
+ nb_total = ar_mask.sum().item()
+ nb_correct = ((result==input).long() * ar_mask).sum().item()
+
+ return nb_total, nb_correct
+
+ train_nb_total, train_nb_correct = compute_accuracy(self.train_input, self.train_ar_mask)
+
+ logger(
+ f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
+ )
+
+ test_nb_total, test_nb_correct = compute_accuracy(self.test_input, self.test_ar_mask)
+
+ logger(
+ f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
+ )
######################################################################
logger=None,
device=torch.device("cpu"),
):
- if logger is None:
- logger = lambda s: print(s)
mu, std = train_input.float().mean(), train_input.float().std()
nb_parameters = sum(p.numel() for p in model.parameters())
- logger(f"nb_parameters {nb_parameters}")
+ logger(f"vqae nb_parameters {nb_parameters}")
model.to(device)
train_loss = acc_train_loss / train_input.size(0)
test_loss = acc_test_loss / test_input.size(0)
- logger(f"train_ae {k} lr {lr} train_loss {train_loss} test_loss {test_loss}")
+ logger(f"vqae train {k} lr {lr} train_loss {train_loss} test_loss {test_loss}")
sys.stdout.flush()
return encoder, quantizer, decoder
if mode == "first_last":
steps = [True] + [False] * (nb_steps + 1) + [True]
+ if logger is None:
+ logger = lambda s: print(s)
+
train_input, train_actions = generate_episodes(nb_train_samples, steps)
train_input, train_actions = train_input.to(device_storage), train_actions.to(
device_storage
pow2 = (2 ** torch.arange(z.size(1), device=device))[None, None, :]
z_h, z_w = z.size(2), z.size(3)
+ logger(f"vqae input {train_input[0].size()} output {z[0].size()}")
+
def frame2seq(input, batch_size=25):
seq = []
p = pow2.to(device)