From 2be22c9825d8aebe8d184e9501355a31318abf2b Mon Sep 17 00:00:00 2001 From: =?utf8?q?Fran=C3=A7ois=20Fleuret?= Date: Mon, 25 Mar 2024 12:45:32 +0100 Subject: [PATCH] Update. --- escape.py | 38 +++++++++++++++++++++++++++----------- tasks.py | 10 ++++++---- 2 files changed, 33 insertions(+), 15 deletions(-) diff --git a/escape.py b/escape.py index f51863b..6f4af35 100755 --- a/escape.py +++ b/escape.py @@ -25,6 +25,33 @@ nb_codes = first_lookahead_rewards_code + nb_lookahead_rewards_codes ###################################################################### +def action2code(r): + return first_actions_code + r + + +def code2action(r): + return r - first_actions_code + + +def reward2code(r): + return first_rewards_code + r + 1 + + +def code2reward(r): + return r - first_rewards_code - 1 + + +def lookahead_reward2code(r): + return first_lookahead_rewards_code + r + 1 + + +def code2lookahead_reward(r): + return r - first_lookahead_rewards_code - 1 + + +###################################################################### + + def generate_episodes(nb, height=6, width=6, T=10, nb_walls=3): rnd = torch.rand(nb, height, width) rnd[:, 0, :] = 0 @@ -111,17 +138,6 @@ def episodes2seq(states, actions, rewards, lookahead_delta=None): actions = actions[:, :, None] + first_actions_code if lookahead_delta is not None: - # r = rewards - # u = F.pad(r, (0, lookahead_delta - 1)).as_strided( - # (r.size(0), r.size(1), lookahead_delta), - # (r.size(1) + lookahead_delta - 1, 1, 1), - # ) - # a = u[:, :, 1:].min(dim=-1).values - # b = u[:, :, 1:].max(dim=-1).values - # s = (a < 0).long() * a + (a >= 0).long() * b - # lookahead_rewards = (1 + s[:, :, None]) + first_lookahead_rewards_code - - # a[n,t]=min_s>t r[n,s] a = rewards.new_zeros(rewards.size()) b = rewards.new_zeros(rewards.size()) for t in range(a.size(1) - 1): diff --git a/tasks.py b/tasks.py index 6b6b8f2..29f1e5a 100755 --- a/tasks.py +++ b/tasks.py @@ -1944,13 +1944,14 @@ class Escape(Task): # Generate iteration after iteration optimistic_bias = result.new_zeros(self.nb_codes, device=result.device) - optimistic_bias[(-1) + escape.first_lookahead_rewards_code + 1] = math.log(1e-1) - optimistic_bias[(1) + escape.first_lookahead_rewards_code + 1] = math.log(1e1) + optimistic_bias[escape.lookahead_reward2code(-1)] = -math.log(1e1) + optimistic_bias[escape.lookahead_reward2code(1)] = math.log(1e1) for u in tqdm.tqdm( range(it_len, result.size(1) - it_len + 1, it_len), desc="thinking" ): - # Generate the lookahead_reward pessimistically + # Re-generate the lookahead_reward pessimistically in the + # previous iterations ar_mask = (t < u).long() * (t % it_len == index_lookahead_reward).long() ar(result, ar_mask, logit_biases=-optimistic_bias) @@ -1958,7 +1959,8 @@ class Escape(Task): ar_mask = (t >= u).long() * (t < u + state_len).long() ar(result, ar_mask) - # Generate the lookahead_reward optimistically + # Re-generate the lookahead_reward optimistically in the + # previous iterations ar_mask = (t < u).long() * (t % it_len == index_lookahead_reward).long() ar(result, ar_mask, logit_biases=optimistic_bias) -- 2.39.5