X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=mygpt.py;h=f10f1fe2013c4755a8ea7f2eef86a6b932059f25;hb=6183291906184569c2206c34588d118cc77f74bb;hp=f97af49bbb60edb3eca5d26ab96c6e5cccc9dd07;hpb=0d25f8a86e80850cf6a6e27d419f7b043c6028f1;p=mygptrnn.git diff --git a/mygpt.py b/mygpt.py index f97af49..f10f1fe 100755 --- a/mygpt.py +++ b/mygpt.py @@ -472,15 +472,18 @@ def flash_back_time_src(N, H, t0, t1, CL, CH, proba, device): fb_body = fb_body.cumsum(dim=2) fb_start = fb_start * (fb_body == 1) - # pick past starting source times - src_time = ( - fb_start + # t_s = t0-(t0//L * R)*L + + t = torch.arange(fb_start.size(2), device=fb_start.device)[None, None, :] + src_time = fb_start * ( + t + - CL * ( - torch.rand(fb_start.size(), device=fb_start.device) - * (torch.arange(fb_start.size(2), device=fb_start.device) - CL)[ - None, None, : - ] - ).long() + 1 + + ( + torch.rand(fb_start.size(), device=fb_start.device) * (t // CL - 1) + ).long() + ) ) src_time[:, :, CL:] -= src_time.clone()[:, :, :-CL] src_time = src_time.cumsum(dim=2)