######################################################################
-lt_noisy = lambda s, logits: logits / args.temperature_hot
-lt_clean = lambda s, logits: logits / args.temperature_cold
+
+def model_transformer_hot(model):
+ model.temperature = args.temperature_hot
+
+
+def model_transformer_cold(model):
+ model.temperature = args.temperature_cold
+
c_quizzes_procedure = [
- (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), lt_noisy),
- (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), lt_clean),
- (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), lt_clean),
+ (("f_B", "f_A", "A", "B"), (1, 0, 0, 0), model_transformer_hot),
+ (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), model_transformer_cold),
+ (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), model_transformer_cold),
]
c_quizzes_procedure_ = [
- (("A", "f_A", "B", "f_B"), (1, 1, 0, 0), lt_noisy),
- (("A", "f_A", "B", "f_B"), (0, 0, 1, 1), lt_clean),
+ (("A", "f_A", "B", "f_B"), (1, 1, 0, 0), model_transformer_hot),
+ (("A", "f_A", "B", "f_B"), (0, 0, 1, 1), model_transformer_cold),
]
return x
-def set_noise_injection(model, noise_std, identifier=None):
- for m in model.modules():
- if isinstance(m, NoiseInjector):
- if identifier is None or identifier == m.identifier:
- m.noise_std = noise_std
-
-
-def reset_noise_injection(model):
- for m in model.modules():
- if isinstance(m, NoiseInjector):
- m.noise_std = 0.0
-
-
##############################
assert dim_model % nb_heads == 0
+ self.temperature = 1.0
+
self.embedding = nn.Sequential(
CacheWrapper(nn.Embedding(vocabulary_size, dim_model), nn.Dropout(dropout)),
AddPositionalEncoding(len_max),
bs = self.embedding(bs)
bs = self.trunk(bs)
bs = self.readout(bs)
+ bs.x[:, bs.first : bs.first + bs.nb] /= self.temperature
return bs
def encode(self, bs):
bs = self.readout(bs)
return bs
+ def reset_transformations(self):
+ self.temperature = 1.0
+ for m in self.modules():
+ if isinstance(m, NoiseInjector):
+ m.noise_std = 0.0
+
+ def set_noise_injection(self, noise_std, identifier=None):
+ for m in model.modules():
+ if isinstance(m, NoiseInjector):
+ if identifier is None or identifier == m.identifier:
+ m.noise_std = noise_std
+
def record_attention(self, v=True):
for m in self.modules():
if isinstance(m, QKVAttention):
c_quizzes = self.problem.reconfigure(c_quizzes, s)
pred_s = s
- mt(model_for_generation)
+ if mt is not None:
+ mt(model_for_generation)
self.autoregression(
model=model_for_generation,
seq_logproba=seq_logproba,
)
- model_for_generation.reset_transformation()
+ model_for_generation.reset_transformations()
if to_recycle is not None and to_recycle.size(0) > 0:
to_recycle = self.problem.reconfigure(to_recycle, s)