######################################################################
- def make_mask_hints(mask_generate, nb_hints):
+ def make_mask_hints(self, mask_generate, nb_hints):
if nb_hints is None:
- mask_hints = None
+ mask_hints = torch.zeros(
+ mask_generate.size(),
+ device=mask_generate.device,
+ dtype=mask_generate.dtype,
+ )
else:
u = (
torch.rand(mask_generate.size(), device=mask_generate.device)
t = dist.sample() + 1
- x_T_with_hints = mask_hints * x_0 + (1 - mask_hint) * noise
+ x_T_with_hints = mask_hints * x_0 + (1 - mask_hints) * noise
x_t = self.sample_x_t_given_x_0(x_0, t)
x_t = single_iteration * x_T_with_hints + (1 - single_iteration) * x_t
x_t = (1 - mask_generate) * x_0 + mask_generate * x_t
mask_hints = self.make_mask_hints(mask_generate, nb_hints)
- x_T_with_hints = mask_hints * x_0 + (1 - mask_hint) * noise
- x_t = self.sample_x_t_given_x_0(x_0, t)
- x_t = single_iteration * x_T_with_hints + (1 - single_iteration) * x_t
+ x_T_with_hints = mask_hints * x_0 + (1 - mask_hints) * noise
+ x_t = single_iteration * x_T_with_hints + (1 - single_iteration) * noise
x_t = (1 - mask_generate) * x_0 + mask_generate * x_t
changed = True