nb_mlps_per_batch=1024,
):
- inputs, q_test_sets = [],[]
+ seqs, q_test_sets = [],[]
for n in range(0,nb_mlps,nb_mlps_per_batch):
q_train_set, q_test_set, q_params = generate_sets_and_params(
device=device,
)
- inputs.append(torch.cat(
+ seqs.append(torch.cat(
[
q_train_set,
q_train_set.new_full(
q_test_sets.append(q_test_set)
- input = torch.cat(inputs)
+ seq = torch.cat(seqs)
q_test_set = torch.cat(q_test_sets)
- return input, q_test_set
+ return seq, q_test_set
######################################################################
data = []
- input, q_test_set = generate_sequence_and_test_set(
+ seq, q_test_set = generate_sequence_and_test_set(
nb_mlps=batch_nb_mlps,
nb_samples=nb_samples,
device=device,
)
end_time = time.perf_counter()
- print(f"{input.size(0) / (end_time - start_time):.02f} samples per second")
+ print(f"{seq.size(0) / (end_time - start_time):.02f} samples per second")
- q_train_set = input[:, : nb_samples * 3]
- q_params = input[:, nb_samples * 3 + 1 :]
- print(f"SANITY #2 {q_train_set.size()=} {q_params.size()=} {input.size()=}")
+ q_train_set = seq[:, : nb_samples * 3]
+ q_params = seq[:, nb_samples * 3 + 1 :]
+ print(f"SANITY #2 {q_train_set.size()=} {q_params.size()=} {seq.size()=}")
error_train = evaluate_q_params(q_params, q_train_set, nb_mlps_per_batch=17)
print(f"train {error_train*100}%")
error_test = evaluate_q_params(q_params, q_test_set, nb_mlps_per_batch=17)