parser.add_argument("--temperature_cold", type=float, default=1)
-parser.add_argument("--prompt_noise", type=float, default=0.05)
+parser.add_argument("--prompt_noise", type=float, default=0.0)
parser.add_argument("--dirty_debug", action="store_true", default=False)
problem=problem,
batch_size=args.inference_batch_size,
result_dir=args.result_dir,
- prompt_noise=args.prompt_noise,
logger=log_string,
device=main_device,
)
nb_diffusion_iterations = 25
-def degrade_input(input, mask_generate, nb_iterations):
+def degrade_input_to_generate(input, mask_generate, nb_iterations):
noise = torch.randint(
quiz_machine.problem.nb_colors, input.size(), device=input.device
)
return result
-def targets_and_prediction(model, input, mask_generate):
+def targets_and_prediction(model, input, mask_generate, prompt_noise=0.0):
d = deterministic(mask_generate)
+
probs_iterations = 0.1 ** torch.linspace(
0, 1, args.nb_diffusion_iterations, device=input.device
)
+
probs_iterations = probs_iterations[None, :] / probs_iterations.sum()
probs_iterations = probs_iterations.expand(input.size(0), -1)
dist = torch.distributions.categorical.Categorical(probs=probs_iterations)
- N0 = dist.sample()
- N1 = N0 + 1
- N0 = (1 - d) * N0
- N1 = (1 - d) * N1 + d * args.nb_diffusion_iterations
- targets, input = degrade_input(input, mask_generate, (0 * N1, N1))
+ # N0 = dist.sample()
+ # N1 = N0 + 1
+ # N0 = (1 - d) * N0
+ # N1 = (1 - d) * N1 + d * args.nb_diffusion_iterations
+
+ N0 = input.new_zeros(input.size(0))
+ N1 = dist.sample() + 1
+
+ targets, input = degrade_input_to_generate(input, mask_generate, (N0, N1))
+
+ if prompt_noise > 0:
+ mask_prompt_noise = (
+ torch.rand(input.size(), device=input.device()) <= prompt_noise
+ ).long()
+ noise = torch.randint(
+ quiz_machine.problem.nb_colors, input.size(), device=input.device
+ )
+ noisy_input = (1 - mask_prompt_noise) * input + mask_prompt_noise * noise
+ input = (1 - mask_generate) * noisy_input + mask_generate * input
input_with_mask = NTC_channel_cat(input, mask_generate)
logits = model(input_with_mask)
######################################################################
-def one_ae_epoch(
- model, other_models, quiz_machine, n_epoch, c_quizzes, local_device=main_device
-):
+def one_ae_epoch(model, quiz_machine, n_epoch, c_quizzes, local_device=main_device):
model.train().to(local_device)
optimizer_to(model.optimizer, local_device)
if nb_train_samples % args.batch_size == 0:
model.optimizer.zero_grad()
- targets, logits = targets_and_prediction(model, input, mask_generate)
+ targets, logits = targets_and_prediction(
+ model, input, mask_generate, prompt_noise=args.prompt_noise
+ )
loss = NTC_masked_cross_entropy(logits, targets, mask_loss)
acc_train_loss += loss.item() * input.size(0)
# --------------------------------------------------------------------
- # one_ae_epoch(models[0], models, quiz_machine, n_epoch, main_device)
+ # one_ae_epoch(models[0], quiz_machine, n_epoch, main_device)
# exit(0)
log_string(f"{time_train=} {time_c_quizzes=}")
t = threading.Thread(
target=one_ae_epoch,
daemon=True,
- args=(model, models, quiz_machine, n_epoch, c_quizzes, gpu),
+ args=(model, quiz_machine, n_epoch, c_quizzes, gpu),
)
threads.append(t)
problem,
batch_size,
result_dir,
- prompt_noise,
logger,
device=torch.device("cpu"),
):
self.logger = logger
self.prompt_len = None
self.answer_len = None
- self.prompt_noise = prompt_noise
# quad_order, quad_generate, quad_noise, quad_loss
self.train_structures = [
quad_order, quad_generate, quad_noise, quad_loss = s
i = order_ids == j
quizzes[i] = self.problem.reconfigure(quizzes[i], quad_order=quad_order)
- if self.prompt_noise > 0.0:
- quizzes[i] = self.problem.inject_noise(
- quizzes[i],
- self.prompt_noise,
- quad_order=quad_order,
- quad_noise=quad_noise,
- )
quiz_mask_generate[i] = self.make_quiz_mask(
quizzes=quizzes[i], quad_order=quad_order, quad_mask=quad_generate
)
device=device,
)
- # if self.prompt_noise > 0.0 and quad_noise is not None:
- # c_quizzes = self.problem.inject_noise(
- # c_quizzes, self.prompt_noise, quad_order=quad_order, quad_noise=quad_noise
- # )
-
with torch.autograd.no_grad():
t = model.training
model.eval()