projects
/
culture.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[culture.git]
/
main.py
diff --git
a/main.py
b/main.py
index
d7fb3d1
..
b88847e
100755
(executable)
--- a/
main.py
+++ b/
main.py
@@
-13,7
+13,7
@@
from torch.nn import functional as F
import ffutils
import mygpt
import ffutils
import mygpt
-import sky, quizz_machine
+import sky,
wireworld,
quizz_machine
# world quizzes vs. culture quizzes
# world quizzes vs. culture quizzes
@@
-37,7
+37,7
@@
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
-parser.add_argument("--log_filename", type=str, default="train.log"
, help=" "
)
+parser.add_argument("--log_filename", type=str, default="train.log")
parser.add_argument("--result_dir", type=str, default=None)
parser.add_argument("--result_dir", type=str, default=None)
@@
-79,6
+79,8
@@
parser.add_argument("--dropout", type=float, default=0.1)
parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
parser.add_argument("--deterministic_synthesis", action="store_true", default=False)
+parser.add_argument("--problem", type=str, default="sky")
+
parser.add_argument("--nb_gpts", type=int, default=5)
parser.add_argument("--nb_models_for_generation", type=int, default=1)
parser.add_argument("--nb_gpts", type=int, default=5)
parser.add_argument("--nb_models_for_generation", type=int, default=1)
@@
-219,8
+221,15
@@
else:
assert args.nb_train_samples % args.batch_size == 0
assert args.nb_test_samples % args.batch_size == 0
assert args.nb_train_samples % args.batch_size == 0
assert args.nb_test_samples % args.batch_size == 0
-quizz_machine = quizz_machine.QuizzMachine(
+if args.problem=="sky":
problem=sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2, speed=2),
problem=sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2, speed=2),
+elif args.problem="wireworld":
+ problem=wireworld.Wireworld(height=10, width=15, nb_iterations=4)
+else:
+ raise ValueError
+
+quizz_machine = quizz_machine.QuizzMachine(
+ problem=problem,
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
batch_size=args.physical_batch_size,
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
batch_size=args.physical_batch_size,
@@
-343,7
+352,6
@@
def run_tests(model, quizz_machine, deterministic_synthesis):
n_epoch=n_epoch,
model=model,
result_dir=args.result_dir,
n_epoch=n_epoch,
model=model,
result_dir=args.result_dir,
- logger=log_string,
deterministic_synthesis=deterministic_synthesis,
)
deterministic_synthesis=deterministic_synthesis,
)
@@
-397,7
+405,6
@@
def create_c_quizzes(
min_ave_seq_logproba=min_ave_seq_logproba,
n_epoch=n_epoch,
result_dir=args.result_dir,
min_ave_seq_logproba=min_ave_seq_logproba,
n_epoch=n_epoch,
result_dir=args.result_dir,
- logger=log_string,
)
sum_logits += new_c_quizzes.size(0) * ave_seq_logproba
)
sum_logits += new_c_quizzes.size(0) * ave_seq_logproba
@@
-487,7
+494,8
@@
for n_epoch in range(args.nb_epochs):
a = [(model.id, float(model.main_test_accuracy)) for model in models]
a.sort(key=lambda p: p[0])
a = [(model.id, float(model.main_test_accuracy)) for model in models]
a.sort(key=lambda p: p[0])
- log_string(f"current accuracies {a}")
+ s = " ".join([f"{p[1]*100:.02f}%" for p in a])
+ log_string(f"current accuracies {s}")
# select the model with lowest accuracy
models.sort(key=lambda model: model.main_test_accuracy)
# select the model with lowest accuracy
models.sort(key=lambda model: model.main_test_accuracy)