From 8f1ae8ff3931ade16bcd7d01f6107a3a474e5baf Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Fri, 16 Jun 2023 16:45:41 +0200 Subject: [PATCH] Initial commit --- warp.py | 168 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ warp.tex | 66 ++++++++++++++++++++++ 2 files changed, 234 insertions(+) create mode 100755 warp.py create mode 100644 warp.tex diff --git a/warp.py b/warp.py new file mode 100755 index 0000000..96dfa11 --- /dev/null +++ b/warp.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python + + +# Any copyright is dedicated to the Public Domain. +# https://creativecommons.org/publicdomain/zero/1.0/ + +# Written by Francois Fleuret + +import math, argparse, os + +import torch, torchvision + +from torch import nn +from torch.nn import functional as F + +###################################################################### + +parser = argparse.ArgumentParser() + +parser.add_argument("--result_dir", type=str, default="/tmp") + +args = parser.parse_args() + +###################################################################### + +# If the source is older than the result, do nothing + +ref_filename = os.path.join(args.result_dir, f"warp_0.tex") + +if os.path.exists(ref_filename) and os.path.getmtime(__file__) < os.path.getmtime( + ref_filename +): + exit(0) + +###################################################################### + +torch.manual_seed(0) + +nb = 1000 +x = torch.rand(nb, 2) * torch.tensor([math.pi * 1.5, 0.10]) + torch.tensor( + [math.pi * -0.25, 0.25] +) + +train_targets = (torch.rand(nb) < 0.5).long() +train_input = torch.cat((x[:, 0:1].sin() * x[:, 1:2], x[:, 0:1].cos() * x[:, 1:2]), 1) +train_input[:, 0] *= train_targets * 2 - 1 +train_input[:, 0] += 0.05 * (train_targets * 2 - 1) +train_input[:, 1] -= 0.15 * (train_targets * 2 - 1) +train_input *= 1.2 + + +class WithResidual(nn.Module): + def __init__(self, *f): + super().__init__() + self.f = f[0] if len(f) == 1 else nn.Sequential(*f) + + def forward(self, x): + return 0.5 * x + 0.5 * self.f(x) + + +model = nn.Sequential( + nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()), + nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()), + nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()), + nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()), + nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()), + nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()), + nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()), + nn.Sequential(nn.Linear(2, 2, bias=False), nn.Tanh()), + nn.Linear(2, 2), +) + +with torch.no_grad(): + for p in model.modules(): + if isinstance(p, nn.Linear): + # p.bias.zero_() + p.weight[...] = 2 * torch.eye(2) + torch.randn(2, 2) * 1e-4 + +optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) +criterion = nn.CrossEntropyLoss() + +nb_epochs, batch_size = 1000, 25 + +for k in range(nb_epochs): + acc_loss = 0.0 + + for input, targets in zip( + train_input.split(batch_size), train_targets.split(batch_size) + ): + output = model(input) + loss = criterion(output, targets) + acc_loss += loss.item() + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + nb_train_errors = 0 + for input, targets in zip( + train_input.split(batch_size), train_targets.split(batch_size) + ): + wta = model(input).argmax(1) + nb_train_errors += (wta != targets).long().sum() + train_error = nb_train_errors / train_input.size(0) + + print(f"loss {k} {acc_loss:.02f} {train_error*100:.02f}%") + + if train_error == 0: + break + +###################################################################### + +sg=25 + +input, targets = train_input, train_targets + +grid = torch.linspace(-1.2,1.2,sg) +grid = torch.cat((grid[:,None,None].expand(sg,sg,1),grid[None,:,None].expand(sg,sg,1)),-1).reshape(-1,2) + +for l, m in enumerate(model): + with open(os.path.join(args.result_dir, f"warp_{l}.tex"), "w") as f: + f.write( + """\\addplot[ + scatter src=explicit symbolic, + scatter/classes={0={blue}, 1={red}}, + scatter, mark=*, only marks, mark options={mark size=0.5}, +]% +table[meta=label] { +x y label +""" + ) + for k in range(512): + f.write(f"{input[k,0]} {input[k,1]} {targets[k]}\n") + f.write("};\n") + + g = grid.reshape(sg,sg,-1) + for i in range(g.size(0)): + for j in range(g.size(1)): + if j == 0: + pre="\\draw[black!25,very thin] " + else: + pre="--" + f.write(f"{pre} ({g[i,j,0]},{g[i,j,1]})") + f.write(";\n") + + for j in range(g.size(1)): + for i in range(g.size(0)): + if i == 0: + pre="\\draw[black!25,very thin] " + else: + pre="--" + f.write(f"{pre} ({g[i,j,0]},{g[i,j,1]})") + f.write(";\n") + + # add the decision line + + if l == len(model) - 1: + u = torch.tensor([[1.0, -1.0]]) + phi = model[-1] + a, b = (u @ phi.weight).squeeze(), (u @ phi.bias).item() + p = a * (b / (a @ a.t()).item()) + f.write( + f"\\draw[black,thick] ({p[0]-a[1]},{p[1]+a[0]}) -- ({p[0]+a[1]},{p[1]-a[0]});" + ) + + input, grid = m(input), m(grid) + +###################################################################### diff --git a/warp.tex b/warp.tex new file mode 100644 index 0000000..d03b295 --- /dev/null +++ b/warp.tex @@ -0,0 +1,66 @@ +%% -*- mode: latex; mode: reftex; mode: flyspell; coding: utf-8; tex-command: "pdflatex.sh" -*- + +\documentclass[11pt,a4paper,twoside]{article} +\usepackage[a4paper,top=2.5cm,bottom=2cm,left=2.5cm,right=2.5cm]{geometry} +\usepackage[colorlinks=true,linkcolor=blue,urlcolor=blue,citecolor=blue]{hyperref} +\usepackage{amsmath} +\usepackage{amssymb} +\usepackage{dsfont} +\usepackage{tikz} +\usetikzlibrary{arrows,arrows.meta,calc} +\usetikzlibrary{patterns,backgrounds} +\usetikzlibrary{positioning,fit} +\usetikzlibrary{shapes.geometric,shapes.multipart} +\usetikzlibrary{patterns.meta,decorations.pathreplacing,calligraphy} +\usetikzlibrary{tikzmark} +\usetikzlibrary{decorations.pathmorphing} +\usepackage{pgfplots} +\usepgfplotslibrary{patchplots,colormaps} +\pgfplotsset{compat = newest} + + +\begin{document} + +\definecolor{blue}{rgb}{0.3,0.5,0.85} +\definecolor{red}{rgb}{0.65,0.0,0.0} + +\begin{figure} + + \immediate\write18{./warp.py --result_dir=.} + + \newcommand{\warp}[1]{% + \begin{tikzpicture} + \begin{axis}[ticks=none,width=7.0cm, height=7.0cm,xmin=-1.2,xmax=1.2,ymin=-1.2,ymax=1.2] + \input{#1} + \end{axis} + \end{tikzpicture} + } + + \center + + \begin{tikzpicture}[warp/.style={inner sep=1pt,minimum width=5.0cm,minimum height=5.0cm}] + \node[warp] (W0) {\warp{warp_0.tex}}; + \node[warp,right=2pt of W0] (W1) {\warp{warp_1.tex}}; + \node[warp,right=2pt of W1] (W2) {\warp{warp_2.tex}}; + \node[warp,below=20pt of W0] (W3) {\warp{warp_3.tex}}; + \node[warp,right=2pt of W3] (W4) {\warp{warp_4.tex}}; + \node[warp,right=2pt of W4] (W5) {\warp{warp_5.tex}}; + \node[warp,below=20pt of W3] (W6) {\warp{warp_6.tex}}; + \node[warp,right=2pt of W6] (W7) {\warp{warp_7.tex}}; + \node[warp,right=2pt of W7] (W8) {\warp{warp_8.tex}}; + \node[inner sep=0pt,below=4pt of W0] (lW0) {\footnotesize Input}; + \foreach \n in {1,...,8}{ + \node[inner sep=0pt,below=4pt of W\n] (lW\n) {\footnotesize Layer \#\n}; + }; + + \end{tikzpicture} + + \caption[Feature warping]{Each plot shows the deformation of the space + and the resulting distribution of the training points in + $\mathbb{R}^2$ corresponding to the output of each layer, starting + with the input in the top-left square. The thick oblique line in the + bottom-right plot shows the final affine decision.}\label{fig:warp} + +\end{figure} + +\end{document} -- 2.39.5