4 "cell_type": "markdown",
10 "Any copyright is dedicated to the Public Domain.\n",
11 "https://creativecommons.org/publicdomain/zero/1.0/\n",
13 "Written by Francois Fleuret\n",
14 "https://fleuret.org/francois"
19 "execution_count": null,
27 "import torch.nn.functional as F\n",
28 "from torch import nn\n",
30 "import matplotlib.pyplot as plt\n",
32 "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
37 "execution_count": null,
46 " lambda x: torch.sin(x * math.pi),\n",
47 " lambda x: torch.cos(x * math.pi),\n",
48 " lambda x: torch.sigmoid(5 * x) * 2 - 1,\n",
49 " lambda x: 0.25 * x + 0.75 * torch.sign(x),\n",
50 " lambda x: torch.ceil(x * 2) / 2,\n",
53 "mapping_names = [ 'id', 'sin', 'cos', 'sigmoid', 'gap', 'stairs', ]\n",
55 "def comp(n1, n2, x):\n",
56 " return mappings[n2](mappings[n1](x))\n",
58 "x = torch.linspace(-1, 1, 250)\n",
60 "for f, l in zip(mappings, mapping_names):\n",
61 " plt.plot(x, f(x), label = l)\n",
69 "execution_count": null,
74 "def create_set(nb, probas):\n",
75 " probas = probas.view(-1) / probas.sum()\n",
76 " x = torch.rand(nb, device = device) * 2 - 1\n",
77 " y = x.new(x.size(0), len(mappings)**2, device = device)\n",
78 " for k in range(len(mappings)**2):\n",
79 " n1 = k // len(mappings)\n",
80 " n2 = k % len(mappings)\n",
81 " y[:, k] = comp(n1, n2, x)\n",
82 " a = torch.distributions.categorical.Categorical(probas).sample((nb,))\n",
83 " # y[n][m] = y[n, a[n][m]]\n",
84 " y = y.gather(dim = 1, index = a[:, None])\n",
85 " a1 = F.one_hot(a.div(len(mappings), rounding_mode = 'floor'), num_classes = len(mappings))\n",
86 " a2 = F.one_hot(a%len(mappings), num_classes = len(mappings))\n",
87 " x = torch.cat((x[:, None], a1 * 2 - 1, a2 * 2 - 1), 1)\n",
91 "probas_uniform = torch.full((len(mappings), len(mappings)), 1.0, device = device)\n",
93 "a = torch.arange(len(mappings), device = device)\n",
95 "probas_band = ((a[:, None] - a[None, :])%len(mappings) < len(mappings)/2).float()\n",
97 "probas_blocks = (\n",
98 " a[:, None].div(len(mappings)//2, rounding_mode = 'floor') -\n",
99 " a[None, :].div(len(mappings)//2, rounding_mode = 'floor') == 0\n",
102 "probas_checkboard = ((a[:, None] + a[None, :])%2 == 0).float()\n",
104 "#probas_checkboard = (((a[:, None] + a[None, :])%2 == 0) + (a[:, None] == 0) + (a[None, :] == 0)).float()\n",
106 "print(probas_uniform)\n",
107 "print(probas_band)\n",
108 "print(probas_blocks)\n",
109 "print(probas_checkboard)"
114 "execution_count": null,
119 "def train_model(probas_train, probas_test, nb_samples = 100000, nb_epochs = 25):\n",
121 " dim_hidden = 64\n",
123 " model = nn.Sequential(\n",
124 " nn.Linear(1 + len(mappings) * 2, dim_hidden),\n",
126 " nn.Linear(dim_hidden, dim_hidden),\n",
128 " nn.Linear(dim_hidden, 1),\n",
131 " batch_size = 100\n",
133 " train_input, train_targets = create_set(nb_samples, probas_train)\n",
134 " test_input, test_targets = create_set(nb_samples, probas_test)\n",
135 " train_mu, train_std = train_input.mean(), train_input.std()\n",
136 " train_input = (train_input - train_mu) / train_std\n",
137 " test_input = (test_input - train_mu) / train_std\n",
139 " for k in range(nb_epochs):\n",
140 " optimizer = torch.optim.Adam(model.parameters(), lr = 1e-2 /(k + 1))\n",
142 " acc_train_loss = 0.0\n",
144 " for input, targets in zip(train_input.split(batch_size),\n",
145 " train_targets.split(batch_size)):\n",
146 " output = model(input)\n",
147 " loss = F.mse_loss(output, targets)\n",
148 " acc_train_loss += loss.item() * input.size(0)\n",
150 " optimizer.zero_grad()\n",
151 " loss.backward()\n",
152 " optimizer.step()\n",
154 " acc_test_loss = 0.0\n",
156 " for input, targets in zip(test_input.split(batch_size),\n",
157 " test_targets.split(batch_size)):\n",
158 " output = model(input)\n",
159 " loss = F.mse_loss(output, targets)\n",
160 " acc_test_loss += loss.item() * input.size(0)\n",
162 " #print(f'loss {k} {acc_train_loss/train_input.size(0):f} {acc_test_loss/test_input.size(0):f}')\n",
164 " return train_mu, train_std, model\n",
166 "def prediction(model, mu, std, n1, n2, x):\n",
167 " h1 = F.one_hot(torch.full((x.size(0),), n1, device = device), num_classes = len(mappings)) * 2 - 1\n",
168 " h2 = F.one_hot(torch.full((x.size(0),), n2, device = device), num_classes = len(mappings)) * 2 - 1\n",
169 " input = torch.cat((x[:, None], h1, h2), dim = 1)\n",
170 " input = (input - mu) / std\n",
171 " return model(input).view(-1).detach()"
176 "execution_count": null,
181 "def plot_result(probas_train):\n",
183 " train_mu, train_std, model = train_model(\n",
184 " probas_train = probas_train,\n",
185 " probas_test = probas_uniform,\n",
188 " e = torch.empty(len(mappings), len(mappings))\n",
190 " x = torch.linspace(-1, 1, 250, device = device)\n",
192 " for n1 in range(len(mappings)):\n",
193 " for n2 in range(len(mappings)):\n",
194 " gt = comp(n1, n2, x)\n",
195 " pr = prediction(model, train_mu, train_std, n1, n2, x)\n",
196 " e[n1, n2] = F.mse_loss(gt, pr)\n",
198 " plt.matshow(e, cmap = plt.cm.Blues, vmin = 0, vmax = 1)\n",
200 "plot_result(probas_uniform)\n",
201 "plot_result(probas_band)\n",
202 "plot_result(probas_blocks)\n",
203 "plot_result(probas_checkboard)"
208 "execution_count": null,
213 "train_mu, train_std, model = train_model(\n",
214 " probas_train = probas_checkboard,\n",
215 " probas_test = probas_uniform,\n",
218 "x = torch.linspace(-1, 1, 250, device = device)\n",
220 "for n1, n2 in [ (1, 5), (1, 2), (5, 3), (4, 5) ]:\n",
221 " plt.plot(x.to('cpu'), comp(n1, n2, x).to('cpu'), label = 'ground truth')\n",
222 " plt.plot(x.to('cpu'), prediction(model, train_mu, train_std, n1, n2, x).to('cpu'), label = 'prediction')\n",
229 "execution_count": null,
238 "display_name": "Python 3 (ipykernel)",
239 "language": "python",
247 "file_extension": ".py",
248 "mimetype": "text/x-python",
250 "nbconvert_exporter": "python",
251 "pygments_lexer": "ipython3",