3 # @XREMOTE_HOST: elk.fleuret.org
4 # @XREMOTE_EXEC: python
5 # @XREMOTE_PRE: source ${HOME}/misc/venv/pytorch/bin/activate
6 # @XREMOTE_PRE: killall -u ${USER} -q -9 python || true
7 # @XREMOTE_PRE: ln -sf ${HOME}/data/pytorch ./data
8 # @XREMOTE_SEND: *.py *.sh
10 # Any copyright is dedicated to the Public Domain.
11 # https://creativecommons.org/publicdomain/zero/1.0/
13 # Written by Francois Fleuret <francois@fleuret.org>
17 import torch, torchvision
20 from torch.nn import functional as F
22 ######################################################################
24 nb_quantization_levels = 101
27 def quantize(x, xmin, xmax):
29 ((x - xmin) / (xmax - xmin) * nb_quantization_levels)
31 .clamp(min=0, max=nb_quantization_levels - 1)
35 def dequantize(q, xmin, xmax):
36 return q / nb_quantization_levels * (xmax - xmin) + xmin
39 ######################################################################
45 model = nn.Sequential(
46 nn.Linear(2, hidden_dim),
48 nn.Linear(hidden_dim, hidden_dim),
50 nn.Linear(hidden_dim, 2),
56 ######################################################################
59 def generate_sets_and_params(
64 device=torch.device("cpu"),
67 data_input = torch.zeros(nb_mlps, 2 * nb_samples, 2, device=device)
68 data_targets = torch.zeros(
69 nb_mlps, 2 * nb_samples, dtype=torch.int64, device=device
72 while (data_targets.float().mean(-1) - 0.5).abs().max() > 0.1:
73 i = (data_targets.float().mean(-1) - 0.5).abs() > 0.1
78 support = torch.rand(nb, nb_rec, 2, 3, device=device) * 2 - 1
79 support = support.sort(-1).values
80 support = support[:, :, :, torch.tensor([0, 2])].view(nb, nb_rec, 4)
82 x = torch.rand(nb, 2 * nb_samples, 2, device=device) * 2 - 1
85 (x[:, None, :, 0] >= support[:, :, None, 0]).long()
86 * (x[:, None, :, 0] <= support[:, :, None, 1]).long()
87 * (x[:, None, :, 1] >= support[:, :, None, 2]).long()
88 * (x[:, None, :, 1] <= support[:, :, None, 3]).long()
94 data_input[i], data_targets[i] = x, y
96 train_input, train_targets = (
97 data_input[:, :nb_samples],
98 data_targets[:, :nb_samples],
100 test_input, test_targets = data_input[:, nb_samples:], data_targets[:, nb_samples:]
102 q_train_input = quantize(train_input, -1, 1)
103 train_input = dequantize(q_train_input, -1, 1)
104 train_targets = train_targets
106 q_test_input = quantize(test_input, -1, 1)
107 test_input = dequantize(q_test_input, -1, 1)
108 test_targets = test_targets
111 w1 = torch.randn(nb_mlps, hidden_dim, 2, device=device) / math.sqrt(2)
112 b1 = torch.zeros(nb_mlps, hidden_dim, device=device)
113 w2 = torch.randn(nb_mlps, 2, hidden_dim, device=device) / math.sqrt(hidden_dim)
114 b2 = torch.zeros(nb_mlps, 2, device=device)
120 optimizer = torch.optim.Adam([w1, b1, w2, b2], lr=1e-2)
122 criterion = nn.CrossEntropyLoss()
125 for k in range(nb_epochs):
129 for input, targets in zip(
130 train_input.split(batch_size, dim=1), train_targets.split(batch_size, dim=1)
132 h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :]
134 output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :]
135 loss = F.cross_entropy(
136 output.reshape(-1, output.size(-1)), targets.reshape(-1)
138 acc_train_loss += loss.item() * input.size(0)
140 wta = output.argmax(-1)
141 nb_train_errors += (wta != targets).long().sum(-1)
143 optimizer.zero_grad()
147 with torch.no_grad():
148 for p in [w1, b1, w2, b2]:
150 torch.rand(p.size(), device=p.device) <= k / (nb_epochs - 1)
152 pq = quantize(p, -2, 2)
153 p[...] = (1 - m) * p + m * dequantize(pq, -2, 2)
155 train_error = nb_train_errors / train_input.size(1)
156 acc_train_loss = acc_train_loss / train_input.size(1)
158 # print(f"{k=} {acc_train_loss=} {train_error=}")
160 q_params = torch.cat(
161 [quantize(p.view(nb_mlps, -1), -2, 2) for p in [w1, b1, w2, b2]], dim=1
163 q_train_set = torch.cat([q_train_input, train_targets[:, :, None]], -1).reshape(
166 q_test_set = torch.cat([q_test_input, test_targets[:, :, None]], -1).reshape(
170 return q_train_set, q_test_set, q_params
173 ######################################################################
176 def evaluate_q_params(q_params, q_set, batch_size=25, device=torch.device("cpu")):
177 nb_mlps = q_params.size(0)
179 w1 = torch.empty(nb_mlps, hidden_dim, 2, device=device)
180 b1 = torch.empty(nb_mlps, hidden_dim, device=device)
181 w2 = torch.empty(nb_mlps, 2, hidden_dim, device=device)
182 b2 = torch.empty(nb_mlps, 2, device=device)
184 with torch.no_grad():
186 for p in [w1, b1, w2, b2]:
187 print(f"{p.size()=}")
188 x = dequantize(q_params[:, k : k + p.numel() // nb_mlps], -2, 2).view(
192 k += p.numel() // nb_mlps
194 q_set = q_set.view(nb_mlps, -1, 3)
195 data_input = dequantize(q_set[:, :, :2], -1, 1).to(device)
196 data_targets = q_set[:, :, 2].to(device)
198 print(f"{data_input.size()=} {data_targets.size()=}")
200 criterion = nn.CrossEntropyLoss()
206 for input, targets in zip(
207 data_input.split(batch_size, dim=1), data_targets.split(batch_size, dim=1)
209 h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :]
211 output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :]
212 loss = F.cross_entropy(output.reshape(-1, output.size(-1)), targets.reshape(-1))
213 acc_loss += loss.item() * input.size(0)
214 wta = output.argmax(-1)
215 nb_errors += (wta != targets).long().sum(-1)
217 error = nb_errors / data_input.size(1)
218 acc_loss = acc_loss / data_input.size(1)
223 ######################################################################
226 def generate_sequence_and_test_set(
233 q_train_set, q_test_set, q_params = generate_sets_and_params(
244 q_train_set.new_full(
249 nb_quantization_levels,
256 print(f"SANITY #1 {q_train_set.size()=} {q_params.size()=} {input.size()=}")
259 (torch.arange(input.size(0), device=input.device) > q_train_set.size(0) + 1)
262 .reshape(nb_mlps, -1)
265 return input, ar_mask, q_test_set
268 ######################################################################
270 if __name__ == "__main__":
273 nb_mlps, nb_samples = 128, 200
275 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
277 start_time = time.perf_counter()
283 generate_sequence_and_test_set(
285 nb_samples=nb_samples,
292 end_time = time.perf_counter()
293 nb = sum([i.size(0) for i, _, _ in data])
294 print(f"{nb / (end_time - start_time):.02f} samples per second")
296 for input, ar_mask, q_test_set in data:
297 q_train_set = input[:, : nb_samples * 3]
298 q_params = input[:, nb_samples * 3 + 1 :]
299 print(f"SANITY #2 {q_train_set.size()=} {q_params.size()=} {input.size()=}")
300 error_train = evaluate_q_params(q_params, q_train_set)
301 print(f"train {error_train*100}%")
302 error_test = evaluate_q_params(q_params, q_test_set)
303 print(f"test {error_test*100}%")