):
# some paranoia
imt = imt.clone()
- imt[:, 0] = imt[:, 0] * (1 - imt[:1])
+ imt[:, 0] = imt[:, 0] * (1 - imt[:, 1])
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
logits = model(imt[:, 0] * 2 + imt[:, 1])
dist = torch.distributions.categorical.Categorical(logits=logits)
- result = (1 - masks) * imt[:, 0] + masks * dist.sample()
+ result = (1 - imt[:, 1]) * imt[:, 0] + imt[:, 1] * dist.sample()
record.append(result)
return torch.cat(record)