X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;fp=tasks.py;h=11879fd44866094711a76b5c7bdebebdda28fa41;hb=0b8185b90014369f0d39892e128ad04a7d9ae872;hp=f2b7709f1dc742979bad1e9a44e27da4525904d3;hpb=1eeba5d817d6e440a93895d42f6e580e9ba273fd;p=picoclvr.git diff --git a/tasks.py b/tasks.py index f2b7709..11879fd 100755 --- a/tasks.py +++ b/tasks.py @@ -5,7 +5,7 @@ # Written by Francois Fleuret -import math, os, tqdm +import math, os, tqdm, warnings import torch, torchvision @@ -1928,6 +1928,8 @@ class Escape(Task): result[:, it_len:] = -1 + snapshots = [] + def ar(result, ar_mask, logit_biases=None): ar_mask = ar_mask.expand_as(result) result *= 1 - ar_mask @@ -1941,6 +1943,8 @@ class Escape(Task): device=self.device, progress_bar_desc=None, ) + warnings.warn("keeping thinking snapshots", RuntimeWarning) + snapshots.append(result[:10].detach().clone()) # Generate iteration after iteration @@ -1948,30 +1952,32 @@ class Escape(Task): optimistic_bias[escape.lookahead_reward2code(-1)] = -math.log(1e1) optimistic_bias[escape.lookahead_reward2code(1)] = math.log(1e1) - snapshots = [] - for u in tqdm.tqdm( range(it_len, result.size(1) - it_len + 1, it_len), desc="thinking" ): + lr, _, _, _ = escape.seq2episodes(result[:, :u], self.height, self.width) + # Generate the lookahead_reward and state - ar_mask = (t >= u + index_lookahead_reward).long() * ( + ar_mask = (t % it_len == index_lookahead_reward).long() * ( + t <= u + index_lookahead_reward + ).long() + ar(result, ar_mask) + + # Generate the lookahead_reward and state + ar_mask = (t >= u + index_states).long() * ( t < u + index_states + state_len ).long() ar(result, ar_mask) - snapshots.append(result[:10].detach().clone()) - backup_lookahead_reward = result[:, u + index_lookahead_reward] # Re-generate the lookahead_reward - ar_mask = (t == u + index_lookahead_reward).long() + ar_mask = (t % it_len == index_lookahead_reward).long() * ( + t <= u + index_lookahead_reward + ).long() ar(result, ar_mask, logit_biases=optimistic_bias) - snapshots.append(result[:10].detach().clone()) # Generate the action and reward ar_mask = (t >= u + index_action).long() * (t <= u + index_reward).long() ar(result, ar_mask) - snapshots.append(result[:10].detach().clone()) - - result[:, u + index_lookahead_reward] = backup_lookahead_reward filename = os.path.join(result_dir, f"test_thinking_compute_{n_epoch:04d}.txt") with open(filename, "w") as f: