)
#!!!!!!!!!!!!!!!!!!!!
- l = quiz_machine.models_logprobas(
- models, c_quizzes, ("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0)
- )
- for s in range(seq_logproba.size(0)):
- print(f"-- {s=} ----------------")
- for m in range(seq_logproba.size(1)):
- print("DEBUG", seq_logproba[s, m].item(), l[s, m].item())
+ for m in range(seq_logproba.size(1)):
+ l = quiz_machine.models_logprobas(
+ [models[m]],
+ solved_c_quizzes[:, m, :],
+ ("A", "f_A", "B", "f_B"),
+ (0, 0, 0, 1),
+ (0, 0, 0, 0),
+ )
+ for s in range(seq_logproba.size(0)):
+ print("DEBUG", seq_logproba[s, m].item(), l[s, 0].item())
exit(0)
#!!!!!!!!!!!!!!!!!!!!!!!!!
all_n = torch.arange(t_next.size(0))
- acc_seq_logproba += ar_mask[:, s] * logits[all_n, t_next]
+ acc_seq_logproba += ar_mask[:, s] * logits.log_softmax(dim=1)[all_n, t_next]
input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]