# pass
-warnings.warn("*********** novel procedure!!! **********", RuntimeWarning)
-
c_quizzes_procedure = [
- # (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_transformer_hot),
- # (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_transformer_cold),
- # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold),
(("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_transformer_hot),
- (("f_B", "f_A", "A", "B"), (0, 1, 1, 0), model_transformer_cold),
- (("A", "f_A", "B", "f_B"), (0, 0, 1, 1), model_transformer_cold),
- # (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_transformer_cold),
- # (("f_B", "f_A", "A", "B"), (0, 0, 1, 1), model_transformer_cold),
- # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold),
+ (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_transformer_cold),
+ (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold),
+ (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), model_transformer_cold),
]
######################################################################
struct = ("A", "f_A", "B", "f_B")
c_quizzes = quiz_machine.problem.create_empty_quizzes(nb, struct=struct)
- ar_mask = quiz_machine.make_ar_mask(c_quizzes, struct, (1, 1, 1, 1))
+ ar_mask = quiz_machine.make_quiz_mask(c_quizzes, struct, (1, 1, 1, 1))
i = F.one_hot(
torch.randint(args.nb_gpts, (c_quizzes.size(0),)),
self.prompt_noise = prompt_noise
# struct, mask_generate, mask_noise, mask_loss
- self.understood_structures = [
- (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 0)),
- (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)),
- (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 0)),
- (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)),
+ self.train_structures = [
+ (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
+ (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
+ (("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
+ (("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
(("f_B", "f_A", "A", "B"), (0, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
+ # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 0)),
+ # (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)),
+ # (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 0)),
+ # (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 0, 1)),
+ # (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
]
- self.test_structures = [
- self.understood_structures[0],
- self.understood_structures[2],
- self.understood_structures[4],
- ]
+ self.test_structures = self.train_structures
self.LOCK_C_QUIZZES = threading.Lock()
self.train_c_quizzes = []
quizzes, from_w = quizzes[i], from_w[i]
self.randomize_configuations_inplace(
- quizzes, structs=[s for s, _, _, _ in self.understood_structures]
+ quizzes, structs=[s for s, _, _, _ in self.train_structures]
)
quiz_mask_loss = quizzes.new_full(quizzes.size(), 1)
if self.prompt_noise > 0.0:
- for struct, _, mask_noise, mask_loss in self.understood_structures:
+ for struct, _, mask_noise, mask_loss in self.train_structures:
i = self.problem.indices_select(quizzes=quizzes, struct=struct)
if i.any():
quizzes[i] = self.problem.inject_noise(
######################################################################
def make_ar_mask(self, quizzes, struct, mask):
- assert struct in [s for s, _, _, _ in self.understood_structures]
+ assert struct in [s for s, _, _, _ in self.train_structures]
return self.problem.make_ar_mask(quizzes, struct=struct, mask=mask)
######################################################################
models_for_validation,
c_quizzes,
struct,
- mask_value,
+ mask_loss,
mask_noise=None,
device=None,
):
seq_logproba.split(self.batch_size),
):
input = input.to(device)
- ar_mask = self.make_ar_mask(input, struct=struct, mask=mask_value)
+ quiz_mask_loss = self.make_ar_mask(
+ input, struct=struct, mask=mask_loss
+ )
output = model(mygpt.BracketedSequence(input)).x
l[:, model.id] = (
-F.cross_entropy(
output.transpose(1, 2), input, reduction="none"
)
- * ar_mask
+ * quiz_mask_loss
).sum(dim=1)
model.train(t)