X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=51538366383be455f160ab6158392627ecea5190;hb=19ec7f3e4030ddece2647983dcf1bed5eb0d9544;hp=fddcaff34994d31bdb4837772d42ddc4f40d0727;hpb=9eadd2cd6913a0de53b4b0f526157497e8d14381;p=picoclvr.git diff --git a/tasks.py b/tasks.py index fddcaff..5153836 100755 --- a/tasks.py +++ b/tasks.py @@ -1938,8 +1938,13 @@ class Escape(Task): range(it_len, result.size(1) - it_len + 1, it_len), desc="thinking" ): # Put the lookahead reward to either 0 or -1 for the - # current iteration, sample the next state - s = -(torch.rand(result.size(0), device=result.device) < 0.2).long() + # current iteration, with a proba that depends with the + # sequence index, so that we have diverse examples, sample + # the next state + s = -( + torch.rand(result.size(0), device=result.device) + <= torch.linspace(0, 1, result.size(0), device=result.device) + ).long() result[:, u - 1] = s + 1 + escape.first_lookahead_rewards_code ar_mask = (t >= u).long() * (t < u + state_len).long() ar(result, ar_mask) @@ -1956,6 +1961,7 @@ class Escape(Task): # Extract the rewards r = result[:, range(v + state_len + 1 + it_len, u + it_len - 1, it_len)] r = r - escape.first_rewards_code - 1 + r = r.clamp(min=-1, max=1) # the reward is predicted hence can be weird a = r.min(dim=1).values b = r.max(dim=1).values s = (a < 0).long() * a + (a >= 0).long() * b