#!/usr/bin/env python

# Any copyright is dedicated to the Public Domain.
# https://creativecommons.org/publicdomain/zero/1.0/

# Written by Francois Fleuret <francois@fleuret.org>

import math
import matplotlib.pyplot as plt

import torch
from torch import nn

######################################################################


def data_rectangle(nb):
    x = torch.rand(nb, 1) - 0.5
    y = torch.rand(nb, 1) * 2 - 1
    data = torch.cat((y, x), 1)
    alpha = math.pi / 8
    data = data @ torch.tensor(
        [[math.cos(alpha), math.sin(alpha)], [-math.sin(alpha), math.cos(alpha)]]
    )
    return data, "rectangle"


def data_zigzag(nb):
    a = torch.empty(nb).uniform_(0, 1).view(-1, 1)
    # zigzag
    x = 0.4 * ((a - 0.5) * 5 * math.pi).cos()
    y = a * 2.5 - 1.25
    data = torch.cat((y, x), 1)
    data = data @ torch.tensor([[1.0, -1.0], [1.0, 1.0]])
    return data, "zigzag"


def data_spiral(nb):
    a = torch.empty(nb).uniform_(0, 1).view(-1, 1)
    x = (a * 2.25 * math.pi).cos() * (a * 0.8 + 0.5)
    y = (a * 2.25 * math.pi).sin() * (a * 0.8 + 0.5)
    data = torch.cat((y, x), 1)
    return data, "spiral"


def data_penta(nb):
    a = (torch.randint(5, (nb,)).float() / 5 * 2 * math.pi).view(-1, 1)
    x = a.cos()
    y = a.sin()
    data = torch.cat((y, x), 1)
    data = data + data.new(data.size()).normal_(0, 0.05)
    return data, "penta"


######################################################################


def train_model(data):
    model = nn.Sequential(nn.Linear(2, 100), nn.ReLU(), nn.Linear(100, 2))

    batch_size, nb_epochs = 100, 1000
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.MSELoss()

    for e in range(nb_epochs):
        acc_loss = 0
        for input in data.split(batch_size):
            noise = input.new(input.size()).normal_(0, 0.1)
            output = model(input + noise)
            loss = criterion(output, input)
            acc_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        if (e + 1) % 100 == 0:
            print(e + 1, acc_loss)

    return model


######################################################################


def save_image(data_name, model, data):
    a = torch.linspace(-1.5, 1.5, 30)
    x = a.view(1, -1, 1).expand(a.size(0), a.size(0), 1)
    y = a.view(-1, 1, 1).expand(a.size(0), a.size(0), 1)
    grid = torch.cat((y, x), 2).view(-1, 2)

    # Take the origins of the arrows on the part of the grid closer than
    # sqrt(0.1) to the data points
    dist = (grid.view(-1, 1, 2) - data.view(1, -1, 2)).pow(2).sum(2).min(1)[0]
    origins = grid[torch.arange(grid.size(0)).masked_select(dist < 0.1)]

    field = model(origins).detach() - origins

    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1)

    ax.axis("off")
    ax.set_xlim(-1.6, 1.6)
    ax.set_ylim(-1.6, 1.6)
    ax.set_aspect(1)

    plot_field = ax.quiver(
        origins[:, 0].numpy(),
        origins[:, 1].numpy(),
        field[:, 0].numpy(),
        field[:, 1].numpy(),
        units="xy",
        scale=1,
        width=3e-3,
        headwidth=25,
        headlength=25,
    )

    plot_data = ax.scatter(
        data[:, 0].numpy(), data[:, 1].numpy(), s=1, color="tab:blue"
    )

    filename = f"denoising_field_{data_name}.pdf"
    print(f"Saving {filename}")
    fig.savefig(filename, bbox_inches="tight")


######################################################################

for data_source in [data_rectangle, data_zigzag, data_spiral, data_penta]:
    data, data_name = data_source(1000)
    data = data - data.mean(0)
    model = train_model(data)
    save_image(data_name, model, data)
