X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;fp=tasks.py;h=fddcaff34994d31bdb4837772d42ddc4f40d0727;hb=9eadd2cd6913a0de53b4b0f526157497e8d14381;hp=56c2b0fd3f8f5b54eb80116ed96aa20581044279;hpb=62ad2378c60cdf322c0111279bd45fbef8365fc2;p=picoclvr.git diff --git a/tasks.py b/tasks.py index 56c2b0f..fddcaff 100755 --- a/tasks.py +++ b/tasks.py @@ -1939,7 +1939,7 @@ class Escape(Task): ): # Put the lookahead reward to either 0 or -1 for the # current iteration, sample the next state - s = -1 # (torch.rand(result.size(0), device = result.device) < 0.2).long() + s = -(torch.rand(result.size(0), device=result.device) < 0.2).long() result[:, u - 1] = s + 1 + escape.first_lookahead_rewards_code ar_mask = (t >= u).long() * (t < u + state_len).long() ar(result, ar_mask) @@ -1980,7 +1980,7 @@ class Escape(Task): def produce_results( self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000 ): - result = self.test_input[:100].clone() + result = self.test_input[:250].clone() # Saving the ground truth