X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=324376df60319e9549ae431c5d43dd04f1a29ed9;hb=232299b8af7e66a02e64bb2e47b525e2f50b099d;hp=6a7e639275775421468831e62d3ce3ca90b30aad;hpb=819181b36c1af5c3c606b6dcb11a242e9c43331c;p=picoclvr.git diff --git a/tasks.py b/tasks.py index 6a7e639..324376d 100755 --- a/tasks.py +++ b/tasks.py @@ -1905,7 +1905,10 @@ class Greed(Task): t % self.world.it_len == self.world.index_lookahead_reward ).long() - return lr_mask * self.world.lookahead_reward2code(2) + (1 - lr_mask) * batch + return ( + lr_mask * self.world.lookahead_reward2code(greed.REWARD_UNKNOWN) + + (1 - lr_mask) * batch + ) def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} @@ -1950,7 +1953,7 @@ class Greed(Task): result[:, self.world.it_len :] = -1 # Set the lookahead_reward of the firs to UNKNOWN result[:, self.world.index_lookahead_reward] = self.world.lookahead_reward2code( - 2 + greed.REWARD_UNKNOWN ) t = torch.arange(result.size(1), device=result.device)[None, :] @@ -1965,7 +1968,7 @@ class Greed(Task): if u > 0: result[ :, u + self.world.index_lookahead_reward - ] = self.world.lookahead_reward2code(2) + ] = self.world.lookahead_reward2code(greed.REWARD_UNKNOWN) ar_mask = (t >= u + self.world.index_states).long() * ( t < u + self.world.index_states + self.world.state_len ).long() @@ -1974,7 +1977,7 @@ class Greed(Task): # Generate the action and reward with lookahead_reward to +1 result[ :, u + self.world.index_lookahead_reward - ] = self.world.lookahead_reward2code(1) + ] = self.world.lookahead_reward2code(greed.REWARD_PLUS) ar_mask = (t >= u + self.world.index_reward).long() * ( t <= u + self.world.index_action ).long() @@ -1983,7 +1986,7 @@ class Greed(Task): # Set the lookahead_reward to UNKNOWN for the next iterations result[ :, u + self.world.index_lookahead_reward - ] = self.world.lookahead_reward2code(2) + ] = self.world.lookahead_reward2code(gree.REWARD_UNKNOWN) filename = os.path.join(result_dir, f"test_thinking_compute_{n_epoch:04d}.txt") with open(filename, "w") as f: