# Written by Francois Fleuret <francois@fleuret.org>
-import math, os, tqdm, warnings
+import math, os, tqdm, warnings, sys
import torch, torchvision
class MNIST(Task):
+ def fourier_encode(self, x):
+ y = torch.linalg.lstsq(
+ self.fourier_basis.to(x.device).t(), x.float().t()
+ ).solution.t()
+ y = ((y / 255).clamp(min=0, max=1) * 255).long()
+ return y
+
+ def fourier_decode(self, y):
+ return y.float() @ self.fourier_basis.to(y.device)
+
def __init__(
- self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu")
+ self,
+ nb_train_samples,
+ nb_test_samples,
+ batch_size,
+ fourier=True,
+ device=torch.device("cpu"),
):
super().__init__()
data_set = torchvision.datasets.MNIST(root="./data", train=False, download=True)
self.test_input = data_set.data[:nb_test_samples].view(-1, 28 * 28).long()
+ self.fourier = fourier
+
+ if fourier:
+ self.create_fourier_basis()
+
+ # print(f"BEFORE {self.train_input.size()=} {self.test_input.size()=}")
+
+ self.train_input = self.fourier_encode(self.train_input[:256])
+ self.test_input = self.fourier_encode(self.test_input[:256])
+
+ torchvision.utils.save_image(
+ 1 - self.fourier_decode(self.train_input).reshape(-1, 1, 28, 28) / 256,
+ "train.png",
+ nrow=16,
+ )
+ torchvision.utils.save_image(
+ 1 - self.fourier_decode(self.test_input).reshape(-1, 1, 28, 28) / 256,
+ "test.png",
+ nrow=16,
+ )
+ # exit(0)
+
+ print(f"AFTER {self.train_input.size()=} {self.test_input.size()=}")
+
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
input = self.train_input if split == "train" else self.test_input
deterministic_synthesis,
device=self.device,
)
+
+ results = self.fourier_decode(results)
+
image_name = os.path.join(result_dir, f"mnist_result_{n_epoch:04d}.png")
torchvision.utils.save_image(
1 - results.reshape(-1, 1, 28, 28) / 255.0,