projects
/
mygptrnn.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[mygptrnn.git]
/
main.py
diff --git
a/main.py
b/main.py
index
74e1d6c
..
3e67a73
100755
(executable)
--- a/
main.py
+++ b/
main.py
@@
-16,14
+16,6
@@
import mygpt, tasks, problems
######################################################################
######################################################################
-if torch.cuda.is_available():
- device = torch.device("cuda")
- torch.backends.cuda.matmul.allow_tf32 = True
-else:
- device = torch.device("cpu")
-
-######################################################################
-
def str2bool(x):
x = x.lower()
def str2bool(x):
x = x.lower()
@@
-55,6
+47,8
@@
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--max_percents_of_test_in_train", type=int, default=1)
parser.add_argument("--max_percents_of_test_in_train", type=int, default=1)
+parser.add_argument("--force_cpu", type=str2bool, default=False)
+
########################################
parser.add_argument("--nb_epochs", type=int, default=50)
########################################
parser.add_argument("--nb_epochs", type=int, default=50)
@@
-117,7
+111,7
@@
parser.add_argument("--deterministic_synthesis", action="store_true", default=Fa
parser.add_argument("--no_checkpoint", action="store_true", default=False)
parser.add_argument("--no_checkpoint", action="store_true", default=False)
-parser.add_argument("--
overwrite_results
", action="store_true", default=False)
+parser.add_argument("--
continue_training
", action="store_true", default=False)
parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
@@
-208,15
+202,25
@@
parser.add_argument("--mixing_deterministic_start", action="store_true", default
######################################################################
######################################################################
-args = parser.parse_args()
+
#
args = parser.parse_args()
-assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
+args, sup_args = parser.parse_known_args()
+
+sup_args = dict([x.removeprefix("--").split("=") for x in sup_args])
if args.result_dir is None:
args.result_dir = f"results_{args.task}_{args.model}"
######################################################################
if args.result_dir is None:
args.result_dir = f"results_{args.task}_{args.model}"
######################################################################
+if not args.force_cpu and torch.cuda.is_available():
+ device = torch.device("cuda")
+ torch.backends.cuda.matmul.allow_tf32 = True
+else:
+ device = torch.device("cpu")
+
+######################################################################
+
default_task_args = {
"addition": {
"model": "352M",
default_task_args = {
"addition": {
"model": "352M",
@@
-426,10
+430,12
@@
else:
try:
os.mkdir(args.result_dir)
except FileExistsError:
try:
os.mkdir(args.result_dir)
except FileExistsError:
- if not args.
overwrite_results
:
+ if not args.
continue_training
:
print(f"result directory {args.result_dir} already exists")
exit(1)
print(f"result directory {args.result_dir} already exists")
exit(1)
+loss_file = open(os.path.join(args.result_dir, "loss.dat"), "a")
+
log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
if args.seed >= 0:
log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
if args.seed >= 0:
@@
-466,6
+472,9
@@
log_string(f"argv {' '.join(sys.argv)}")
for n in vars(args):
log_string(f"args.{n} {getattr(args, n)}")
for n in vars(args):
log_string(f"args.{n} {getattr(args, n)}")
+for k, v in sup_args.items():
+ log_string(f'sup_args["{k}"] "{v}"')
+
######################################################################
######################################################################
@@
-503,6
+512,9
@@
def get_lr(n_epoch, it):
######################################################################
######################################################################
+assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
+
+
def picoclvr_pruner_horizontal_green(p):
return not ("green" in p and ("left" in p or "right" in p))
def picoclvr_pruner_horizontal_green(p):
return not ("green" in p and ("left" in p or "right" in p))
@@
-728,6
+740,8
@@
model = mygpt.MyGPT(
causal=True,
dropout=args.dropout,
attention_layer=args.attention,
causal=True,
dropout=args.dropout,
attention_layer=args.attention,
+ logger=log_string,
+ **sup_args,
)
model.to(device)
)
model.to(device)
@@
-832,10
+846,12
@@
if nb_epochs_finished >= nb_epochs:
deterministic_synthesis=args.deterministic_synthesis,
)
deterministic_synthesis=args.deterministic_synthesis,
)
-time_pred_result =
None
+time_pred_result =
datetime.datetime.now()
it = 0
it = 0
+n_batch = 0
+
for n_epoch in range(nb_epochs_finished, nb_epochs):
if args.optim == "sgd":
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
for n_epoch in range(nb_epochs_finished, nb_epochs):
if args.optim == "sgd":
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
@@
-877,6
+893,12
@@
for n_epoch in range(nb_epochs_finished, nb_epochs):
total_loss.backward()
optimizer.step()
total_loss.backward()
optimizer.step()
+ grad_norm = sum([p.grad.pow(2).sum() for p in model.parameters()]).sqrt()
+
+ loss_file.write(f"{n_epoch} {n_batch} {loss.item()} {grad_norm.item()}\n")
+
+ n_batch += 1
+
with torch.autograd.no_grad():
model.eval()
with torch.autograd.no_grad():
model.eval()
@@
-910,10
+932,9
@@
for n_epoch in range(nb_epochs_finished, nb_epochs):
)
time_current_result = datetime.datetime.now()
)
time_current_result = datetime.datetime.now()
- if time_pred_result is not None:
- log_string(
- f"next_result {time_current_result + (time_current_result - time_pred_result)}"
- )
+ log_string(
+ f"next_result {time_current_result + (time_current_result - time_pred_result)}"
+ )
time_pred_result = time_current_result
checkpoint = {
time_pred_result = time_current_result
checkpoint = {