X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=qmlp.py;h=abebfc17e1befec6e1e612b1197f47909df083cc;hb=fdc61b7e50e029aac58b10f377acdce549532f84;hp=e12f0e1a6e16e16d9abf7a5ee18edde1e53d7763;hpb=f44ab6863f93ae348e66ffbf52251d96d3b5453c;p=picoclvr.git diff --git a/qmlp.py b/qmlp.py index e12f0e1..abebfc1 100755 --- a/qmlp.py +++ b/qmlp.py @@ -39,8 +39,6 @@ def dequantize(q, xmin, xmax): ###################################################################### - - def generate_sets_and_params( batch_nb_mlps, nb_samples, @@ -48,20 +46,24 @@ def generate_sets_and_params( nb_epochs, device=torch.device("cpu"), print_log=False, + save_as_examples=False, ): data_input = torch.zeros(batch_nb_mlps, 2 * nb_samples, 2, device=device) data_targets = torch.zeros( batch_nb_mlps, 2 * nb_samples, dtype=torch.int64, device=device ) + nb_rec = 8 + nb_values = 2 # more increases the min-max gap + + rec_support = torch.empty(batch_nb_mlps, nb_rec, 4, device=device) + while (data_targets.float().mean(-1) - 0.5).abs().max() > 0.1: i = (data_targets.float().mean(-1) - 0.5).abs() > 0.1 nb = i.sum() - - nb_rec = 2 - support = torch.rand(nb, nb_rec, 2, 3, device=device) * 2 - 1 + support = torch.rand(nb, nb_rec, 2, nb_values, device=device) * 2 - 1 support = support.sort(-1).values - support = support[:, :, :, torch.tensor([0, 2])].view(nb, nb_rec, 4) + support = support[:, :, :, torch.tensor([0, nb_values - 1])].view(nb, nb_rec, 4) x = torch.rand(nb, 2 * nb_samples, 2, device=device) * 2 - 1 y = ( @@ -75,7 +77,7 @@ def generate_sets_and_params( .values ) - data_input[i], data_targets[i] = x, y + data_input[i], data_targets[i], rec_support[i] = x, y, support train_input, train_targets = ( data_input[:, :nb_samples], @@ -85,16 +87,62 @@ def generate_sets_and_params( q_train_input = quantize(train_input, -1, 1) train_input = dequantize(q_train_input, -1, 1) - train_targets = train_targets q_test_input = quantize(test_input, -1, 1) test_input = dequantize(q_test_input, -1, 1) - test_targets = test_targets + + if save_as_examples: + a = ( + 2 + * torch.arange(nb_quantization_levels).float() + / (nb_quantization_levels - 1) + - 1 + ) + xf = torch.cat( + [ + a[:, None, None].expand( + nb_quantization_levels, nb_quantization_levels, 1 + ), + a[None, :, None].expand( + nb_quantization_levels, nb_quantization_levels, 1 + ), + ], + 2, + ) + xf = xf.reshape(1, -1, 2).expand(min(q_train_input.size(0), 10), -1, -1) + print(f"{xf.size()=} {x.size()=}") + yf = ( + ( + (xf[:, None, :, 0] >= rec_support[: xf.size(0), :, None, 0]).long() + * (xf[:, None, :, 0] <= rec_support[: xf.size(0), :, None, 1]).long() + * (xf[:, None, :, 1] >= rec_support[: xf.size(0), :, None, 2]).long() + * (xf[:, None, :, 1] <= rec_support[: xf.size(0), :, None, 3]).long() + ) + .max(dim=1) + .values + ) + + full_input, full_targets = xf, yf + + q_full_input = quantize(full_input, -1, 1) + full_input = dequantize(q_full_input, -1, 1) + + for k in range(q_full_input[:10].size(0)): + with open(f"example_full_{k:04d}.dat", "w") as f: + for u, c in zip(full_input[k], full_targets[k]): + f.write(f"{c} {u[0].item()} {u[1].item()}\n") + + for k in range(q_train_input[:10].size(0)): + with open(f"example_train_{k:04d}.dat", "w") as f: + for u, c in zip(train_input[k], train_targets[k]): + f.write(f"{c} {u[0].item()} {u[1].item()}\n") hidden_dim = 32 w1 = torch.randn(batch_nb_mlps, hidden_dim, 2, device=device) / math.sqrt(2) b1 = torch.zeros(batch_nb_mlps, hidden_dim, device=device) - w2 = torch.randn(batch_nb_mlps, 2, hidden_dim, device=device) / math.sqrt(hidden_dim) + w2 = torch.randn(batch_nb_mlps, 2, hidden_dim, device=device) / math.sqrt( + hidden_dim + ) b2 = torch.zeros(batch_nb_mlps, 2, device=device) w1.requires_grad_() @@ -141,6 +189,22 @@ def generate_sets_and_params( # print(f"{k=} {acc_train_loss=} {train_error=}") + acc_test_loss = 0 + nb_test_errors = 0 + + for input, targets in zip( + test_input.split(batch_size, dim=1), test_targets.split(batch_size, dim=1) + ): + h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :] + h = F.relu(h) + output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :] + loss = F.cross_entropy(output.reshape(-1, output.size(-1)), targets.reshape(-1)) + acc_test_loss += loss.item() * input.size(0) + + wta = output.argmax(-1) + nb_test_errors += (wta != targets).long().sum(-1) + + test_error = nb_test_errors / test_input.size(1) q_params = torch.cat( [quantize(p.view(batch_nb_mlps, -1), -2, 2) for p in [w1, b1, w2, b2]], dim=1 ) @@ -151,21 +215,27 @@ def generate_sets_and_params( batch_nb_mlps, -1 ) - return q_train_set, q_test_set, q_params + return q_train_set, q_test_set, q_params, test_error ###################################################################### -def evaluate_q_params(q_params, q_set, batch_size=25, device=torch.device("cpu"), nb_mlps_per_batch=1024): - +def evaluate_q_params( + q_params, + q_set, + batch_size=25, + device=torch.device("cpu"), + nb_mlps_per_batch=1024, + save_as_examples=False, +): errors = [] nb_mlps = q_params.size(0) - for n in range(0,nb_mlps,nb_mlps_per_batch): - batch_nb_mlps = min(nb_mlps_per_batch,nb_mlps-n) - batch_q_params = q_params[n:n+batch_nb_mlps] - batch_q_set = q_set[n:n+batch_nb_mlps] + for n in range(0, nb_mlps, nb_mlps_per_batch): + batch_nb_mlps = min(nb_mlps_per_batch, nb_mlps - n) + batch_q_params = q_params[n : n + batch_nb_mlps] + batch_q_set = q_set[n : n + batch_nb_mlps] hidden_dim = 32 w1 = torch.empty(batch_nb_mlps, hidden_dim, 2, device=device) b1 = torch.empty(batch_nb_mlps, hidden_dim, device=device) @@ -176,9 +246,9 @@ def evaluate_q_params(q_params, q_set, batch_size=25, device=torch.device("cpu") k = 0 for p in [w1, b1, w2, b2]: print(f"{p.size()=}") - x = dequantize(batch_q_params[:, k : k + p.numel() // batch_nb_mlps], -2, 2).view( - p.size() - ) + x = dequantize( + batch_q_params[:, k : k + p.numel() // batch_nb_mlps], -2, 2 + ).view(p.size()) p.copy_(x) k += p.numel() // batch_nb_mlps @@ -200,7 +270,9 @@ def evaluate_q_params(q_params, q_set, batch_size=25, device=torch.device("cpu") h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :] h = F.relu(h) output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :] - loss = F.cross_entropy(output.reshape(-1, output.size(-1)), targets.reshape(-1)) + loss = F.cross_entropy( + output.reshape(-1, output.size(-1)), targets.reshape(-1) + ) acc_loss += loss.item() * input.size(0) wta = output.argmax(-1) nb_errors += (wta != targets).long().sum(-1) @@ -208,7 +280,6 @@ def evaluate_q_params(q_params, q_set, batch_size=25, device=torch.device("cpu") errors.append(nb_errors / data_input.size(1)) acc_loss = acc_loss / data_input.size(1) - return torch.cat(errors) @@ -223,39 +294,42 @@ def generate_sequence_and_test_set( device, nb_mlps_per_batch=1024, ): + seqs, q_test_sets, test_errors = [], [], [] - inputs, q_test_sets = [],[] - - for n in range(0,nb_mlps,nb_mlps_per_batch): - q_train_set, q_test_set, q_params = generate_sets_and_params( - batch_nb_mlps = min(nb_mlps_per_batch, nb_mlps - n), + for n in range(0, nb_mlps, nb_mlps_per_batch): + q_train_set, q_test_set, q_params, test_error = generate_sets_and_params( + batch_nb_mlps=min(nb_mlps_per_batch, nb_mlps - n), nb_samples=nb_samples, batch_size=batch_size, nb_epochs=nb_epochs, device=device, ) - inputs.append(torch.cat( - [ - q_train_set, - q_train_set.new_full( - ( - q_train_set.size(0), - 1, + seqs.append( + torch.cat( + [ + q_train_set, + q_train_set.new_full( + ( + q_train_set.size(0), + 1, + ), + nb_quantization_levels, ), - nb_quantization_levels, - ), - q_params, - ], - dim=-1, - )) + q_params, + ], + dim=-1, + ) + ) q_test_sets.append(q_test_set) + test_errors.append(test_error) - input = torch.cat(inputs) + seq = torch.cat(seqs) q_test_set = torch.cat(q_test_sets) + test_error = torch.cat(test_errors) - return input, q_test_set + return seq, q_test_set, test_error ###################################################################### @@ -263,7 +337,19 @@ def generate_sequence_and_test_set( if __name__ == "__main__": import time - batch_nb_mlps, nb_samples = 128, 500 + batch_nb_mlps, nb_samples = 128, 250 + + generate_sets_and_params( + batch_nb_mlps=10, + nb_samples=nb_samples, + batch_size=25, + nb_epochs=100, + device=torch.device("cpu"), + print_log=False, + save_as_examples=True, + ) + + exit(0) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -271,21 +357,21 @@ if __name__ == "__main__": data = [] - input, q_test_set = generate_sequence_and_test_set( + seq, q_test_set, test_error = generate_sequence_and_test_set( nb_mlps=batch_nb_mlps, nb_samples=nb_samples, device=device, batch_size=25, nb_epochs=250, - nb_mlps_per_batch=17 + nb_mlps_per_batch=17, ) end_time = time.perf_counter() - print(f"{input.size(0) / (end_time - start_time):.02f} samples per second") + print(f"{seq.size(0) / (end_time - start_time):.02f} samples per second") - q_train_set = input[:, : nb_samples * 3] - q_params = input[:, nb_samples * 3 + 1 :] - print(f"SANITY #2 {q_train_set.size()=} {q_params.size()=} {input.size()=}") + q_train_set = seq[:, : nb_samples * 3] + q_params = seq[:, nb_samples * 3 + 1 :] + print(f"SANITY #2 {q_train_set.size()=} {q_params.size()=} {seq.size()=}") error_train = evaluate_q_params(q_params, q_train_set, nb_mlps_per_batch=17) print(f"train {error_train*100}%") error_test = evaluate_q_params(q_params, q_test_set, nb_mlps_per_batch=17)