from torch import nn
from torch.nn import functional as F
+import mygpt
from mygpt import BracketedSequence
######################################################################
class Gang(nn.Module):
def __init__(self, models, nb_models_for_generation, mode="groupthink"):
super().__init__()
- self.models = models
+ self.models = nn.ModuleList(models)
self.nb_models_for_generation = nb_models_for_generation
self.mode = mode
nb, self.train_w_quizzes.size(1), device=self.device, dtype=torch.int64
)
- ar_mask = torch.full(c_quizzes.size(), 1, device=self.device)
- seq_logproba = torch.empty(ar_mask.size(0), device=self.device)
-
- # bracketing of the temperature to get the target logproba
+ ar_mask_prompt = torch.zeros(c_quizzes.size(), device=self.device)
+ ar_mask_prompt[:, : ar_mask_prompt.size(1) // 2 + 1] = 1
+ ar_mask_solve = 1 - ar_mask_prompt
+ seq_logproba = torch.empty(ar_mask_prompt.size(0), device=self.device)
+ warnings.warn("noise injection", RuntimeWarning)
temperature = 1
- d_temperature = 1 / 3
+ noise_std = torch.rand(1).item()
+ self.logger(f"{noise_std=}")
+ mygpt.set_noise_injection(model_for_generation, noise_std)
- while True:
- seq_logproba[...] = 0
+ masked_inplace_autoregression(
+ model=model_for_generation,
+ batch_size=self.batch_size,
+ input=c_quizzes,
+ ar_mask=ar_mask_prompt,
+ seq_logproba=seq_logproba,
+ temperature=temperature,
+ deterministic_synthesis=False,
+ # progress_bar_desc="sampling c_quizzes",
+ device=self.device,
+ )
- masked_inplace_autoregression(
- model=model_for_generation,
- batch_size=self.batch_size,
- input=c_quizzes,
- ar_mask=ar_mask,
- seq_logproba=seq_logproba,
- temperature=temperature,
- deterministic_synthesis=False,
- # progress_bar_desc="sampling c_quizzes",
- device=self.device,
- )
+ ave_seq_logproba = seq_logproba.mean()
+
+ masked_inplace_autoregression(
+ model=model_for_generation,
+ batch_size=self.batch_size,
+ input=c_quizzes,
+ ar_mask=ar_mask_solve,
+ seq_logproba=seq_logproba,
+ temperature=temperature,
+ deterministic_synthesis=True,
+ # progress_bar_desc="sampling c_quizzes",
+ device=self.device,
+ )
- ave_seq_logproba = seq_logproba.mean()
-
- # If we do not have target logprobs, get out now
- if min_ave_seq_logproba is None:
- break
-
- # Oh man that's ugly
- if ave_seq_logproba < min_ave_seq_logproba:
- if d_temperature > 0:
- d_temperature *= -1 / 3
- temperature += d_temperature
- elif ave_seq_logproba > min_ave_seq_logproba * 0.99:
- if d_temperature < 0:
- d_temperature *= -1 / 3
- temperature += d_temperature
- else:
- break
-
- self.logger(f"changing temperature to {temperature}")
+ mygpt.set_noise_injection(model_for_generation, 0.0)
return c_quizzes, seq_logproba.mean()