X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=e7362b749210dcce6c6b92934bd34744d95b770d;hb=6833683bd343fd687d093d6c47cca8f1909e8b03;hp=de69a755f9510daff3669c252fb14a8bec4b3148;hpb=f3f490def0be8a3ea2b9a0ac60f5bb33c5c45fb5;p=mygptrnn.git diff --git a/mygpt.py b/mygpt.py index de69a75..e7362b7 100755 --- a/mygpt.py +++ b/mygpt.py @@ -37,7 +37,7 @@ import ffutils # 1 for the successive tokens. # # Modules able to process brackets may implement a cache that is -# resetted when the input bracket starts at t=0 +# resetted when init_cache is True class BracketedSequence: @@ -482,7 +482,7 @@ class Caterpillar(nn.Module): self.attention_dropout = attention_dropout warnings.warn("flash back", RuntimeWarning) - self.proba_flashback = 0.1 + self.proba_flashback = 1e-2 self.w_G = randw(nb_heads, caterpillar_height, dim_model) self.b_G = nn.Parameter( @@ -603,20 +603,18 @@ class Caterpillar(nn.Module): src_time = t - u - t0 src_head = torch.randint(H, (N, CH, t1 - t0, 1), device=X.device) - mask_V = ( + mask = ( torch.rand(N, CH, t1 - t0, DV, device=X.device) <= self.proba_flashback ).long() + self.rec_V[:, :, t0:t1] = ( - mask_V * V[n, src_head, src_time, dv] - + (1 - mask_V) * self.rec_V[:, :, t0:t1] + mask * V[n, src_head, src_time, dv] + + (1 - mask) * self.rec_V[:, :, t0:t1] ) - mask_K = ( - torch.rand(N, CH, t1 - t0, DK, device=X.device) <= self.proba_flashback - ).long() self.rec_K[:, :, t0:t1] = ( - mask_K * K[n, src_head, src_time, dk] - + (1 - mask_K) * self.rec_K[:, :, t0:t1] + mask * K[n, src_head, src_time, dk] + + (1 - mask) * self.rec_K[:, :, t0:t1] ) ###################################################################### @@ -773,7 +771,12 @@ class MyGPT(nn.Module): ): super().__init__() - assert attention_layer in {"mha", "dumbrec", "kvrec", "caterpillar"} + assert attention_layer in { + "mha", + "dumbrec", + "kvrec", + "caterpillar", + }, f"Unknown attention operator {attention_layer}." if attention_layer == "caterpillar": assert nb_lines % caterpillar_height == 0