X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=29f1e5a783024a3948867d2ffc0dae8c71b09954;hb=2be22c9825d8aebe8d184e9501355a31318abf2b;hp=6b6b8f2a5ed055c3f80473f70fa7b0ac87f6a526;hpb=621231cc5bb94f983c556a1b450b66067bec4165;p=picoclvr.git diff --git a/tasks.py b/tasks.py index 6b6b8f2..29f1e5a 100755 --- a/tasks.py +++ b/tasks.py @@ -1944,13 +1944,14 @@ class Escape(Task): # Generate iteration after iteration optimistic_bias = result.new_zeros(self.nb_codes, device=result.device) - optimistic_bias[(-1) + escape.first_lookahead_rewards_code + 1] = math.log(1e-1) - optimistic_bias[(1) + escape.first_lookahead_rewards_code + 1] = math.log(1e1) + optimistic_bias[escape.lookahead_reward2code(-1)] = -math.log(1e1) + optimistic_bias[escape.lookahead_reward2code(1)] = math.log(1e1) for u in tqdm.tqdm( range(it_len, result.size(1) - it_len + 1, it_len), desc="thinking" ): - # Generate the lookahead_reward pessimistically + # Re-generate the lookahead_reward pessimistically in the + # previous iterations ar_mask = (t < u).long() * (t % it_len == index_lookahead_reward).long() ar(result, ar_mask, logit_biases=-optimistic_bias) @@ -1958,7 +1959,8 @@ class Escape(Task): ar_mask = (t >= u).long() * (t < u + state_len).long() ar(result, ar_mask) - # Generate the lookahead_reward optimistically + # Re-generate the lookahead_reward optimistically in the + # previous iterations ar_mask = (t < u).long() * (t % it_len == index_lookahead_reward).long() ar(result, ar_mask, logit_biases=optimistic_bias)