projects
/
mygptrnn.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[mygptrnn.git]
/
main.py
diff --git
a/main.py
b/main.py
index
fabebdd
..
c22ae57
100755
(executable)
--- a/
main.py
+++ b/
main.py
@@
-16,13
+16,16
@@
import mygpt, tasks, problems
######################################################################
######################################################################
-if torch.cuda.is_available():
- device = torch.device("cuda")
- torch.backends.cuda.matmul.allow_tf32 = True
-else:
- device = torch.device("cpu")
-######################################################################
+def str2bool(x):
+ x = x.lower()
+ if x in {"1", "true", "yes"}:
+ return True
+ elif x in {"0", "false", "no"}:
+ return False
+ else:
+ raise ValueError
+
parser = argparse.ArgumentParser(
description="An implementation of GPT with cache.",
parser = argparse.ArgumentParser(
description="An implementation of GPT with cache.",
@@
-44,6
+47,8
@@
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--max_percents_of_test_in_train", type=int, default=1)
parser.add_argument("--max_percents_of_test_in_train", type=int, default=1)
+parser.add_argument("--force_cpu", type=str2bool, default=False)
+
########################################
parser.add_argument("--nb_epochs", type=int, default=50)
########################################
parser.add_argument("--nb_epochs", type=int, default=50)
@@
-68,7
+73,7
@@
parser.add_argument("--min_learning_rate", type=float, default=6e-5)
# legacy
# legacy
-parser.add_argument("--legacy_lr_schedule",
action="store_true", default=Fals
e)
+parser.add_argument("--legacy_lr_schedule",
type=str2bool, default=Tru
e)
parser.add_argument("--legacy_large_lr", type=float, default=1e-4)
parser.add_argument("--legacy_large_lr", type=float, default=1e-4)
@@
-96,8
+101,6
@@
parser.add_argument("--caterpillar_height", type=int, default=None)
parser.add_argument("--rho", type=float, default=0.0)
parser.add_argument("--rho", type=float, default=0.0)
-parser.add_argument("--dim_rec_v", type=int, default=None)
-
parser.add_argument("--nb_blocks", type=int, default=None)
parser.add_argument("--dropout", type=float, default=0.1)
parser.add_argument("--nb_blocks", type=int, default=None)
parser.add_argument("--dropout", type=float, default=0.1)
@@
-108,7
+111,7
@@
parser.add_argument("--deterministic_synthesis", action="store_true", default=Fa
parser.add_argument("--no_checkpoint", action="store_true", default=False)
parser.add_argument("--no_checkpoint", action="store_true", default=False)
-parser.add_argument("--
overwrite_results
", action="store_true", default=False)
+parser.add_argument("--
continue_training
", action="store_true", default=False)
parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
@@
-199,15
+202,25
@@
parser.add_argument("--mixing_deterministic_start", action="store_true", default
######################################################################
######################################################################
-args = parser.parse_args()
+
#
args = parser.parse_args()
-assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
+args, sup_args = parser.parse_known_args()
+
+sup_args = dict([x.removeprefix("--").split("=") for x in sup_args])
if args.result_dir is None:
args.result_dir = f"results_{args.task}_{args.model}"
######################################################################
if args.result_dir is None:
args.result_dir = f"results_{args.task}_{args.model}"
######################################################################
+if not args.force_cpu and torch.cuda.is_available():
+ device = torch.device("cuda")
+ torch.backends.cuda.matmul.allow_tf32 = True
+else:
+ device = torch.device("cpu")
+
+######################################################################
+
default_task_args = {
"addition": {
"model": "352M",
default_task_args = {
"addition": {
"model": "352M",
@@
-321,7
+334,6
@@
default_model_args = {
"dim_keys": 32,
"dim_hidden": 32,
"nb_heads": 2,
"dim_keys": 32,
"dim_hidden": 32,
"nb_heads": 2,
- "dim_rec_v": 16,
"nb_blocks": 2,
},
"17K-C": {
"nb_blocks": 2,
},
"17K-C": {
@@
-332,7
+344,6
@@
default_model_args = {
"nb_heads": 2,
"nb_lines": 16,
"caterpillar_height": 4,
"nb_heads": 2,
"nb_lines": 16,
"caterpillar_height": 4,
- "dim_rec_v": 16,
"nb_blocks": 2,
},
"4M": {
"nb_blocks": 2,
},
"4M": {
@@
-341,7
+352,6
@@
default_model_args = {
"dim_keys": 32,
"dim_hidden": 1024,
"nb_heads": 4,
"dim_keys": 32,
"dim_hidden": 1024,
"nb_heads": 4,
- "dim_rec_v": 64,
"nb_blocks": 6,
},
"4M-C": {
"nb_blocks": 6,
},
"4M-C": {
@@
-352,7
+362,6
@@
default_model_args = {
"nb_heads": 4,
"nb_lines": 32,
"caterpillar_height": 4,
"nb_heads": 4,
"nb_lines": 32,
"caterpillar_height": 4,
- "dim_rec_v": 64, # dim_model / nb_heads
"nb_blocks": 6,
},
"37M": {
"nb_blocks": 6,
},
"37M": {
@@
-361,7
+370,6
@@
default_model_args = {
"dim_keys": 64,
"dim_hidden": 2048,
"nb_heads": 8,
"dim_keys": 64,
"dim_hidden": 2048,
"nb_heads": 8,
- "dim_rec_v": 64,
"nb_blocks": 12,
},
"37M-C": {
"nb_blocks": 12,
},
"37M-C": {
@@
-372,7
+380,6
@@
default_model_args = {
"nb_heads": 8,
"nb_lines": 256,
"caterpillar_height": 32,
"nb_heads": 8,
"nb_lines": 256,
"caterpillar_height": 32,
- "dim_rec_v": 64,
"nb_blocks": 12,
},
"122M": {
"nb_blocks": 12,
},
"122M": {
@@
-381,7
+388,6
@@
default_model_args = {
"dim_keys": 64,
"dim_hidden": 2048,
"nb_heads": 8,
"dim_keys": 64,
"dim_hidden": 2048,
"nb_heads": 8,
- "dim_rec_v": 96,
"nb_blocks": 24,
},
"122M-C": {
"nb_blocks": 24,
},
"122M-C": {
@@
-391,7
+397,6
@@
default_model_args = {
"dim_hidden": 2048,
"nb_heads": 8,
"nb_lines": 128,
"dim_hidden": 2048,
"nb_heads": 8,
"nb_lines": 128,
- "dim_rec_v": 96,
"nb_blocks": 24,
},
"352M": {
"nb_blocks": 24,
},
"352M": {
@@
-400,7
+405,6
@@
default_model_args = {
"dim_keys": 64,
"dim_hidden": 2048,
"nb_heads": 8,
"dim_keys": 64,
"dim_hidden": 2048,
"nb_heads": 8,
- "dim_rec_v": 128,
"nb_blocks": 48,
},
"352M-C": {
"nb_blocks": 48,
},
"352M-C": {
@@
-410,7
+414,6
@@
default_model_args = {
"dim_hidden": 2048,
"nb_heads": 8,
"nb_lines": 128,
"dim_hidden": 2048,
"nb_heads": 8,
"nb_lines": 128,
- "dim_rec_v": 128,
"nb_blocks": 48,
},
}
"nb_blocks": 48,
},
}
@@
-427,10
+430,12
@@
else:
try:
os.mkdir(args.result_dir)
except FileExistsError:
try:
os.mkdir(args.result_dir)
except FileExistsError:
- if not args.
overwrite_results
:
+ if not args.
continue_training
:
print(f"result directory {args.result_dir} already exists")
exit(1)
print(f"result directory {args.result_dir} already exists")
exit(1)
+loss_file = open(os.path.join(args.result_dir, "loss.dat"), "a")
+
log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
if args.seed >= 0:
log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
if args.seed >= 0:
@@
-460,13
+465,16
@@
with os.popen("sha256sum *.py") as f:
log_string(f"sha256sum {l.strip()}")
now = time.strftime("%Y%m%d-%H%M%S", time.localtime())
log_string(f"sha256sum {l.strip()}")
now = time.strftime("%Y%m%d-%H%M%S", time.localtime())
-os.system(f"tar zcvf {args.result_dir}/src-{now}.tgz *.py *.sh")
+os.system(f"tar
--ignore-failed-read
zcvf {args.result_dir}/src-{now}.tgz *.py *.sh")
log_string(f"argv {' '.join(sys.argv)}")
for n in vars(args):
log_string(f"args.{n} {getattr(args, n)}")
log_string(f"argv {' '.join(sys.argv)}")
for n in vars(args):
log_string(f"args.{n} {getattr(args, n)}")
+for n in vars(sup_args):
+ log_string(f"sup_args.{n} {getattr(sup_args, n)}")
+
######################################################################
######################################################################
@@
-478,7
+486,7
@@
def get_lr(n_epoch, it):
if it < args.nb_warmup_iter:
return args.legacy_large_lr * it / args.nb_warmup_iter
if it < args.nb_warmup_iter:
return args.legacy_large_lr * it / args.nb_warmup_iter
- elif
it
< args.legacy_nb_epoch_large_lr:
+ elif
n_epoch
< args.legacy_nb_epoch_large_lr:
return args.legacy_large_lr
else:
return args.legacy_small_lr
return args.legacy_large_lr
else:
return args.legacy_small_lr
@@
-504,6
+512,9
@@
def get_lr(n_epoch, it):
######################################################################
######################################################################
+assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
+
+
def picoclvr_pruner_horizontal_green(p):
return not ("green" in p and ("left" in p or "right" in p))
def picoclvr_pruner_horizontal_green(p):
return not ("green" in p and ("left" in p or "right" in p))
@@
-725,11
+736,12
@@
model = mygpt.MyGPT(
nb_heads=args.nb_heads,
nb_lines=args.nb_lines,
caterpillar_height=args.caterpillar_height,
nb_heads=args.nb_heads,
nb_lines=args.nb_lines,
caterpillar_height=args.caterpillar_height,
- dim_rec_v=args.dim_rec_v,
nb_blocks=args.nb_blocks,
causal=True,
dropout=args.dropout,
attention_layer=args.attention,
nb_blocks=args.nb_blocks,
causal=True,
dropout=args.dropout,
attention_layer=args.attention,
+ logger=log_string,
+ **sup_args,
)
model.to(device)
)
model.to(device)
@@
-834,10
+846,12
@@
if nb_epochs_finished >= nb_epochs:
deterministic_synthesis=args.deterministic_synthesis,
)
deterministic_synthesis=args.deterministic_synthesis,
)
-time_pred_result =
None
+time_pred_result =
datetime.datetime.now()
it = 0
it = 0
+n_batch = 0
+
for n_epoch in range(nb_epochs_finished, nb_epochs):
if args.optim == "sgd":
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
for n_epoch in range(nb_epochs_finished, nb_epochs):
if args.optim == "sgd":
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
@@
-879,6
+893,12
@@
for n_epoch in range(nb_epochs_finished, nb_epochs):
total_loss.backward()
optimizer.step()
total_loss.backward()
optimizer.step()
+ grad_norm = sum([p.grad.pow(2).sum() for p in model.parameters()]).sqrt()
+
+ loss_file.write(f"{n_epoch} {n_batch} {loss.item()} {grad_norm.item()}\n")
+
+ n_batch += 1
+
with torch.autograd.no_grad():
model.eval()
with torch.autograd.no_grad():
model.eval()
@@
-912,10
+932,9
@@
for n_epoch in range(nb_epochs_finished, nb_epochs):
)
time_current_result = datetime.datetime.now()
)
time_current_result = datetime.datetime.now()
- if time_pred_result is not None:
- log_string(
- f"next_result {time_current_result + (time_current_result - time_pred_result)}"
- )
+ log_string(
+ f"next_result {time_current_result + (time_current_result - time_pred_result)}"
+ )
time_pred_result = time_current_result
checkpoint = {
time_pred_result = time_current_result
checkpoint = {