##############################
+class NoiseInjector(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.noise_std = 0.0
+
+ def forward(self, x):
+ if self.noise_std > 0:
+ x = x + torch.randn(x.size(), device=x.device) * self.noise_std
+ return x
+
+
+def set_noise_injection(model, noise_std):
+ for m in model.modules():
+ if isinstance(m, NoiseInjector):
+ m.noise_std = noise_std
+
+
+##############################
+
+
class MyGPT(nn.Module):
def __init__(
self,
for b in range(nb_blocks):
trunk_blocks += [
WithResidual(
- CacheWrapper(nn.LayerNorm((dim_model,))),
+ CacheWrapper(
+ nn.LayerNorm((dim_model,)),
+ NoiseInjector(),
+ ),
QKVAttention(
dim_in=dim_model,
dim_qk=dim_keys,
WithResidual(
CacheWrapper(
nn.LayerNorm((dim_model,)),
+ NoiseInjector(),
nn.Linear(in_features=dim_model, out_features=dim_hidden),
nn.ReLU(),
nn.Linear(in_features=dim_hidden, out_features=dim_model),
from torch import nn
from torch.nn import functional as F
+import mygpt
from mygpt import BracketedSequence
######################################################################
class Gang(nn.Module):
def __init__(self, models, nb_models_for_generation, mode="groupthink"):
super().__init__()
- self.models = models
+ self.models = nn.ModuleList(models)
self.nb_models_for_generation = nb_models_for_generation
self.mode = mode
ar_mask_solve = 1 - ar_mask_prompt
seq_logproba = torch.empty(ar_mask_prompt.size(0), device=self.device)
- # bracketing of the temperature to get the target logproba if
- # min_ave_seq_logproba is not None
+ warnings.warn("noise injection", RuntimeWarning)
+ temperature = 1
+ noise_std = torch.rand(1).item()
+ self.logger(f"{noise_std=}")
+ mygpt.set_noise_injection(model_for_generation, noise_std)
- temperature = 2
- d_temperature = 1 / 3
-
- while True:
- seq_logproba[...] = 0
-
- masked_inplace_autoregression(
- model=model_for_generation,
- batch_size=self.batch_size,
- input=c_quizzes,
- ar_mask=ar_mask_prompt,
- seq_logproba=seq_logproba,
- temperature=temperature,
- deterministic_synthesis=False,
- # progress_bar_desc="sampling c_quizzes",
- device=self.device,
- )
+ masked_inplace_autoregression(
+ model=model_for_generation,
+ batch_size=self.batch_size,
+ input=c_quizzes,
+ ar_mask=ar_mask_prompt,
+ seq_logproba=seq_logproba,
+ temperature=temperature,
+ deterministic_synthesis=False,
+ # progress_bar_desc="sampling c_quizzes",
+ device=self.device,
+ )
- ave_seq_logproba = seq_logproba.mean()
+ ave_seq_logproba = seq_logproba.mean()
- masked_inplace_autoregression(
- model=model_for_generation,
- batch_size=self.batch_size,
- input=c_quizzes,
- ar_mask=ar_mask_solve,
- seq_logproba=seq_logproba,
- temperature=temperature,
- deterministic_synthesis=True,
- # progress_bar_desc="sampling c_quizzes",
- device=self.device,
- )
+ masked_inplace_autoregression(
+ model=model_for_generation,
+ batch_size=self.batch_size,
+ input=c_quizzes,
+ ar_mask=ar_mask_solve,
+ seq_logproba=seq_logproba,
+ temperature=temperature,
+ deterministic_synthesis=True,
+ # progress_bar_desc="sampling c_quizzes",
+ device=self.device,
+ )
- # If we do not have target logprobs, get out now
- if min_ave_seq_logproba is None:
- break
-
- # Oh man that's ugly
- if ave_seq_logproba < min_ave_seq_logproba:
- if d_temperature > 0:
- d_temperature *= -1 / 3
- temperature += d_temperature
- elif ave_seq_logproba > min_ave_seq_logproba * 0.99:
- if d_temperature < 0:
- d_temperature *= -1 / 3
- temperature += d_temperature
- else:
- break
-
- self.logger(f"changing temperature to {temperature}")
+ mygpt.set_noise_injection(model_for_generation, 0.0)
return c_quizzes, seq_logproba.mean()
def generate_frame_sequences_hard(self, nb):
frame_sequences = []
+ nb_frames = (self.nb_iterations - 1) * self.speed + 1
result = torch.full(
- (nb * 4, self.nb_iterations * self.speed, self.height, self.width),
+ (nb * 4, nb_frames, self.height, self.width),
self.token_empty,
)
result[n, 0, i + vi, j + vj] = self.token_tail
break
- if torch.rand(1) < 0.75:
- break
+ # if torch.rand(1) < 0.75:
+ break
weight = torch.full((1, 1, 3, 3), 1.0)
# tail->conductor
# conductor->head if 1 or 2 head in the neighborhood, or remains conductor
- for l in range(self.nb_iterations * self.speed - 1):
+ nb_heads = (result[:, 0] == self.token_head).flatten(1).long().sum(dim=1)
+ valid = nb_heads > 0
+
+ for l in range(nb_frames - 1):
nb_head_neighbors = (
F.conv2d(
input=(result[:, l] == self.token_head).float()[:, None, :, :],
+ (1 - mask_1_or_2_heads) * self.token_conductor
)
)
+ pred_nb_heads = nb_heads
+ nb_heads = (
+ (result[:, l + 1] == self.token_head).flatten(1).long().sum(dim=1)
+ )
+ valid = torch.logical_and(valid, (nb_heads >= pred_nb_heads))
+
+ result = result[valid]
result = result[
:, torch.arange(self.nb_iterations, device=result.device) * self.speed