+# Gets a pair (x, t) and appends t (scalar or 1d tensor) to x as an
+# additional dimension / channel
+
+class TimeAppender(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, u):
+ x, t = u
+ if not torch.is_tensor(t):
+ t = x.new_full((x.size(0),), t)
+ t = t.view((-1,) + (1,) * (x.dim() - 1)).expand_as(x[:,:1])
+ return torch.cat((x, t), 1)
+