projects
/
picoclvr.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
0057bf1
)
Update.
master
author
François Fleuret
<francois@fleuret.org>
Mon, 30 Sep 2024 06:28:37 +0000
(08:28 +0200)
committer
François Fleuret
<francois@fleuret.org>
Mon, 30 Sep 2024 06:28:37 +0000
(08:28 +0200)
tasks.py
patch
|
blob
|
history
diff --git
a/tasks.py
b/tasks.py
index
5f3258b
..
9901715
100755
(executable)
--- a/
tasks.py
+++ b/
tasks.py
@@
-665,13
+665,6
@@
class MNIST(Task):
def create_fourier_basis(self):
self.fourier_basis = generate_2d_fourier_basis(T=28).flatten(1)
self.fourier_basis_inverse = self.fourier_basis.inverse()
def create_fourier_basis(self):
self.fourier_basis = generate_2d_fourier_basis(T=28).flatten(1)
self.fourier_basis_inverse = self.fourier_basis.inverse()
-
- torchvision.utils.save_image(
- 1 - self.fourier_basis.reshape(-1, 1, 28, 28) / self.fourier_basis.std(),
- "fourier.png",
- nrow=28,
- )
-
y = self.train_input.float() @ self.fourier_basis.t()
self.fourier_range = 4
self.fourier_mu = y.mean(dim=0, keepdim=True)
y = self.train_input.float() @ self.fourier_basis.t()
self.fourier_range = 4
self.fourier_mu = y.mean(dim=0, keepdim=True)
@@
-718,24
+711,6
@@
class MNIST(Task):
self.train_input = self.fourier_encode(self.train_input)
self.test_input = self.fourier_encode(self.test_input)
self.train_input = self.fourier_encode(self.train_input)
self.test_input = self.fourier_encode(self.test_input)
- torchvision.utils.save_image(
- 1
- - self.fourier_decode(self.train_input[:256]).reshape(-1, 1, 28, 28)
- / 256,
- "check-train.png",
- nrow=16,
- )
- torchvision.utils.save_image(
- 1
- - self.fourier_decode(self.test_input[:256]).reshape(-1, 1, 28, 28)
- / 256,
- "check-test.png",
- nrow=16,
- )
- # exit(0)
-
- print(f"AFTER {self.train_input.size()=} {self.test_input.size()=}")
-
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
input = self.train_input if split == "train" else self.test_input
def batches(self, split="train", nb_to_use=-1, desc=None):
assert split in {"train", "test"}
input = self.train_input if split == "train" else self.test_input
@@
-754,6
+729,26
@@
class MNIST(Task):
def produce_results(
self, n_epoch, model, result_dir, logger, deterministic_synthesis
):
def produce_results(
self, n_epoch, model, result_dir, logger, deterministic_synthesis
):
+ if n_epoch == 0:
+ image_name = os.path.join(result_dir, "fourier.png")
+ torchvision.utils.save_image(
+ 0.5
+ - 0.5
+ * self.fourier_basis.reshape(-1, 1, 28, 28)
+ / self.fourier_basis.std(),
+ image_name,
+ nrow=28,
+ )
+
+ image_name = os.path.join(result_dir, "check-train.png")
+ torchvision.utils.save_image(
+ 1
+ - self.fourier_decode(self.train_input[:256]).reshape(-1, 1, 28, 28)
+ / 256,
+ image_name,
+ nrow=16,
+ )
+
results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
ar_mask = torch.full_like(results, 1)
masked_inplace_autoregression(
results = torch.empty(64, 28 * 28, device=self.device, dtype=torch.int64)
ar_mask = torch.full_like(results, 1)
masked_inplace_autoregression(