From: François Fleuret Date: Wed, 27 Mar 2024 06:22:09 +0000 (+0100) Subject: Update. X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=commitdiff_plain;h=209f48148a9695069de628ba679345e8f95efa30;p=culture.git Update. --- diff --git a/greed.py b/greed.py index 3cbe886..20cef79 100755 --- a/greed.py +++ b/greed.py @@ -300,7 +300,8 @@ def save_seq_as_anim_script(seq, filename): ) f.write(episodes2str(lr, s, a, r, unicode=True, ansi_colors=True)) f.write("EOF\n") - f.write("sleep 0.5\n") + f.write("sleep 0.25\n") + print(f"Saved {filename}") if __name__ == "__main__": diff --git a/tasks.py b/tasks.py index 845b5b3..aa5df72 100755 --- a/tasks.py +++ b/tasks.py @@ -1905,6 +1905,15 @@ class Greed(Task): self.index_reward = self.state_len + 2 self.it_len = self.state_len + 3 # lookahead_reward / state / action / reward + def wipe_lookahead_rewards(self, batch): + t = torch.arange(batch.size(1), device=batch.device)[None, :] + u = torch.randint(batch.size(1), (batch.size(0), 1), device=batch.device) + lr_mask = (t <= u).long() * ( + t % self.it_len == self.index_lookahead_reward + ).long() + + return lr_mask * greed.lookahead_reward2code(2) + (1 - lr_mask) * batch + def batches(self, split="train", nb_to_use=-1, desc=None): assert split in {"train", "test"} input = self.train_input if split == "train" else self.test_input @@ -1915,14 +1924,7 @@ class Greed(Task): for batch in tqdm.tqdm( input.split(self.batch_size), dynamic_ncols=True, desc=desc ): - t = torch.arange(batch.size(1), device=batch.device)[None, :] - u = torch.randint(batch.size(1), (batch.size(0), 1), device=batch.device) - lr_mask = (t <= u).long() * ( - t % self.it_len == self.index_lookahead_reward - ).long() - - batch = lr_mask * greed.lookahead_reward2code(2) + (1 - lr_mask) * batch - yield batch + yield self.wipe_lookahead_rewards(batch) def vocabulary_size(self): return greed.nb_codes @@ -2010,7 +2012,7 @@ class Greed(Task): def produce_results( self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000 ): - result = self.test_input[:250].clone() + result = self.wipe_lookahead_rewards(self.test_input[:250].clone()) # Saving the ground truth