projects
/
culture.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[culture.git]
/
main.py
diff --git
a/main.py
b/main.py
index
ee4e9e5
..
05c3557
100755
(executable)
--- a/
main.py
+++ b/
main.py
@@
-12,7
+12,15
@@
from torch import nn
from torch.nn import functional as F
import ffutils
from torch.nn import functional as F
import ffutils
-import mygpt, tasks
+import mygpt, quizz_machine
+
+# world quizzes vs. culture quizzes
+
+######################################################################
+
+accuracy_to_make_c_quizzes = 0.975
+nb_new_c_quizzes_for_train = 1000
+nb_new_c_quizzes_for_test = 100
######################################################################
######################################################################
@@
-73,7
+81,7
@@
parser.add_argument("--deterministic_synthesis", action="store_true", default=Fa
parser.add_argument("--nb_gpts", type=int, default=5)
parser.add_argument("--nb_gpts", type=int, default=5)
-parser.add_argument("--
check
", action="store_true", default=False)
+parser.add_argument("--
dirty_debug
", action="store_true", default=False)
######################################################################
######################################################################
@@
-84,6
+92,13
@@
if args.result_dir is None:
######################################################################
######################################################################
+if args.dirty_debug:
+ accuracy_to_make_c_quizzes = 0.0
+ nb_new_c_quizzes_for_train = 100
+ nb_new_c_quizzes_for_test = 10
+
+######################################################################
+
default_args = {
"model": "37M",
"batch_size": 100,
default_args = {
"model": "37M",
"batch_size": 100,
@@
-182,9
+197,9
@@
for n in vars(args):
######################################################################
######################################################################
-if args.
check
:
- args.nb_train_samples = 2500
0
- args.nb_test_samples = 100
0
+if args.
dirty_debug
:
+ args.nb_train_samples = 2500
+ args.nb_test_samples = 100
if args.physical_batch_size is None:
args.physical_batch_size = args.batch_size
if args.physical_batch_size is None:
args.physical_batch_size = args.batch_size
@@
-194,7
+209,7
@@
else:
assert args.nb_train_samples % args.batch_size == 0
assert args.nb_test_samples % args.batch_size == 0
assert args.nb_train_samples % args.batch_size == 0
assert args.nb_test_samples % args.batch_size == 0
-
task = tasks.World
(
+
quizz_machine = quizz_machine.QuizzMachine
(
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
batch_size=args.physical_batch_size,
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
batch_size=args.physical_batch_size,
@@
-207,7
+222,7
@@
task = tasks.World(
log_string(f"device {device}")
log_string(f"device {device}")
-vocabulary_size =
task
.vocabulary_size()
+vocabulary_size =
quizz_machine
.vocabulary_size()
log_string(f"vocabulary_size {vocabulary_size}")
log_string(f"vocabulary_size {vocabulary_size}")
@@
-216,8
+231,10
@@
log_string(f"vocabulary_size {vocabulary_size}")
# Compute the entropy of the training tokens
token_count = 0
# Compute the entropy of the training tokens
token_count = 0
-for input in task.batches(split="train", desc="train-entropy"):
- token_count += F.one_hot(input, num_classes=task.vocabulary_size()).sum((0, 1))
+for input in quizz_machine.batches(split="train", desc="train-entropy"):
+ token_count += F.one_hot(input, num_classes=quizz_machine.vocabulary_size()).sum(
+ (0, 1)
+ )
token_probas = token_count / token_count.sum()
entropy = -torch.xlogy(token_probas, token_probas).sum()
train_set_perplexity = math.exp(entropy)
token_probas = token_count / token_count.sum()
entropy = -torch.xlogy(token_probas, token_probas).sum()
train_set_perplexity = math.exp(entropy)
@@
-239,11
+256,11
@@
if args.max_percents_of_test_in_train >= 0:
nb_test, nb_in_train = 0, 0
for test_subset in subsets_as_tuples(
nb_test, nb_in_train = 0, 0
for test_subset in subsets_as_tuples(
-
task
.batches(split="test", desc="test-check"), 25000
+
quizz_machine
.batches(split="test", desc="test-check"), 25000
):
in_train = set()
for train_subset in subsets_as_tuples(
):
in_train = set()
for train_subset in subsets_as_tuples(
-
task
.batches(split="train", desc="train-check"), 25000
+
quizz_machine
.batches(split="train", desc="train-check"), 25000
):
in_train.update(test_subset.intersection(train_subset))
nb_in_train += len(in_train)
):
in_train.update(test_subset.intersection(train_subset))
nb_in_train += len(in_train)
@@
-260,14
+277,14
@@
if args.max_percents_of_test_in_train >= 0:
##############################
##############################
-def one_epoch(model,
task
):
+def one_epoch(model,
quizz_machine
):
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
model.train()
nb_train_samples, acc_train_loss = 0, 0.0
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
model.train()
nb_train_samples, acc_train_loss = 0, 0.0
- for input in
task
.batches(split="train"):
+ for input in
quizz_machine
.batches(split="train"):
input = input.to(device)
if nb_train_samples % args.batch_size == 0:
input = input.to(device)
if nb_train_samples % args.batch_size == 0:
@@
-292,14
+309,14
@@
def one_epoch(model, task):
######################################################################
######################################################################
-def run_tests(model,
task
, deterministic_synthesis):
+def run_tests(model,
quizz_machine
, deterministic_synthesis):
with torch.autograd.no_grad():
model.eval()
nb_test_samples, acc_test_loss = 0, 0.0
nb_samples_accumulated = 0
with torch.autograd.no_grad():
model.eval()
nb_test_samples, acc_test_loss = 0, 0.0
nb_samples_accumulated = 0
- for input in
task
.batches(split="test"):
+ for input in
quizz_machine
.batches(split="test"):
input = input.to(device)
bs = model(mygpt.BracketedSequence(input))
input = input.to(device)
bs = model(mygpt.BracketedSequence(input))
@@
-311,7
+328,7
@@
def run_tests(model, task, deterministic_synthesis):
nb_test_samples += input.size(0)
nb_test_samples += input.size(0)
- main_test_accuracy =
task
.produce_results(
+ main_test_accuracy =
quizz_machine
.produce_results(
n_epoch=n_epoch,
model=model,
result_dir=args.result_dir,
n_epoch=n_epoch,
model=model,
result_dir=args.result_dir,
@@
-329,52
+346,58
@@
def run_tests(model, task, deterministic_synthesis):
######################################################################
######################################################################
-def create_quizzes(
+def create_
c_
quizzes(
model,
other_models,
model,
other_models,
-
task
,
+
quizz_machine
,
nb_for_train=1000,
nb_for_test=100,
nb_for_train=1000,
nb_for_test=100,
-
desired_average_logits
=None,
+
min_ave_seq_logproba
=None,
):
kept = []
):
kept = []
- sum_logits
=
0
+ sum_logits
, sum_nb_c_quizzes = 0,
0
while sum([x.size(0) for x in kept]) < nb_for_train + nb_for_test:
nb_to_generate = 4 * (nb_for_train + nb_for_test)
while sum([x.size(0) for x in kept]) < nb_for_train + nb_for_test:
nb_to_generate = 4 * (nb_for_train + nb_for_test)
- new_
quizzes, nb_correct, _sum_logits = task.create_new
_quizzes(
+ new_
c_quizzes, nb_correct, ave_seq_logproba = quizz_machine.create_c
_quizzes(
n_epoch=n_epoch,
result_dir=args.result_dir,
logger=log_string,
nb=nb_to_generate,
model=model,
other_models=other_models,
n_epoch=n_epoch,
result_dir=args.result_dir,
logger=log_string,
nb=nb_to_generate,
model=model,
other_models=other_models,
-
desired_average_logits=desired_average_logits
,
+
min_ave_seq_logproba=min_ave_seq_logproba
,
)
)
- sum_logits += _sum_logits
+ sum_logits += new_c_quizzes.size(0) * ave_seq_logproba
+ sum_nb_c_quizzes += new_c_quizzes.size(0)
+
+ to_keep = new_c_quizzes[nb_correct == len(other_models) - 1]
+
+ if args.dirty_debug:
+ to_keep = new_c_quizzes
- to_keep = new_quizzes[nb_correct == len(other_models) - 1]
log_string(
log_string(
- f"keep {to_keep.size(0)}/{new_
quizzes.size(0)} quizzes ({to_keep.size(0)*100/new
_quizzes.size(0):.02f}%)"
+ f"keep {to_keep.size(0)}/{new_
c_quizzes.size(0)} c_quizzes ({to_keep.size(0)*100/new_c
_quizzes.size(0):.02f}%)"
)
)
+
kept.append(to_keep)
kept.append(to_keep)
- new_quizzes = torch.cat(kept, dim=0)[: nb_for_train + nb_for_test]
+ new_
c_
quizzes = torch.cat(kept, dim=0)[: nb_for_train + nb_for_test]
-
task.store_new_quizzes(new
_quizzes[:nb_for_train], for_train=True)
-
task.store_new_quizzes(new
_quizzes[nb_for_train:], for_train=False)
+
quizz_machine.store_c_quizzes(new_c
_quizzes[:nb_for_train], for_train=True)
+
quizz_machine.store_c_quizzes(new_c
_quizzes[nb_for_train:], for_train=False)
-
task.save_image
(
- new_quizzes[:72],
+
quizz_machine.save_quizzes
(
+ new_
c_
quizzes[:72],
args.result_dir,
args.result_dir,
- f"
world_quiz_{n_epoch:04d}_{model.id:02d}.png
",
+ f"
culture_c_quiz_{n_epoch:04d}_{model.id:02d}
",
log_string,
)
log_string,
)
- return sum_logits /
new_quizzes.size(0)
+ return sum_logits /
sum_nb_c_quizzes
######################################################################
######################################################################
@@
-404,16
+427,7
@@
log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
######################################################################
######################################################################
-accuracy_to_make_quizzes = 0.975
-nb_new_quizzes_for_train = 1000
-nb_new_quizzes_for_test = 100
-
-if args.check:
- accuracy_to_make_quizzes = 0.0
- nb_new_quizzes_for_train = 100
- nb_new_quizzes_for_test = 10
-
-desired_average_logits = None
+min_ave_seq_logproba = None
for n_epoch in range(args.nb_epochs):
log_string(f"--- epoch {n_epoch} ----------------------------------------")
for n_epoch in range(args.nb_epochs):
log_string(f"--- epoch {n_epoch} ----------------------------------------")
@@
-431,45
+445,45
@@
for n_epoch in range(args.nb_epochs):
)
# improve it
)
# improve it
- one_epoch(model,
task
)
+ one_epoch(model,
quizz_machine
)
-
task.renew_sampl
es(args.nb_train_samples // args.nb_gpts)
+
quizz_machine.renew_w_quizz
es(args.nb_train_samples // args.nb_gpts)
log_string(
log_string(
- f"train_set_composition w
orld {task.nb_batch_samples_world} quizzes {task.nb_batch_samples
_quizzes}"
+ f"train_set_composition w
_quizzes {quizz_machine.nb_batch_w_quizzes} c_quizzes {quizz_machine.nb_batch_c
_quizzes}"
)
# test it
)
# test it
- run_tests(model,
task
, deterministic_synthesis=False)
+ run_tests(model,
quizz_machine
, deterministic_synthesis=False)
log_string(
log_string(
- f"test_set_composition w
orld {task.nb_batch_samples_world} quizzes {task.nb_batch_samples
_quizzes}"
+ f"test_set_composition w
_quizzes {quizz_machine.nb_batch_w_quizzes} c_quizzes {quizz_machine.nb_batch_c
_quizzes}"
)
)
- if min([m.main_test_accuracy for m in models]) >= accuracy_to_make_quizzes:
+ if min([m.main_test_accuracy for m in models]) >= accuracy_to_make_
c_
quizzes:
other_models = models.copy()
other_models.remove(model)
other_models = models.copy()
other_models.remove(model)
- ave
rage_logits = create
_quizzes(
+ ave
_seq_logproba = create_c
_quizzes(
model,
other_models,
model,
other_models,
-
task
,
- nb_for_train=nb_new_quizzes_for_train,
- nb_for_test=nb_new_quizzes_for_test,
-
desired_average_logits=desired_average_logits
,
+
quizz_machine
,
+ nb_for_train=nb_new_
c_
quizzes_for_train,
+ nb_for_test=nb_new_
c_
quizzes_for_test,
+
min_ave_seq_logproba=min_ave_seq_logproba
,
)
# We keep the first average logits as a reference
)
# We keep the first average logits as a reference
- if
desired_average_logits
is None:
- desired_average_logits = average_logits
+ if
min_ave_seq_logproba
is None:
+ min_ave_seq_logproba = ave_seq_logproba
else:
log_string(
else:
log_string(
- f"
desired_average_logits {desired_average_logits} average_logits {average_logits
}"
+ f"
min_ave_seq_logproba {min_ave_seq_logproba} ave_seq_logproba {ave_seq_logproba
}"
)
# We update everyone
for model in models:
)
# We update everyone
for model in models:
- run_tests(model,
task
, deterministic_synthesis=False)
+ run_tests(model,
quizz_machine
, deterministic_synthesis=False)
######################################################################
######################################################################