projects
/
picoclvr.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
b003cc9
)
Update.
author
François Fleuret
<francois@fleuret.org>
Wed, 21 Jun 2023 08:40:05 +0000
(10:40 +0200)
committer
François Fleuret
<francois@fleuret.org>
Wed, 21 Jun 2023 08:40:05 +0000
(10:40 +0200)
main.py
patch
|
blob
|
history
diff --git
a/main.py
b/main.py
index
acecfdd
..
e723866
100755
(executable)
--- a/
main.py
+++ b/
main.py
@@
-8,7
+8,7
@@
# torch.backends.cuda.matmul.allow_tf23
# torch.autocast(torch.bfloat16)
# torch.backends.cuda.matmul.allow_tf23
# torch.autocast(torch.bfloat16)
-import math, sys, argparse, time, tqdm,
itertools,
os
+import math, sys, argparse, time, tqdm, os
import torch, torchvision
from torch import nn
import torch, torchvision
from torch import nn
@@
-27,7
+27,8
@@
else:
######################################################################
parser = argparse.ArgumentParser(
######################################################################
parser = argparse.ArgumentParser(
- description="An implementation of GPT with cache to solve a toy geometric reasoning task."
+ description="An implementation of GPT with cache.",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--task", type=str, default="picoclvr")
)
parser.add_argument("--task", type=str, default="picoclvr")
@@
-40,7
+41,7
@@
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--nb_epochs", type=int, default=25)
parser.add_argument("--nb_epochs", type=int, default=25)
-parser.add_argument("--batch_size", type=int, default=
25
)
+parser.add_argument("--batch_size", type=int, default=
None
)
parser.add_argument("--nb_train_samples", type=int, default=250000)
parser.add_argument("--nb_train_samples", type=int, default=250000)
@@
-128,6
+129,28
@@
if args.seed >= 0:
######################################################################
######################################################################
+default_args = {
+ "picoclvr": {
+ "batch_size": 25,
+ },
+ "mnist": {
+ "batch_size": 10,
+ },
+ "maze": {
+ "batch_size": 25,
+ },
+ "snake": {
+ "batch_size": 20,
+ },
+}
+
+if args.task in default_args:
+ for k, v in default_args[args.task].items():
+ if getattr(args, k) is None:
+ setattr(args, k, v)
+
+######################################################################
+
def log_string(s):
t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
def log_string(s):
t = time.strftime("%Y%m%d-%H:%M:%S ", time.localtime())
@@
-639,6
+662,8
@@
def generate_snake_sequences(
nb, height, width, nb_colors, length, device=torch.device("cpu")
):
worlds = torch.randint(nb_colors, (nb, height, width), device=device)
nb, height, width, nb_colors, length, device=torch.device("cpu")
):
worlds = torch.randint(nb_colors, (nb, height, width), device=device)
+ nb_prior_visits = torch.zeros(nb, height, width, device=device)
+
# nb x 2
snake_position = torch.cat(
(
# nb x 2
snake_position = torch.cat(
(
@@
-649,6
+674,9
@@
def generate_snake_sequences(
)
snake_direction = torch.randint(4, (nb,), device=device)
sequences = torch.empty(nb, 2 * length, device=device, dtype=torch.int64)
)
snake_direction = torch.randint(4, (nb,), device=device)
sequences = torch.empty(nb, 2 * length, device=device, dtype=torch.int64)
+ sequences_prior_visits = torch.zeros(
+ nb, 2 * length, device=device, dtype=torch.int64
+ )
i = torch.arange(nb, device=device) # [:,None]
for l in range(length):
i = torch.arange(nb, device=device) # [:,None]
for l in range(length):
@@
-680,7
+708,10
@@
def generate_snake_sequences(
),
).float()
val = (
),
).float()
val = (
- torch.rand_like(val) * val * torch.tensor([[1.0, 4.0, 1.0]], device=device)
+ # The multiplicative factors bias toward moving forward
+ torch.rand_like(val)
+ * val
+ * torch.tensor([[1.0, 2.0, 1.0]], device=device)
)
# nb
)
# nb
@@
-688,12
+719,16
@@
def generate_snake_sequences(
snake_direction = snake_next_direction[i, j]
sequences[:, 2 * l] = worlds[i, snake_position[:, 0], snake_position[:, 1]] + 4
snake_direction = snake_next_direction[i, j]
sequences[:, 2 * l] = worlds[i, snake_position[:, 0], snake_position[:, 1]] + 4
+ sequences_prior_visits[:, 2 * l] = nb_prior_visits[
+ i, snake_position[:, 0], snake_position[:, 1]
+ ]
+ nb_prior_visits[i, snake_position[:, 0], snake_position[:, 1]] += 1
sequences[:, 2 * l + 1] = snake_direction
# nb x 2
snake_position = snake_next_position[i, j]
sequences[:, 2 * l + 1] = snake_direction
# nb x 2
snake_position = snake_next_position[i, j]
- return sequences,
world
s
+ return sequences,
sequences_prior_visit
s
# generate_snake_sequences(nb=1, height=4, width=6, nb_colors=3, length=20)
# generate_snake_sequences(nb=1, height=4, width=6, nb_colors=3, length=20)
@@
-717,10
+752,10
@@
class TaskSnake(Task):
self.width = width
self.device = device
self.width = width
self.device = device
- self.train_input, self.train_
world
s = generate_snake_sequences(
+ self.train_input, self.train_
prior_visit
s = generate_snake_sequences(
nb_train_samples, height, width, nb_colors, length, self.device
)
nb_train_samples, height, width, nb_colors, length, self.device
)
- self.test_input, self.test_
world
s = generate_snake_sequences(
+ self.test_input, self.test_
prior_visit
s = generate_snake_sequences(
nb_test_samples, height, width, nb_colors, length, self.device
)
nb_test_samples, height, width, nb_colors, length, self.device
)
@@
-746,32
+781,39
@@
class TaskSnake(Task):
t = model.training
model.eval()
t = model.training
model.eval()
- def compute_nb_correct(input):
+ def compute_nb_correct(input
, prior_visits
):
result = input.clone()
result = input.clone()
- i = torch.arange(result.size(1), device=result.device)
- ar_mask = torch.logical_and(i >= i.size(0) // 2, i % 2 == 0)[
- None, :
- ].long()
+ i = torch.arange(result.size(1), device=result.device)[None, :]
+ ar_mask = torch.logical_and(i >= i.size(0) // 2, i % 2 == 0).long()
result *= 1 - ar_mask
masked_inplace_autoregression(
model, self.batch_size, result, ar_mask, device=self.device
)
result *= 1 - ar_mask
masked_inplace_autoregression(
model, self.batch_size, result, ar_mask, device=self.device
)
- nb_total = ar_mask.sum() * input.size(0)
- nb_correct = ((result == input).long() * ar_mask).sum()
+ nb_total = (
+ (prior_visits > 0) * ar_mask
+ ).sum()
+
+ nb_correct = (
+ (result == input).long() * (prior_visits > 0) * ar_mask
+ ).sum()
# nb_total = result.size(0)
# nb_correct = ((result - input).abs().sum(1) == 0).sum()
return nb_total, nb_correct
# nb_total = result.size(0)
# nb_correct = ((result - input).abs().sum(1) == 0).sum()
return nb_total, nb_correct
- train_nb_total, train_nb_correct = compute_nb_correct(self.train_input)
+ train_nb_total, train_nb_correct = compute_nb_correct(
+ self.train_input, self.train_prior_visits
+ )
log_string(
f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
)
log_string(
f"accuracy_train nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
)
- test_nb_total, test_nb_correct = compute_nb_correct(self.test_input)
+ test_nb_total, test_nb_correct = compute_nb_correct(
+ self.test_input, self.test_prior_visits
+ )
log_string(
f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
log_string(
f"accuracy_test nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"