X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=38c85ed84d228b217e214b68adbbda1f76c255a9;hb=690470307a7995cc3117eb54545f921eedcecba5;hp=a4ef557d2adde2f7df591adba4d3efc6b31e7d5f;hpb=12671bdc1a083514de11399041f74747a6ca601b;p=picoclvr.git diff --git a/tasks.py b/tasks.py index a4ef557..38c85ed 100755 --- a/tasks.py +++ b/tasks.py @@ -1885,10 +1885,10 @@ class Escape(Task): self.width = width states, actions, rewards = escape.generate_episodes( - nb_train_samples + nb_test_samples, height, width, 3 * T + nb_train_samples + nb_test_samples, height, width, T ) seq = escape.episodes2seq(states, actions, rewards, lookahead_delta=T) - seq = seq[:, seq.size(1) // 3 : 2 * seq.size(1) // 3] + # seq = seq[:, seq.size(1) // 3 : 2 * seq.size(1) // 3] self.train_input = seq[:nb_train_samples].to(self.device) self.test_input = seq[nb_train_samples:].to(self.device)