3 # @XREMOTE_HOST: elk.fleuret.org
4 # @XREMOTE_EXEC: python
5 # @XREMOTE_PRE: source ${HOME}/misc/venv/pytorch/bin/activate
6 # @XREMOTE_PRE: killall -u ${USER} -q -9 python || true
7 # @XREMOTE_PRE: ln -sf ${HOME}/data/pytorch ./data
8 # @XREMOTE_SEND: *.py *.sh
10 # Any copyright is dedicated to the Public Domain.
11 # https://creativecommons.org/publicdomain/zero/1.0/
13 # Written by Francois Fleuret <francois@fleuret.org>
17 import torch, torchvision
20 from torch.nn import functional as F
22 ######################################################################
24 nb_quantization_levels = 101
27 def quantize(x, xmin, xmax):
29 ((x - xmin) / (xmax - xmin) * nb_quantization_levels)
31 .clamp(min=0, max=nb_quantization_levels - 1)
35 def dequantize(q, xmin, xmax):
36 return q / nb_quantization_levels * (xmax - xmin) + xmin
39 ######################################################################
42 def generate_sets_and_params(
47 device=torch.device("cpu"),
49 save_as_examples=False,
51 data_input = torch.zeros(batch_nb_mlps, 2 * nb_samples, 2, device=device)
52 data_targets = torch.zeros(
53 batch_nb_mlps, 2 * nb_samples, dtype=torch.int64, device=device
57 nb_values = 2 # more increases the min-max gap
59 rec_support = torch.empty(batch_nb_mlps, nb_rec, 4, device=device)
61 while (data_targets.float().mean(-1) - 0.5).abs().max() > 0.1:
62 i = (data_targets.float().mean(-1) - 0.5).abs() > 0.1
64 support = torch.rand(nb, nb_rec, 2, nb_values, device=device) * 2 - 1
65 support = support.sort(-1).values
66 support = support[:, :, :, torch.tensor([0, nb_values - 1])].view(nb, nb_rec, 4)
68 x = torch.rand(nb, 2 * nb_samples, 2, device=device) * 2 - 1
71 (x[:, None, :, 0] >= support[:, :, None, 0]).long()
72 * (x[:, None, :, 0] <= support[:, :, None, 1]).long()
73 * (x[:, None, :, 1] >= support[:, :, None, 2]).long()
74 * (x[:, None, :, 1] <= support[:, :, None, 3]).long()
80 data_input[i], data_targets[i], rec_support[i] = x, y, support
82 train_input, train_targets = (
83 data_input[:, :nb_samples],
84 data_targets[:, :nb_samples],
86 test_input, test_targets = data_input[:, nb_samples:], data_targets[:, nb_samples:]
88 q_train_input = quantize(train_input, -1, 1)
89 train_input = dequantize(q_train_input, -1, 1)
91 q_test_input = quantize(test_input, -1, 1)
92 test_input = dequantize(q_test_input, -1, 1)
95 a = 2 * torch.arange(nb_quantization_levels).float() / (nb_quantization_levels - 1) - 1
96 xf = torch.cat([a[:,None,None].expand(nb_quantization_levels, nb_quantization_levels,1),
97 a[None,:,None].expand(nb_quantization_levels, nb_quantization_levels,1)], 2)
98 xf = xf.reshape(1,-1,2).expand(min(q_train_input.size(0),10),-1,-1)
99 print(f"{xf.size()=} {x.size()=}")
102 (xf[:, None, :, 0] >= rec_support[:xf.size(0), :, None, 0]).long()
103 * (xf[:, None, :, 0] <= rec_support[:xf.size(0), :, None, 1]).long()
104 * (xf[:, None, :, 1] >= rec_support[:xf.size(0), :, None, 2]).long()
105 * (xf[:, None, :, 1] <= rec_support[:xf.size(0), :, None, 3]).long()
111 full_input, full_targets = xf,yf
113 q_full_input = quantize(full_input, -1, 1)
114 full_input = dequantize(q_full_input, -1, 1)
116 for k in range(q_full_input[:10].size(0)):
117 with open(f"example_full_{k:04d}.dat", "w") as f:
118 for u, c in zip(full_input[k], full_targets[k]):
119 f.write(f"{c} {u[0].item()} {u[1].item()}\n")
121 for k in range(q_train_input[:10].size(0)):
122 with open(f"example_train_{k:04d}.dat", "w") as f:
123 for u, c in zip(train_input[k], train_targets[k]):
124 f.write(f"{c} {u[0].item()} {u[1].item()}\n")
127 w1 = torch.randn(batch_nb_mlps, hidden_dim, 2, device=device) / math.sqrt(2)
128 b1 = torch.zeros(batch_nb_mlps, hidden_dim, device=device)
129 w2 = torch.randn(batch_nb_mlps, 2, hidden_dim, device=device) / math.sqrt(
132 b2 = torch.zeros(batch_nb_mlps, 2, device=device)
138 optimizer = torch.optim.Adam([w1, b1, w2, b2], lr=1e-2)
140 criterion = nn.CrossEntropyLoss()
143 for k in range(nb_epochs):
147 for input, targets in zip(
148 train_input.split(batch_size, dim=1), train_targets.split(batch_size, dim=1)
150 h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :]
152 output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :]
153 loss = F.cross_entropy(
154 output.reshape(-1, output.size(-1)), targets.reshape(-1)
156 acc_train_loss += loss.item() * input.size(0)
158 wta = output.argmax(-1)
159 nb_train_errors += (wta != targets).long().sum(-1)
161 optimizer.zero_grad()
165 with torch.no_grad():
166 for p in [w1, b1, w2, b2]:
168 torch.rand(p.size(), device=p.device) <= k / (nb_epochs - 1)
170 pq = quantize(p, -2, 2)
171 p[...] = (1 - m) * p + m * dequantize(pq, -2, 2)
173 train_error = nb_train_errors / train_input.size(1)
174 acc_train_loss = acc_train_loss / train_input.size(1)
176 # print(f"{k=} {acc_train_loss=} {train_error=}")
181 for input, targets in zip(
182 test_input.split(batch_size, dim=1), test_targets.split(batch_size, dim=1)
184 h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :]
186 output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :]
187 loss = F.cross_entropy(output.reshape(-1, output.size(-1)), targets.reshape(-1))
188 acc_test_loss += loss.item() * input.size(0)
190 wta = output.argmax(-1)
191 nb_test_errors += (wta != targets).long().sum(-1)
193 test_error = nb_test_errors / test_input.size(1)
194 q_params = torch.cat(
195 [quantize(p.view(batch_nb_mlps, -1), -2, 2) for p in [w1, b1, w2, b2]], dim=1
197 q_train_set = torch.cat([q_train_input, train_targets[:, :, None]], -1).reshape(
200 q_test_set = torch.cat([q_test_input, test_targets[:, :, None]], -1).reshape(
204 return q_train_set, q_test_set, q_params, test_error
207 ######################################################################
210 def evaluate_q_params(
211 q_params, q_set, batch_size=25, device=torch.device("cpu"), nb_mlps_per_batch=1024,
212 save_as_examples=False,
215 nb_mlps = q_params.size(0)
217 for n in range(0, nb_mlps, nb_mlps_per_batch):
218 batch_nb_mlps = min(nb_mlps_per_batch, nb_mlps - n)
219 batch_q_params = q_params[n : n + batch_nb_mlps]
220 batch_q_set = q_set[n : n + batch_nb_mlps]
222 w1 = torch.empty(batch_nb_mlps, hidden_dim, 2, device=device)
223 b1 = torch.empty(batch_nb_mlps, hidden_dim, device=device)
224 w2 = torch.empty(batch_nb_mlps, 2, hidden_dim, device=device)
225 b2 = torch.empty(batch_nb_mlps, 2, device=device)
227 with torch.no_grad():
229 for p in [w1, b1, w2, b2]:
230 print(f"{p.size()=}")
232 batch_q_params[:, k : k + p.numel() // batch_nb_mlps], -2, 2
235 k += p.numel() // batch_nb_mlps
237 batch_q_set = batch_q_set.view(batch_nb_mlps, -1, 3)
238 data_input = dequantize(batch_q_set[:, :, :2], -1, 1).to(device)
239 data_targets = batch_q_set[:, :, 2].to(device)
241 print(f"{data_input.size()=} {data_targets.size()=}")
243 criterion = nn.CrossEntropyLoss()
249 for input, targets in zip(
250 data_input.split(batch_size, dim=1), data_targets.split(batch_size, dim=1)
252 h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :]
254 output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :]
255 loss = F.cross_entropy(
256 output.reshape(-1, output.size(-1)), targets.reshape(-1)
258 acc_loss += loss.item() * input.size(0)
259 wta = output.argmax(-1)
260 nb_errors += (wta != targets).long().sum(-1)
262 errors.append(nb_errors / data_input.size(1))
263 acc_loss = acc_loss / data_input.size(1)
265 return torch.cat(errors)
268 ######################################################################
271 def generate_sequence_and_test_set(
277 nb_mlps_per_batch=1024,
279 seqs, q_test_sets, test_errors = [], [], []
281 for n in range(0, nb_mlps, nb_mlps_per_batch):
282 q_train_set, q_test_set, q_params, test_error = generate_sets_and_params(
283 batch_nb_mlps=min(nb_mlps_per_batch, nb_mlps - n),
284 nb_samples=nb_samples,
285 batch_size=batch_size,
294 q_train_set.new_full(
299 nb_quantization_levels,
307 q_test_sets.append(q_test_set)
308 test_errors.append(test_error)
310 seq = torch.cat(seqs)
311 q_test_set = torch.cat(q_test_sets)
312 test_error = torch.cat(test_errors)
314 return seq, q_test_set, test_error
317 ######################################################################
319 if __name__ == "__main__":
322 batch_nb_mlps, nb_samples = 128, 250
324 generate_sets_and_params(
326 nb_samples=nb_samples,
329 device=torch.device("cpu"),
331 save_as_examples=True,
336 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
338 start_time = time.perf_counter()
342 seq, q_test_set, test_error = generate_sequence_and_test_set(
343 nb_mlps=batch_nb_mlps,
344 nb_samples=nb_samples,
348 nb_mlps_per_batch=17,
351 end_time = time.perf_counter()
352 print(f"{seq.size(0) / (end_time - start_time):.02f} samples per second")
354 q_train_set = seq[:, : nb_samples * 3]
355 q_params = seq[:, nb_samples * 3 + 1 :]
356 print(f"SANITY #2 {q_train_set.size()=} {q_params.size()=} {seq.size()=}")
357 error_train = evaluate_q_params(q_params, q_train_set, nb_mlps_per_batch=17)
358 print(f"train {error_train*100}%")
359 error_test = evaluate_q_params(q_params, q_test_set, nb_mlps_per_batch=17)
360 print(f"test {error_test*100}%")