+
+# Gets a pair (x, t) and appends t (scalar or 1d tensor) to x as an
+# additional dimension / channel
+
+class TimeAppender(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, u):
+ x, t = u
+ if not torch.is_tensor(t):
+ t = x.new_full((x.size(0),), t)
+ t = t.view((-1,) + (1,) * (x.dim() - 1)).expand_as(x[:,:1])
+ return torch.cat((x, t), 1)
+
+class ConvNet(nn.Module):
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+
+ ks, nc = 5, 64
+
+ self.core = nn.Sequential(
+ TimeAppender(),
+ nn.Conv2d(in_channels + 1, nc, ks, padding = ks//2),
+ nn.ReLU(),
+ nn.Conv2d(nc, nc, ks, padding = ks//2),
+ nn.ReLU(),
+ nn.Conv2d(nc, nc, ks, padding = ks//2),
+ nn.ReLU(),
+ nn.Conv2d(nc, nc, ks, padding = ks//2),
+ nn.ReLU(),
+ nn.Conv2d(nc, nc, ks, padding = ks//2),
+ nn.ReLU(),
+ nn.Conv2d(nc, out_channels, ks, padding = ks//2),
+ )
+
+ def forward(self, u):
+ return self.core(u)
+
+######################################################################
+# Data