projects
/
culture.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[culture.git]
/
main.py
diff --git
a/main.py
b/main.py
index
02e1a8d
..
6b46fa0
100755
(executable)
--- a/
main.py
+++ b/
main.py
@@
-13,7
+13,7
@@
from torch.nn import functional as F
import ffutils
import mygpt
import ffutils
import mygpt
-import sky,
reasoning
, quiz_machine
+import sky,
grids
, quiz_machine
# world quizzes vs. culture quizzes
# world quizzes vs. culture quizzes
@@
-79,7
+79,7
@@
parser.add_argument("--dropout", type=float, default=0.1)
parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
-parser.add_argument("--problem", type=str, default="
sky
")
+parser.add_argument("--problem", type=str, default="
grids
")
parser.add_argument("--nb_gpts", type=int, default=5)
parser.add_argument("--nb_gpts", type=int, default=5)
@@
-251,8
+251,8
@@
if args.problem == "sky":
speed=args.sky_speed,
)
back_accuracy = False
speed=args.sky_speed,
)
back_accuracy = False
-elif args.problem == "
reasoning
":
- problem =
reasoning.Reasoning
(device=device)
+elif args.problem == "
grids
":
+ problem =
grids.Grids
(device=device)
back_accuracy = True
else:
raise ValueError
back_accuracy = True
else:
raise ValueError
@@
-418,6
+418,7
@@
def create_c_quizzes(
)
file_name = os.path.join(args.result_dir, f"culture_c_quiz_{n_epoch:04d}_logp.dat")
)
file_name = os.path.join(args.result_dir, f"culture_c_quiz_{n_epoch:04d}_logp.dat")
+
with open(file_name, "w") as logp_file:
while (
valid_c_quizzes(quizzes_and_nb_correct_records, standard_validity).size(0)
with open(file_name, "w") as logp_file:
while (
valid_c_quizzes(quizzes_and_nb_correct_records, standard_validity).size(0)