log_string(f"vocabulary_size {vocabulary_size}")
-##############################
-
-models = []
-
-for k in range(2):
- models.append(
- mygpt.MyGPT(
- vocabulary_size=vocabulary_size,
- dim_model=args.dim_model,
- dim_keys=args.dim_keys,
- dim_hidden=args.dim_hidden,
- nb_heads=args.nb_heads,
- nb_blocks=args.nb_blocks,
- causal=True,
- dropout=args.dropout,
- ).to(device)
- )
-
-
-nb_parameters = sum(p.numel() for p in models[0].parameters())
-log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
-
######################################################################
# Compute the entropy of the training tokens
log_string(f"learning_rate_schedule {learning_rate_schedule}")
-time_pred_result = None
-
######################################################################
-def one_epoch(model, task, learning_rate):
- log_string(f"learning_rate {learning_rate}")
-
+def one_epoch(model, task):
if args.optim == "sgd":
- optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
+ optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
elif args.optim == "adam":
- optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
+ optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
elif args.optim == "adamw":
- optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
+ optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)
else:
raise ValueError(f"Unknown optimizer {args.optim}.")
log_string(f"test_perplexity {n_epoch} {test_perplexity}")
- return main_test_accuracy
+ model.main_test_accuracy = main_test_accuracy
######################################################################
task,
nb_for_train=1000,
nb_for_test=100,
- nb_runs=10,
- nb_min_correct=9,
- nb_max_correct=9,
):
kept = []
nb=4 * (nb_for_train + nb_for_test),
model=model,
other_models=other_models,
- nb_runs=nb_runs,
)
- to_keep = new_quizzes[
- torch.logical_and(
- nb_correct >= nb_min_correct, nb_correct <= nb_max_correct
- )
- ]
+ to_keep = new_quizzes[nb_correct == len(other_models) - 1]
log_string(f"keep {to_keep.size(0)} quizzes")
kept.append(to_keep)
)
+######################################################################
+
+models = []
+
+for k in range(5):
+ model = mygpt.MyGPT(
+ vocabulary_size=vocabulary_size,
+ dim_model=args.dim_model,
+ dim_keys=args.dim_keys,
+ dim_hidden=args.dim_hidden,
+ nb_heads=args.nb_heads,
+ nb_blocks=args.nb_blocks,
+ causal=True,
+ dropout=args.dropout,
+ ).to(device)
+
+ model.main_test_accuracy = 0.0
+ model.id = k
+
+ models.append(model)
+
+
+nb_parameters = sum(p.numel() for p in models[0].parameters())
+log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
+
######################################################################
accuracy_to_make_quizzes = 0.975
for n_epoch in range(args.nb_epochs):
- learning_rate = learning_rate_schedule[n_epoch]
+ models.sort(key=lambda model: model.main_test_accuracy)
- for m in models:
- one_epoch(m, task, learning_rate)
- test_accuracy = run_tests(m, task, deterministic_synthesis=False)
+ model = models[0]
- if test_accuracy >= accuracy_to_make_quizzes:
- other_models = models.copy()
- other_models.remove(m)
- create_quizzes(m, other_models, task)
+ log_string(
+ f"training model {model.id} main_test_accuracy {model.main_test_accuracy}"
+ )
+
+ one_epoch(model, task)
- # --------------------------------------------
+ log_string(
+ f"train_set_composition world {task.nb_batch_samples_world} quizzes {task.nb_batch_samples_quizzes}"
+ )
- time_current_result = datetime.datetime.now()
- if time_pred_result is not None:
- log_string(
- f"next_result {time_current_result + (time_current_result - time_pred_result)}"
+ run_tests(model, task, deterministic_synthesis=False)
+
+ if model.main_test_accuracy >= accuracy_to_make_quizzes:
+ other_models = models.copy()
+ other_models.remove(model)
+
+ create_quizzes(
+ model,
+ other_models,
+ task,
+ nb_for_train=1000,
+ nb_for_test=100,
)
- time_pred_result = time_current_result
+
######################################################################
torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=8, padding=2)
logger(f"wrote {image_name}")
+ def make_ar_mask(self, input):
+ b = torch.arange(input.size(1), device=input.device) > input.size(1) // 2
+ return b.long()[None, :].expand_as(input)
+
def __init__(
self,
nb_train_samples,
self.train_input = world.generate(
nb_train_samples, height=self.height, width=self.width
- )
- self.train_ar_mask = (
- (torch.arange(self.train_input.size(1)) > self.train_input.size(1) // 2)
- .long()[None, :]
- .expand_as(self.train_input)
- )
+ ).to(device)
self.test_input = world.generate(
nb_test_samples, height=self.height, width=self.width
- )
- self.test_ar_mask = (
- (torch.arange(self.test_input.size(1)) > self.test_input.size(1) // 2)
- .long()[None, :]
- .expand_as(self.test_input)
- )
-
- self.train_input, self.train_ar_mask = self.train_input.to(
- device
- ), self.train_ar_mask.to(device)
- self.test_input, self.test_ar_mask = self.test_input.to(
- device
- ), self.test_ar_mask.to(device)
+ ).to(device)
self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+ self.train_quizzes = []
+ self.test_quizzes = []
+
if result_dir is not None:
self.save_image(
self.train_input[:96], result_dir, f"world_train.png", logger
)
- def batches(self, split="train", nb_to_use=-1, desc=None):
+ def batches(self, split="train", desc=None):
assert split in {"train", "test"}
- input = self.train_input if split == "train" else self.test_input
- if nb_to_use > 0:
- input = input[:nb_to_use]
+ if split == "train":
+ input = self.train_input
+ quizzes = self.train_quizzes
+ else:
+ input = self.test_input
+ quizzes = self.test_quizzes
+
+ if len(quizzes) > 0:
+ quizzes = torch.cat(quizzes, dim=0)
+ if quizzes.size(0) > input.size(0) // 2:
+ i = torch.randperm(input.size(0))[: input.size(0) // 2]
+ quizzes = quizzes[i]
+
+ i = torch.randperm(input.size(0))[: input.size(0) - quizzes.size(0)]
+ input = input[i]
+
+ self.nb_batch_samples_world = input.size(0)
+ self.nb_batch_samples_quizzes = quizzes.size(0)
+
+ input = torch.cat([input, quizzes], dim=0)
+ else:
+ self.nb_batch_samples_world = input.size(0)
+ self.nb_batch_samples_quizzes = 0
+
if desc is None:
desc = f"epoch-{split}"
for batch in tqdm.tqdm(
def produce_results(
self, n_epoch, model, result_dir, logger, deterministic_synthesis, nmax=1000
):
- def compute_accuracy(input, ar_mask, logger=None):
- input, ar_mask = input[:nmax], ar_mask[:nmax]
+ def compute_accuracy(input, logger=None):
+ input = input[:nmax]
+ ar_mask = self.make_ar_mask(input)
result = input.clone() * (1 - ar_mask)
masked_inplace_autoregression(
return nb_total, nb_correct
- train_nb_total, train_nb_correct = compute_accuracy(
- self.train_input, self.train_ar_mask
- )
+ train_nb_total, train_nb_correct = compute_accuracy(self.train_input)
logger(
f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
)
- test_nb_total, test_nb_correct = compute_accuracy(
- self.test_input, self.test_ar_mask, logger
- )
+ test_nb_total, test_nb_correct = compute_accuracy(self.test_input, logger)
logger(
f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
##############################
- input, ar_mask = self.test_input[:96], self.test_ar_mask[:96]
+ input = self.test_input[:96]
+ ar_mask = self.make_ar_mask(input)
result = input.clone() * (1 - ar_mask)
masked_inplace_autoregression(
return main_test_accuracy
def store_new_quizzes(self, new_quizzes, for_train=True):
- input = self.train_input if for_train else self.test_input
-
- nb_current = input.size(0)
- nb_new = new_quizzes.size(0)
- if nb_new >= nb_current:
- input[...] = new_quizzes[:nb_current]
+ if for_train:
+ self.train_quizzes.append(new_quizzes)
else:
- nb_kept = nb_current - nb_new
- input[:nb_kept] = input[-nb_kept:].clone()
- input[nb_kept:] = new_quizzes
+ self.test_quizzes.append(new_quizzes)
def create_new_quizzes(
- self, n_epoch, result_dir, logger, nb, models, other_models, nb_runs
+ self,
+ n_epoch,
+ result_dir,
+ logger,
+ nb,
+ model,
+ other_models,
):
new_quizzes = torch.empty(
nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64
device=self.device,
)
- input = (
- new_quizzes[:, None, :]
- .expand(-1, nb_runs, -1)
- .clone()
- .reshape(-1, new_quizzes.size(-1))
- )
- result = input.clone()
+ ar_mask = self.make_ar_mask(new_quizzes)
- ar_mask = (
- (torch.arange(result.size(1), device=self.device) > result.size(1) // 2)
- .long()[None, :]
- .expand_as(result)
- )
+ nb_correct = 0
- dispatch = torch.randint(len(other_models), (result.size(0),))
+ for m in other_models:
+ result = new_quizzes.clone()
- for n, m in enumerate(other_models):
masked_inplace_autoregression(
m,
self.batch_size,
- result[dispatch == n],
- ar_mask[dispatch == n],
- deterministic_synthesis=False,
- progress_bar_desc=None,
+ result,
+ ar_mask,
+ deterministic_synthesis=True,
+ progress_bar_desc="solving quizzes",
device=self.device,
)
- nb_correct = (
- (input == result).long().min(dim=-1).values.reshape(-1, nb_runs).sum(dim=-1)
- )
+ nb_correct += (new_quizzes == result).long().min(dim=-1).values
return new_quizzes, nb_correct