if invert_size is None:
return (
x.reshape(
- x.size(0), #0
- x.size(1), #1
- factor, #2
- x.size(2) // factor,#3
- factor,#4
- x.size(3) // factor,#5
+ x.size(0), # 0
+ x.size(1), # 1
+ factor, # 2
+ x.size(2) // factor, # 3
+ factor, # 4
+ x.size(3) // factor, # 5
)
.permute(0, 2, 4, 1, 3, 5)
.reshape(-1, x.size(1), x.size(2) // factor, x.size(3) // factor)
else:
return (
x.reshape(
- invert_size[0], #0
- factor, #1
- factor, #2
- invert_size[1], #3
- invert_size[2] // factor, #4
- invert_size[3] // factor, #5
+ invert_size[0], # 0
+ factor, # 1
+ factor, # 2
+ invert_size[1], # 3
+ invert_size[2] // factor, # 4
+ invert_size[3] // factor, # 5
)
.permute(0, 3, 1, 4, 2, 5)
.reshape(invert_size)
)
+def train_encoder(input, device=torch.device("cpu")):
+ class SomeLeNet(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
+ self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
+ self.fc1 = nn.Linear(256, 200)
+ self.fc2 = nn.Linear(200, 10)
+
+ def forward(self, x):
+ x = F.relu(F.max_pool2d(self.conv1(x), kernel_size=3))
+ x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=2))
+ x = x.view(x.size(0), -1)
+ x = F.relu(self.fc1(x))
+ x = self.fc2(x)
+ return x
+
+ ######################################################################
+
+ model = SomeLeNet()
+
+ nb_parameters = sum(p.numel() for p in model.parameters())
+
+ print(f"nb_parameters {nb_parameters}")
+
+ optimizer = torch.optim.SGD(model.parameters(), lr=lr)
+ criterion = nn.CrossEntropyLoss()
+
+ model.to(device)
+ criterion.to(device)
+
+ train_input, train_targets = train_input.to(device), train_targets.to(device)
+ test_input, test_targets = test_input.to(device), test_targets.to(device)
+
+ mu, std = train_input.mean(), train_input.std()
+ train_input.sub_(mu).div_(std)
+ test_input.sub_(mu).div_(std)
+
+ start_time = time.perf_counter()
+
+ for k in range(nb_epochs):
+ acc_loss = 0.0
+
+ for input, targets in zip(
+ train_input.split(batch_size), train_targets.split(batch_size)
+ ):
+ output = model(input)
+ loss = criterion(output, targets)
+ acc_loss += loss.item()
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ nb_test_errors = 0
+ for input, targets in zip(
+ test_input.split(batch_size), test_targets.split(batch_size)
+ ):
+ wta = model(input).argmax(1)
+ nb_test_errors += (wta != targets).long().sum()
+ test_error = nb_test_errors / test_input.size(0)
+ duration = time.perf_counter() - start_time
+
+ print(f"loss {k} {duration:.02f}s {acc_loss:.02f} {test_error*100:.02f}%")
+
+
+######################################################################
+
if __name__ == "__main__":
import time
print(f"{nb / (end_time - start_time):.02f} samples per second")
input = torch.cat(all_frames, 0)
- x = patchify(input, 8)
- y = x.reshape(x.size(0), -1)
- print(f"{x.size()=} {y.size()=}")
- centroids, t = kmeans(y, 4096)
- results = centroids[t]
- results = results.reshape(x.size())
- results = patchify(results, 8, input.size())
+
+ # x = patchify(input, 8)
+ # y = x.reshape(x.size(0), -1)
+ # print(f"{x.size()=} {y.size()=}")
+ # centroids, t = kmeans(y, 4096)
+ # results = centroids[t]
+ # results = results.reshape(x.size())
+ # results = patchify(results, 8, input.size())
print(f"{input.size()=} {results.size()=}")