X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=qmlp.py;fp=qmlp.py;h=a7defe49d6244301cc091e9cc3e0b2b3236d8621;hb=26ef53ee3769c3b6b92b85d15b5a43cbd18ede07;hp=e12f0e1a6e16e16d9abf7a5ee18edde1e53d7763;hpb=f44ab6863f93ae348e66ffbf52251d96d3b5453c;p=picoclvr.git diff --git a/qmlp.py b/qmlp.py index e12f0e1..a7defe4 100755 --- a/qmlp.py +++ b/qmlp.py @@ -224,7 +224,7 @@ def generate_sequence_and_test_set( nb_mlps_per_batch=1024, ): - inputs, q_test_sets = [],[] + seqs, q_test_sets = [],[] for n in range(0,nb_mlps,nb_mlps_per_batch): q_train_set, q_test_set, q_params = generate_sets_and_params( @@ -235,7 +235,7 @@ def generate_sequence_and_test_set( device=device, ) - inputs.append(torch.cat( + seqs.append(torch.cat( [ q_train_set, q_train_set.new_full( @@ -252,10 +252,10 @@ def generate_sequence_and_test_set( q_test_sets.append(q_test_set) - input = torch.cat(inputs) + seq = torch.cat(seqs) q_test_set = torch.cat(q_test_sets) - return input, q_test_set + return seq, q_test_set ###################################################################### @@ -271,7 +271,7 @@ if __name__ == "__main__": data = [] - input, q_test_set = generate_sequence_and_test_set( + seq, q_test_set = generate_sequence_and_test_set( nb_mlps=batch_nb_mlps, nb_samples=nb_samples, device=device, @@ -281,11 +281,11 @@ if __name__ == "__main__": ) end_time = time.perf_counter() - print(f"{input.size(0) / (end_time - start_time):.02f} samples per second") + print(f"{seq.size(0) / (end_time - start_time):.02f} samples per second") - q_train_set = input[:, : nb_samples * 3] - q_params = input[:, nb_samples * 3 + 1 :] - print(f"SANITY #2 {q_train_set.size()=} {q_params.size()=} {input.size()=}") + q_train_set = seq[:, : nb_samples * 3] + q_params = seq[:, nb_samples * 3 + 1 :] + print(f"SANITY #2 {q_train_set.size()=} {q_params.size()=} {seq.size()=}") error_train = evaluate_q_params(q_params, q_train_set, nb_mlps_per_batch=17) print(f"train {error_train*100}%") error_test = evaluate_q_params(q_params, q_test_set, nb_mlps_per_batch=17)