4 # Any copyright is dedicated to the Public Domain.
5 # https://creativecommons.org/publicdomain/zero/1.0/
7 # Written by Francois Fleuret <francois@fleuret.org>
9 import math, argparse, os
11 import torch, torchvision
14 from torch.nn import functional as F
16 ######################################################################
18 parser = argparse.ArgumentParser()
20 parser.add_argument("--result_dir", type=str, default="/tmp")
22 args = parser.parse_args()
24 ######################################################################
26 # If the source is older than the result, do nothing
28 ref_filename = os.path.join(args.result_dir, f"warp_0.tex")
30 if os.path.exists(ref_filename) and os.path.getmtime(__file__) < os.path.getmtime(
35 ######################################################################
40 x = torch.rand(nb, 2) * torch.tensor([math.pi * 1.5, 0.10]) + torch.tensor(
41 [math.pi * -0.25, 0.25]
44 train_targets = (torch.rand(nb) < 0.5).long()
45 train_input = torch.cat((x[:, 0:1].sin() * x[:, 1:2], x[:, 0:1].cos() * x[:, 1:2]), 1)
46 train_input[:, 0] *= train_targets * 2 - 1
47 train_input[:, 0] += 0.05 * (train_targets * 2 - 1)
48 train_input[:, 1] -= 0.15 * (train_targets * 2 - 1)
52 model = nn.Sequential(
53 nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()),
54 nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()),
55 nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()),
56 nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()),
57 nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()),
58 nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()),
59 nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()),
60 nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()),
65 for p in model.modules():
66 if isinstance(p, nn.Linear):
68 p.weight[...] = 2 * torch.eye(2) + torch.randn(2, 2) * 1e-4
70 optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
71 criterion = nn.CrossEntropyLoss()
73 nb_epochs, batch_size = 1000, 25
75 for k in range(nb_epochs):
78 for input, targets in zip(
79 train_input.split(batch_size), train_targets.split(batch_size)
82 loss = criterion(output, targets)
83 acc_loss += loss.item()
90 for input, targets in zip(
91 train_input.split(batch_size), train_targets.split(batch_size)
93 wta = model(input).argmax(1)
94 nb_train_errors += (wta != targets).long().sum()
95 train_error = nb_train_errors / train_input.size(0)
97 print(f"loss {k} {acc_loss:.02f} {train_error*100:.02f}%")
102 ######################################################################
106 input, targets = train_input, train_targets
108 grid = torch.linspace(-1.2, 1.2, sg)
110 (grid[:, None, None].expand(sg, sg, 1), grid[None, :, None].expand(sg, sg, 1)), -1
113 for l, m in enumerate(model):
114 with open(os.path.join(args.result_dir, f"warp_{l}.tex"), "w") as f:
117 scatter src=explicit symbolic,
118 scatter/classes={0={blue}, 1={red}},
119 scatter, mark=*, only marks, mark options={mark size=0.5},
126 f.write(f"{input[k,0]} {input[k,1]} {targets[k]}\n")
129 g = grid.reshape(sg, sg, -1)
130 for i in range(g.size(0)):
131 for j in range(g.size(1)):
133 pre = "\\draw[black!25,very thin] "
136 f.write(f"{pre} ({g[i,j,0]},{g[i,j,1]})")
139 for j in range(g.size(1)):
140 for i in range(g.size(0)):
142 pre = "\\draw[black!25,very thin] "
145 f.write(f"{pre} ({g[i,j,0]},{g[i,j,1]})")
148 # add the decision line
150 if l == len(model) - 1:
151 u = torch.tensor([[1.0, -1.0]])
153 a, b = (u @ phi.weight).squeeze(), (u @ phi.bias).item()
154 p = a * (b / (a @ a.t()).item())
156 f"\\draw[black,thick] ({p[0]-a[1]},{p[1]+a[0]}) -- ({p[0]+a[1]},{p[1]-a[0]});"
159 input, grid = m(input), m(grid)
161 ######################################################################