m.bias.zero_()
m.weight.fill_(1.0)
+ # x[ 0 ], x[ 1 ], ..., x[ T-2 ], x[ T-1 ]
+ # x[sigma[0]], x[sigma[1]], ..., x[sigma[T-2]], x[sigma[T-1]]
+ # x[ -1 ], x[sigma[0]], ..., x[sigma[T-3]], x[sigma[T-2]]
+
+ # y[sigma[0]], y[sigma[1]], ..., y[sigma[T-2]], y[sigma[T-1]]
+ # y[ 0 ], y[ 1 ], ..., y[ T-2 ], y[ T-1 ]
+
def forward(self, bs, sigma=None):
if sigma is not None:
+ # x[n,t] = x[n,sigma[n,t]]
bs.x = bs.x.gather(dim=1, index=sigma)
bs = BracketedSequence(F.pad(bs.x, (1, -1)), bs.first, bs.nb)
bs = self.embedding(bs)
bs = self.readout(bs)
bs.x[:, bs.first : bs.first + bs.nb] /= self.temperature
if sigma is not None:
- bs.x.scatter_(
- dim=1, index=sigma[:, :, None].expand_as(bs.x), src=bs.x.clone()
- )
+ y = bs.x.new_zeros(bs.x.size())
+ y.scatter_(dim=1, index=sigma[:, :, None].expand_as(bs.x), src=bs.x)
+ bs.x = y
return bs
def encode(self, bs):
if input.size(0) == 0:
return
- to_generate = (ar_mask.sum(0) > 0).nonzero()
-
- if to_generate.min() > 0:
- model(
- BracketedSequence(input, 0, to_generate.min())
- ) # Needed to initialize the model's cache
- for s in range(to_generate.min(), to_generate.max() + 1):
+ for s in range(input.size(1)):
output = model(BracketedSequence(input, s, 1), sigma).x
-
- logits = output[:, s]
+ all_n = torch.arange(input.size(0), device=input.device)
+ u = sigma[:, s]
+ logits = output[all_n, u]
if deterministic_synthesis:
t_next = logits.argmax(-1)
dist = torch.distributions.categorical.Categorical(logits=logits)
t_next = dist.sample()
- all_n = torch.arange(t_next.size(0))
-
seq_logproba += logits[all_n, t_next]
- input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
+ input[all_n, u] = (
+ ar_mask[all_n, u] * t_next + (1 - ar_mask[all_n, u]) * input[all_n, u]
+ )
######################################################################