projects
/
mygptrnn.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[mygptrnn.git]
/
main.py
diff --git
a/main.py
b/main.py
index
cae20f8
..
969b47f
100755
(executable)
--- a/
main.py
+++ b/
main.py
@@
-16,14
+16,6
@@
import mygpt, tasks, problems
######################################################################
######################################################################
-if torch.cuda.is_available():
- device = torch.device("cuda")
- torch.backends.cuda.matmul.allow_tf32 = True
-else:
- device = torch.device("cpu")
-
-######################################################################
-
def str2bool(x):
x = x.lower()
def str2bool(x):
x = x.lower()
@@
-55,6
+47,8
@@
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--max_percents_of_test_in_train", type=int, default=1)
parser.add_argument("--max_percents_of_test_in_train", type=int, default=1)
+parser.add_argument("--force_cpu", type=str2bool, default=False)
+
########################################
parser.add_argument("--nb_epochs", type=int, default=50)
########################################
parser.add_argument("--nb_epochs", type=int, default=50)
@@
-107,8
+101,6
@@
parser.add_argument("--caterpillar_height", type=int, default=None)
parser.add_argument("--rho", type=float, default=0.0)
parser.add_argument("--rho", type=float, default=0.0)
-parser.add_argument("--dim_rec_v", type=int, default=None)
-
parser.add_argument("--nb_blocks", type=int, default=None)
parser.add_argument("--dropout", type=float, default=0.1)
parser.add_argument("--nb_blocks", type=int, default=None)
parser.add_argument("--dropout", type=float, default=0.1)
@@
-119,7
+111,7
@@
parser.add_argument("--deterministic_synthesis", action="store_true", default=Fa
parser.add_argument("--no_checkpoint", action="store_true", default=False)
parser.add_argument("--no_checkpoint", action="store_true", default=False)
-parser.add_argument("--
overwrite_results
", action="store_true", default=False)
+parser.add_argument("--
continue_training
", action="store_true", default=False)
parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
@@
-219,6
+211,14
@@
if args.result_dir is None:
######################################################################
######################################################################
+if not args.force_cpu and torch.cuda.is_available():
+ device = torch.device("cuda")
+ torch.backends.cuda.matmul.allow_tf32 = True
+else:
+ device = torch.device("cpu")
+
+######################################################################
+
default_task_args = {
"addition": {
"model": "352M",
default_task_args = {
"addition": {
"model": "352M",
@@
-332,7
+332,6
@@
default_model_args = {
"dim_keys": 32,
"dim_hidden": 32,
"nb_heads": 2,
"dim_keys": 32,
"dim_hidden": 32,
"nb_heads": 2,
- "dim_rec_v": 16,
"nb_blocks": 2,
},
"17K-C": {
"nb_blocks": 2,
},
"17K-C": {
@@
-343,7
+342,6
@@
default_model_args = {
"nb_heads": 2,
"nb_lines": 16,
"caterpillar_height": 4,
"nb_heads": 2,
"nb_lines": 16,
"caterpillar_height": 4,
- "dim_rec_v": 16,
"nb_blocks": 2,
},
"4M": {
"nb_blocks": 2,
},
"4M": {
@@
-352,7
+350,6
@@
default_model_args = {
"dim_keys": 32,
"dim_hidden": 1024,
"nb_heads": 4,
"dim_keys": 32,
"dim_hidden": 1024,
"nb_heads": 4,
- "dim_rec_v": 64,
"nb_blocks": 6,
},
"4M-C": {
"nb_blocks": 6,
},
"4M-C": {
@@
-363,7
+360,6
@@
default_model_args = {
"nb_heads": 4,
"nb_lines": 32,
"caterpillar_height": 4,
"nb_heads": 4,
"nb_lines": 32,
"caterpillar_height": 4,
- "dim_rec_v": 64, # dim_model / nb_heads
"nb_blocks": 6,
},
"37M": {
"nb_blocks": 6,
},
"37M": {
@@
-372,7
+368,6
@@
default_model_args = {
"dim_keys": 64,
"dim_hidden": 2048,
"nb_heads": 8,
"dim_keys": 64,
"dim_hidden": 2048,
"nb_heads": 8,
- "dim_rec_v": 64,
"nb_blocks": 12,
},
"37M-C": {
"nb_blocks": 12,
},
"37M-C": {
@@
-383,7
+378,6
@@
default_model_args = {
"nb_heads": 8,
"nb_lines": 256,
"caterpillar_height": 32,
"nb_heads": 8,
"nb_lines": 256,
"caterpillar_height": 32,
- "dim_rec_v": 64,
"nb_blocks": 12,
},
"122M": {
"nb_blocks": 12,
},
"122M": {
@@
-392,7
+386,6
@@
default_model_args = {
"dim_keys": 64,
"dim_hidden": 2048,
"nb_heads": 8,
"dim_keys": 64,
"dim_hidden": 2048,
"nb_heads": 8,
- "dim_rec_v": 96,
"nb_blocks": 24,
},
"122M-C": {
"nb_blocks": 24,
},
"122M-C": {
@@
-402,7
+395,6
@@
default_model_args = {
"dim_hidden": 2048,
"nb_heads": 8,
"nb_lines": 128,
"dim_hidden": 2048,
"nb_heads": 8,
"nb_lines": 128,
- "dim_rec_v": 96,
"nb_blocks": 24,
},
"352M": {
"nb_blocks": 24,
},
"352M": {
@@
-411,7
+403,6
@@
default_model_args = {
"dim_keys": 64,
"dim_hidden": 2048,
"nb_heads": 8,
"dim_keys": 64,
"dim_hidden": 2048,
"nb_heads": 8,
- "dim_rec_v": 128,
"nb_blocks": 48,
},
"352M-C": {
"nb_blocks": 48,
},
"352M-C": {
@@
-421,7
+412,6
@@
default_model_args = {
"dim_hidden": 2048,
"nb_heads": 8,
"nb_lines": 128,
"dim_hidden": 2048,
"nb_heads": 8,
"nb_lines": 128,
- "dim_rec_v": 128,
"nb_blocks": 48,
},
}
"nb_blocks": 48,
},
}
@@
-438,7
+428,7
@@
else:
try:
os.mkdir(args.result_dir)
except FileExistsError:
try:
os.mkdir(args.result_dir)
except FileExistsError:
- if not args.
overwrite_results
:
+ if not args.
continue_training
:
print(f"result directory {args.result_dir} already exists")
exit(1)
print(f"result directory {args.result_dir} already exists")
exit(1)
@@
-736,7
+726,6
@@
model = mygpt.MyGPT(
nb_heads=args.nb_heads,
nb_lines=args.nb_lines,
caterpillar_height=args.caterpillar_height,
nb_heads=args.nb_heads,
nb_lines=args.nb_lines,
caterpillar_height=args.caterpillar_height,
- dim_rec_v=args.dim_rec_v,
nb_blocks=args.nb_blocks,
causal=True,
dropout=args.dropout,
nb_blocks=args.nb_blocks,
causal=True,
dropout=args.dropout,
@@
-845,7
+834,7
@@
if nb_epochs_finished >= nb_epochs:
deterministic_synthesis=args.deterministic_synthesis,
)
deterministic_synthesis=args.deterministic_synthesis,
)
-time_pred_result =
None
+time_pred_result =
datetime.datetime.now()
it = 0
it = 0
@@
-923,10
+912,9
@@
for n_epoch in range(nb_epochs_finished, nb_epochs):
)
time_current_result = datetime.datetime.now()
)
time_current_result = datetime.datetime.now()
- if time_pred_result is not None:
- log_string(
- f"next_result {time_current_result + (time_current_result - time_pred_result)}"
- )
+ log_string(
+ f"next_result {time_current_result + (time_current_result - time_pred_result)}"
+ )
time_pred_result = time_current_result
checkpoint = {
time_pred_result = time_current_result
checkpoint = {