#!/usr/bin/env python
-import math
+import math, sys
import torch, torchvision
return False
-def scene2tensor(xh, yh, scene, size=64):
+def scene2tensor(xh, yh, scene, size):
width, height = size, size
pixel_map = torch.ByteTensor(width, height, 4).fill_(255)
data = pixel_map.numpy()
return scene
-def sequence(nb_steps=10, all_frames=False):
+def generate_sequence(nb_steps=10, all_frames=False, size=64):
delta = 0.1
effects = [
(False, 0, 0),
scene = random_scene()
xh, yh = tuple(x.item() for x in torch.rand(2))
- frames.append(scene2tensor(xh, yh, scene))
+ frames.append(scene2tensor(xh, yh, scene, size=size))
actions = torch.randint(len(effects), (nb_steps,))
change = False
xh, yh = x, y
if all_frames:
- frames.append(scene2tensor(xh, yh, scene))
+ frames.append(scene2tensor(xh, yh, scene, size=size))
if not all_frames:
- frames.append(scene2tensor(xh, yh, scene))
+ frames.append(scene2tensor(xh, yh, scene, size=size))
if change:
break
)
-def train_encoder(input, device=torch.device("cpu")):
- class SomeLeNet(nn.Module):
- def __init__(self):
- super().__init__()
- self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
- self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
- self.fc1 = nn.Linear(256, 200)
- self.fc2 = nn.Linear(200, 10)
+class Normalizer(nn.Module):
+ def __init__(self, mu, std):
+ super().__init__()
+ self.mu = nn.Parameter(mu)
+ self.log_var = nn.Parameter(2*torch.log(std))
- def forward(self, x):
- x = F.relu(F.max_pool2d(self.conv1(x), kernel_size=3))
- x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=2))
- x = x.view(x.size(0), -1)
- x = F.relu(self.fc1(x))
- x = self.fc2(x)
- return x
+ def forward(self, x):
+ return (x-self.mu)/torch.exp(self.log_var/2.0)
- ######################################################################
+class SignSTE(nn.Module):
+ def __init__(self):
+ super().__init__()
- model = SomeLeNet()
-
- nb_parameters = sum(p.numel() for p in model.parameters())
+ def forward(self, x):
+ # torch.sign() takes three values
+ s = (x >= 0).float() * 2 - 1
+ if self.training:
+ u = torch.tanh(x)
+ return s + u - u.detach()
+ else:
+ return s
+
+
+def train_encoder(
+ train_input,
+ dim_hidden=64,
+ block_size=16,
+ nb_bits_per_block=10,
+ lr_start=1e-3, lr_end=1e-5,
+ nb_epochs=50,
+ batch_size=25,
+ device=torch.device("cpu"),
+):
+ mu, std = train_input.mean(), train_input.std()
- print(f"nb_parameters {nb_parameters}")
+ encoder = nn.Sequential(
+ Normalizer(mu, std),
+ nn.Conv2d(3, dim_hidden, kernel_size=5, stride=1, padding=2),
+ nn.ReLU(),
+ nn.Conv2d(dim_hidden, dim_hidden, kernel_size=5, stride=1, padding=2),
+ nn.ReLU(),
+ nn.Conv2d(dim_hidden, dim_hidden, kernel_size=5, stride=1, padding=2),
+ nn.ReLU(),
+ nn.Conv2d(dim_hidden, dim_hidden, kernel_size=5, stride=1, padding=2),
+ nn.ReLU(),
+ nn.Conv2d(dim_hidden, dim_hidden, kernel_size=5, stride=1, padding=2),
+ nn.ReLU(),
+ nn.Conv2d(
+ dim_hidden,
+ nb_bits_per_block,
+ kernel_size=block_size,
+ stride=block_size,
+ padding=0,
+ ),
+ SignSTE(),
+ )
- optimizer = torch.optim.SGD(model.parameters(), lr=lr)
- criterion = nn.CrossEntropyLoss()
+ decoder = nn.Sequential(
+ nn.ConvTranspose2d(
+ nb_bits_per_block,
+ dim_hidden,
+ kernel_size=block_size,
+ stride=block_size,
+ padding=0,
+ ),
+ nn.ReLU(),
+ nn.Conv2d(dim_hidden, dim_hidden, kernel_size=5, stride=1, padding=2),
+ nn.ReLU(),
+ nn.Conv2d(dim_hidden, dim_hidden, kernel_size=5, stride=1, padding=2),
+ nn.ReLU(),
+ nn.Conv2d(dim_hidden, dim_hidden, kernel_size=5, stride=1, padding=2),
+ nn.ReLU(),
+ nn.Conv2d(dim_hidden, 3, kernel_size=5, stride=1, padding=2),
+ )
- model.to(device)
- criterion.to(device)
+ model = nn.Sequential(encoder, decoder)
- train_input, train_targets = train_input.to(device), train_targets.to(device)
- test_input, test_targets = test_input.to(device), test_targets.to(device)
+ nb_parameters = sum(p.numel() for p in model.parameters())
- mu, std = train_input.mean(), train_input.std()
- train_input.sub_(mu).div_(std)
- test_input.sub_(mu).div_(std)
+ print(f"nb_parameters {nb_parameters}")
- start_time = time.perf_counter()
+ model.to(device)
for k in range(nb_epochs):
- acc_loss = 0.0
+ lr=math.exp(math.log(lr_start) + math.log(lr_end/lr_start)/(nb_epochs-1)*k)
+ print(f"lr {lr}")
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
+ acc_loss, nb_samples = 0.0, 0
- for input, targets in zip(
- train_input.split(batch_size), train_targets.split(batch_size)
- ):
+ for input in train_input.split(batch_size):
output = model(input)
- loss = criterion(output, targets)
- acc_loss += loss.item()
+ loss = F.mse_loss(output, input)
+ acc_loss += loss.item() * input.size(0)
+ nb_samples += input.size(0)
optimizer.zero_grad()
loss.backward()
optimizer.step()
- nb_test_errors = 0
- for input, targets in zip(
- test_input.split(batch_size), test_targets.split(batch_size)
- ):
- wta = model(input).argmax(1)
- nb_test_errors += (wta != targets).long().sum()
- test_error = nb_test_errors / test_input.size(0)
- duration = time.perf_counter() - start_time
+ print(f"loss {k} {acc_loss/nb_samples}")
+ sys.stdout.flush()
- print(f"loss {k} {duration:.02f}s {acc_loss:.02f} {test_error*100:.02f}%")
+ return encoder, decoder
######################################################################
import time
all_frames = []
- nb = 1000
+ nb = 25000
start_time = time.perf_counter()
for n in range(nb):
- frames, actions = sequence(nb_steps=31)
+ frames, actions = generate_sequence(nb_steps=31)
all_frames += frames
end_time = time.perf_counter()
print(f"{nb / (end_time - start_time):.02f} samples per second")
input = torch.cat(all_frames, 0)
+ encoder, decoder = train_encoder(input)
# x = patchify(input, 8)
# y = x.reshape(x.size(0), -1)
# results = results.reshape(x.size())
# results = patchify(results, 8, input.size())
- print(f"{input.size()=} {results.size()=}")
+ z = encoder(input)
+ results = decoder(z)
+
+ print(f"{input.size()=} {z.size()=} {results.size()=}")
torchvision.utils.save_image(input[:64], "orig.png", nrow=8)
+
torchvision.utils.save_image(results[:64], "qtiz.png", nrow=8)
- # frames, actions = sequence(nb_steps=31, all_frames=True)
+ # frames, actions = generate_sequence(nb_steps=31, all_frames=True)
# frames = torch.cat(frames, 0)
# torchvision.utils.save_image(frames, "seq.png", nrow=8)