import ffutils
import mygpt
-import sky, grids, quiz_machine
-
-from quiz_machine import one_batch_masked_inplace_autoregression
+import sky, grids
import threading, subprocess
######################################################################
-
-# ------------------------------------------------------
-alien_problem = grids.Grids(
- max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
- chunk_size=100,
- nb_threads=args.nb_threads,
- tasks="symmetry",
-)
-
-alien_quiz_machine = quiz_machine.QuizMachine(
- problem=alien_problem,
- batch_size=args.eval_batch_size,
- result_dir=args.result_dir,
- logger=log_string,
- device=main_device,
-)
-# ------------------------------------------------------
-
-######################################################################
-
problem = grids.Grids(
max_nb_cached_chunks=len(gpus) * args.nb_train_samples // 100,
chunk_size=100,
if not args.resume:
problem.save_some_examples(args.result_dir)
-quiz_machine = quiz_machine.QuizMachine(
- problem=problem,
- batch_size=args.eval_batch_size,
- result_dir=args.result_dir,
- logger=log_string,
- device=main_device,
-)
+
+def pure_noise(nb, device):
+ r = problem.pure_noise(nb, device)
+ r = r.view(r.size(0), 4, -1)[:, :, 1:].reshape(r.size(0), -1)
+ return r
+
+
+def quiz_set(nb_samples, c_quizzes, c_quiz_multiplier=1):
+ if c_quizzes is None:
+ quizzes = problem.generate_w_quizzes(nb_samples)
+ quizzes = quizzes.view(quizzes.size(0), 4, -1)[:, :, 1:].reshape(
+ quizzes.size(0), -1
+ )
+ nb_w_quizzes = quizzes.size(0)
+ nb_c_quizzes = 0
+ else:
+ if c_quiz_multiplier > 1:
+ n = min(c_quiz_multiplier, (nb_samples // 2) // c_quizzes.size(0))
+ body = c_quizzes.repeat(n, 1)
+ if n < c_quiz_multiplier:
+ tail = c_quizzes[
+ torch.randperm(c_quizzes.size(0))[: nb_samples // 2 - body.size(0)]
+ ]
+ c_quizzes = torch.cat([body, tail], dim=0)
+ else:
+ c_quizzes = body
+
+ if c_quizzes.size(0) > nb_samples // 2:
+ i = torch.randperm(c_quizzes.size(0))[: nb_samples // 2]
+ c_quizzes = c_quizzes[i]
+
+ w_quizzes = problem.generate_w_quizzes(nb_samples - c_quizzes.size(0))
+ w_quizzes = w_quizzes.view(w_quizzes.size(0), 4, -1)[:, :, 1:].reshape(
+ w_quizzes.size(0), -1
+ )
+ quizzes = torch.cat([w_quizzes, c_quizzes], dim=0)
+ nb_w_quizzes = w_quizzes.size(0)
+ nb_c_quizzes = c_quizzes.size(0)
+
+ i = torch.randperm(quizzes.size(0), device=quizzes.device)
+ quizzes = quizzes[i].contiguous()
+
+ log_string(f"quiz_set nb_w_quizzes {nb_w_quizzes} nb_c_quizzes {nb_c_quizzes}")
+
+ return quizzes
+
######################################################################
log_string(f"main_device {main_device} gpus {[ str(g) for g in gpus]}")
-vocabulary_size = quiz_machine.vocabulary_size()
+vocabulary_size = problem.nb_token_values
log_string(f"vocabulary_size {vocabulary_size}")
"""Replace every component of the input by a random value with
probability args.proba_prompt_noise."""
input, masks, targets = imt_set[:, 0], imt_set[:, 1], imt_set[:, 2]
- noise = quiz_machine.pure_noise(input.size(0), input.device)
+ noise = pure_noise(input.size(0), input.device)
change = (1 - masks) * (
torch.rand(input.size(), device=input.device) < args.proba_prompt_noise
).long()
proba_erased = 1 - (1 - args.diffusion_proba_corruption) ** t
mask_erased = (r <= proba_erased[:, None]).long()
- noise = quiz_machine.pure_noise(nb, input.device)
+ noise = pure_noise(nb, input.device)
targets = input
input = (1 - mask_erased) * input + mask_erased * noise
masks = input.new_full(input.size(), 1)
# mini-batches second so that we keep only the samples that have
# not stabilized
- all_input = quiz_machine.pure_noise(nb, local_device)
+ all_input = pure_noise(nb, local_device)
all_masks = all_input.new_full(all_input.size(), 1)
all_changed = torch.full((all_input.size(0),), True, device=all_input.device)
def one_epoch(model, n_epoch, c_quizzes, local_device=main_device, train=True):
- quizzes = quiz_machine.quiz_set(
+ quizzes = quiz_set(
args.nb_train_samples if train else args.nb_test_samples,
c_quizzes,
args.c_quiz_multiplier,
# Save some original world quizzes and the full prediction (the four grids)
- quizzes = quiz_machine.quiz_set(25, c_quizzes, args.c_quiz_multiplier).to(
- local_device
- )
- quiz_machine.problem.save_quizzes_as_image(
+ quizzes = quiz_set(25, c_quizzes, args.c_quiz_multiplier).to(local_device)
+ problem.save_quizzes_as_image(
args.result_dir, f"test_{n_epoch}_{model.id}.png", quizzes=quizzes
)
result = predict_full(model=model, input=quizzes, local_device=local_device)
- quiz_machine.problem.save_quizzes_as_image(
+ problem.save_quizzes_as_image(
args.result_dir, f"test_{n_epoch}_{model.id}_predict_full.png", quizzes=result
)
# Save some images of the prediction results
- quizzes = quiz_machine.quiz_set(
- args.nb_test_samples, c_quizzes, args.c_quiz_multiplier
- )
+ quizzes = quiz_set(args.nb_test_samples, c_quizzes, args.c_quiz_multiplier)
imt_set = samples_for_prediction_imt(quizzes.to(local_device))
result = ae_predict(model, imt_set, local_device=local_device).to("cpu")
masks = imt_set[:, 1].to("cpu")
]
predicted_parts = correct_parts.abs()
- quiz_machine.problem.save_quizzes_as_image(
+ problem.save_quizzes_as_image(
args.result_dir,
f"culture_prediction_{n_epoch}_{model.id}.png",
quizzes=result[:128],
# Save some images of the ex nihilo generation of the four grids
result = ae_generate(model, 150, local_device=local_device).to("cpu")
- quiz_machine.problem.save_quizzes_as_image(
+ problem.save_quizzes_as_image(
args.result_dir,
f"culture_generation_{n_epoch}_{model.id}.png",
quizzes=result[:128],
comments = [f"nb_correct {c} nb_wrong {w}" for c, w in zip(nb_correct, nb_wrong)]
- quiz_machine.problem.save_quizzes_as_image(
+ problem.save_quizzes_as_image(
args.result_dir,
filename,
quizzes=c_quizzes,
log_string(f"nb_parameters {nb_parameters} ({int(nb_parameters/1e6)}M)")
-######################################################################
-
-if args.quizzes is not None:
- with open(args.quizzes, "r") as file:
- txt = file.read()
-
- quizzes = quiz_machine.problem.text2quiz(txt)
-
- record = []
-
- quizzes = quizzes.to(main_device)
- for model in models:
- log_string(f"processing {model.id} {args.quizzes}")
- for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
- mask_generate = quiz_machine.make_quiz_mask(
- quizzes=quizzes, quad_order=("A", "f_A", "B", "f_B"), quad_mask=quad
- )
- result = ae_generate(model, (1 - mask_generate) * quizzes, mask_generate)
- record.append(result)
-
- result = torch.cat(record, dim=0)
-
- filename = "result.png"
-
- quiz_machine.problem.save_quizzes_as_image(
- args.result_dir,
- filename,
- quizzes=result,
- delta=True,
- nrow=8,
- )
-
- log_string(f"wrote {filename}")
-
- exit(0)
-
-
######################################################################
c_quizzes = None
+++ /dev/null
-#!/usr/bin/env python
-
-# Any copyright is dedicated to the Public Domain.
-# https://creativecommons.org/publicdomain/zero/1.0/
-
-# Written by Francois Fleuret <francois@fleuret.org>
-
-import math, os, tqdm, warnings, sys
-
-import torch, torchvision
-
-from torch import nn
-from torch.nn import functional as F
-
-import mygpt
-from mygpt import BracketedSequence
-
-import threading
-
-######################################################################
-
-# ar_mask is a tensor with 0s and 1s, of same shape as input, with
-# 1s where tokens should be generated. The others are kept
-# unchanged.
-
-
-def one_batch_masked_inplace_autoregression(
- model,
- input,
- ar_mask,
- acc_seq_logprobas,
- deterministic_synthesis=False,
-):
- if input.size(0) == 0:
- return
-
- to_generate = (ar_mask.sum(0) > 0).nonzero()
-
- if to_generate.min() > 0:
- model(
- BracketedSequence(input, 0, to_generate.min())
- ) # Needed to initialize the model's cache
- for s in range(to_generate.min(), to_generate.max() + 1):
- output = model(BracketedSequence(input, s, 1)).x
-
- logits = output[:, s]
-
- if deterministic_synthesis:
- t_next = logits.argmax(-1)
- else:
- dist = torch.distributions.categorical.Categorical(logits=logits)
- t_next = dist.sample()
-
- all_n = torch.arange(t_next.size(0))
-
- acc_seq_logprobas += ar_mask[:, s] * logits.log_softmax(dim=1)[all_n, t_next]
-
- input[:, s] = ar_mask[:, s] * t_next + (1 - ar_mask[:, s]) * input[:, s]
-
-
-######################################################################
-
-
-class QuizMachine:
- def __init__(
- self,
- problem,
- batch_size,
- result_dir,
- logger,
- device=torch.device("cpu"),
- ):
- super().__init__()
-
- self.problem = problem
- self.batch_size = batch_size
- self.device = device
- self.logger = logger
- self.prompt_len = None
- self.answer_len = None
-
- # quad_order, quad_generate, quad_noise, quad_loss
- self.train_structures = [
- (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
- (("f_A", "A", "f_B", "B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
- (("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
- (("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
- # (("B", "f_B", "A", "f_A"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
- # (("f_B", "B", "f_A", "A"), (0, 0, 0, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
- (("f_B", "f_A", "A", "B"), (0, 1, 1, 1), (0, 0, 0, 0), (1, 1, 1, 1)),
- ]
-
- self.test_structures = self.train_structures
-
- def vocabulary_size(self):
- return self.problem.nb_token_values
-
- ######################################################################
-
- def autoregression(
- self,
- model,
- input,
- ar_mask,
- seq_logprobas,
- progress_bar_desc=None,
- ):
- assert input.size() == ar_mask.size()
-
- batches = zip(
- input.split(self.batch_size),
- ar_mask.split(self.batch_size),
- seq_logprobas.split(self.batch_size),
- )
-
- if progress_bar_desc is not None:
- batches = tqdm.tqdm(
- batches,
- dynamic_ncols=True,
- desc=progress_bar_desc,
- total=(input.size(0) + self.batch_size - 1) // self.batch_size,
- )
-
- with torch.autograd.no_grad():
- t = model.training
- model.eval()
-
- for input, ar_mask, seq_logprobas in batches:
- one_batch_masked_inplace_autoregression(
- model=model,
- input=input,
- ar_mask=ar_mask,
- acc_seq_logprobas=seq_logprobas,
- deterministic_synthesis=False,
- )
-
- model.train(t)
-
- ######################################################################
-
- def data_input(
- self, nb_samples, c_quiz_bags=[], c_quiz_multiplier=1, data_structures=None
- ):
- if data_structures is None:
- data_structures = self.train_structures
-
- if len(c_quiz_bags) > 0:
- c_quizzes = torch.cat(c_quiz_bags, dim=0)
-
- if c_quiz_multiplier > 1:
- n = min(c_quiz_multiplier, (nb_samples // 2) // c_quizzes.size(0))
- body = c_quizzes.repeat(n, 1)
- if n < c_quiz_multiplier:
- tail = c_quizzes[
- torch.randperm(c_quizzes.size(0))[
- : nb_samples // 2 - body.size(0)
- ]
- ]
- c_quizzes = torch.cat([body, tail], dim=0)
- else:
- c_quizzes = body
-
- if c_quizzes.size(0) > nb_samples // 2:
- i = torch.randperm(c_quizzes.size(0))[: nb_samples // 2]
- c_quizzes = c_quizzes[i]
-
- w_quizzes = self.problem.generate_w_quizzes(nb_samples - c_quizzes.size(0))
- quizzes = torch.cat([w_quizzes, c_quizzes], dim=0)
- else:
- quizzes = self.problem.generate_w_quizzes(nb_samples)
-
- # shuffle
-
- i = torch.randperm(quizzes.size(0), device=quizzes.device)
- quizzes = quizzes[i]
-
- # Re-order and inject noise
-
- quiz_mask_generate = quizzes.new_full(quizzes.size(), 1)
- quiz_mask_loss = quizzes.new_full(quizzes.size(), 1)
- order_ids = torch.randint(len(data_structures), (quizzes.size(0),))
-
- for j, s in enumerate(data_structures):
- quad_order, quad_generate, quad_noise, quad_loss = s
- i = order_ids == j
- quizzes[i] = self.problem.reconfigure(quizzes[i], quad_order=quad_order)
- quiz_mask_generate[i] = self.make_quiz_mask(
- quizzes=quizzes[i], quad_order=quad_order, quad_mask=quad_generate
- )
- quiz_mask_loss[i] = self.make_quiz_mask(
- quizzes=quizzes[i], quad_order=quad_order, quad_mask=quad_loss
- )
-
- return quizzes, quiz_mask_generate, quiz_mask_loss
-
- ######################################################################
-
- def pure_noise(self, nb, device):
- r = self.problem.pure_noise(nb, device)
- r = r.view(r.size(0), 4, -1)[:, :, 1:].reshape(r.size(0), -1)
- return r
-
- def quiz_set(self, nb_samples, c_quizzes, c_quiz_multiplier=1):
- if c_quizzes is None:
- quizzes = self.problem.generate_w_quizzes(nb_samples)
- quizzes = quizzes.view(quizzes.size(0), 4, -1)[:, :, 1:].reshape(
- quizzes.size(0), -1
- )
- nb_w_quizzes = quizzes.size(0)
- nb_c_quizzes = 0
- else:
- if c_quiz_multiplier > 1:
- n = min(c_quiz_multiplier, (nb_samples // 2) // c_quizzes.size(0))
- body = c_quizzes.repeat(n, 1)
- if n < c_quiz_multiplier:
- tail = c_quizzes[
- torch.randperm(c_quizzes.size(0))[
- : nb_samples // 2 - body.size(0)
- ]
- ]
- c_quizzes = torch.cat([body, tail], dim=0)
- else:
- c_quizzes = body
-
- if c_quizzes.size(0) > nb_samples // 2:
- i = torch.randperm(c_quizzes.size(0))[: nb_samples // 2]
- c_quizzes = c_quizzes[i]
-
- w_quizzes = self.problem.generate_w_quizzes(nb_samples - c_quizzes.size(0))
- w_quizzes = w_quizzes.view(w_quizzes.size(0), 4, -1)[:, :, 1:].reshape(
- w_quizzes.size(0), -1
- )
- quizzes = torch.cat([w_quizzes, c_quizzes], dim=0)
- nb_w_quizzes = w_quizzes.size(0)
- nb_c_quizzes = c_quizzes.size(0)
-
- i = torch.randperm(quizzes.size(0), device=quizzes.device)
- quizzes = quizzes[i].contiguous()
-
- logger(f"quiz_set nb_w_quizzes {nb_w_quizzes} nb_c_quizzes {nb_c_quizzes}")
-
- return quizzes
-
- ######################################################################
-
- def make_quiz_mask(self, quizzes, quad_order, quad_mask):
- assert quad_order in [s for s, _, _, _ in self.train_structures]
- return self.problem.make_quiz_mask(
- quizzes, quad_order=quad_order, quad_mask=quad_mask
- )
-
- ######################################################################
-
- def predict(self, model, quizzes, quad_order, quad_mask):
- quizzes = quizzes.to(self.device)
- ar_mask = self.make_quiz_mask(
- quizzes=quizzes, quad_order=quad_order, quad_mask=quad_mask
- )
- result = quizzes * (1 - ar_mask)
-
- seq_logprobas = torch.zeros(quizzes.size(0), device=self.device)
-
- self.autoregression(
- model=model,
- input=result,
- ar_mask=ar_mask,
- seq_logprobas=seq_logprobas,
- progress_bar_desc="autoregression",
- )
-
- correct = (result == quizzes).min(dim=1).values.long()
-
- # result = result.to("cpu")
- # correct = correct.to("cpu")
- # seq_logprobas = seq_logprobas.to("cpu")
-
- return result, correct, seq_logprobas
-
- ######################################################################
-
- def produce_results(self, n_epoch, model, input, result_dir):
- input = input.to(self.device)
- result = input.new(input.size())
- correct = input.new(input.size(0))
- predicted_parts = input.new(input.size(0), 4)
-
- nb = 0
-
- # We consider all the configurations that we train for
- for quad_order, quad_generate, _, _ in self.test_structures:
- i = self.problem.indices_select(quizzes=input, quad_order=quad_order)
- nb += i.long().sum()
- result[i], correct[i], _ = self.predict(
- model=model, quizzes=input[i], quad_order=quad_order, quad=quad_generate
- )
-
- predicted_parts[i] = torch.tensor(quad_generate, device=self.device)[
- None, :
- ]
- solution_is_deterministic = predicted_parts[i].sum(dim=-1) == 1
- correct[i] = (2 * correct[i] - 1) * (solution_is_deterministic).long()
-
- assert nb == input.size(0)
-
- nb_correct = (correct == 1).long().sum()
- nb_total = (correct != 0).long().sum()
- self.logger(
- f"test_accuracy {n_epoch} model {model.id} val {nb_correct} / {nb_total}"
- )
-
- test_accuracy = (nb_correct / nb_total).item()
-
- ##############################
-
- correct_parts = predicted_parts * correct[:, None]
-
- result = result[:128]
- predicted_parts = predicted_parts[:128]
- correct_parts = correct_parts[:128]
-
- self.problem.save_quizzes_as_image(
- result_dir,
- f"culture_prediction_{n_epoch:04d}_{model.id:02d}.png",
- quizzes=result,
- predicted_parts=predicted_parts,
- correct_parts=correct_parts,
- )
-
- return test_accuracy
-
- ######################################################################
-
- def randomize_configuations_inplace(self, quizzes, quad_orders):
- r = torch.randint(len(quad_orders), (quizzes.size(0),), device=quizzes.device)
- for c in range(len(quad_orders)):
- quizzes[r == c] = self.problem.reconfigure(
- quizzes[r == c], quad_order=quad_orders[c]
- )
-
- ######################################################################
-
- def store_c_quizzes(self, new_c_quizzes, for_train=True):
- with self.LOCK_C_QUIZZES:
- if for_train:
- self.train_c_quizzes.append(new_c_quizzes.to("cpu"))
- else:
- self.test_c_quizzes.append(new_c_quizzes.to("cpu"))
-
- def save_c_quizzes(self, filename):
- torch.save((self.train_c_quizzes, self.test_c_quizzes), filename)
-
- def load_c_quizzes(self, filename):
- self.train_c_quizzes, self.test_c_quizzes = torch.load(filename)
-
- ######################################################################
-
- def models_logprobas(
- self,
- model,
- c_quizzes,
- quad_order,
- quad_loss,
- quad_noise=None,
- temperature=1.0,
- device=None,
- ):
- if device is None:
- device = self.device
-
- c_quizzes = self.problem.reconfigure(c_quizzes, quad_order)
-
- seq_logprobas = torch.zeros(
- c_quizzes.size(0),
- device=device,
- )
-
- with torch.autograd.no_grad():
- t = model.training
- model.eval()
-
- for input, l in zip(
- c_quizzes.split(self.batch_size),
- seq_logprobas.split(self.batch_size),
- ):
- input = input.to(device)
- quiz_mask_loss = self.make_quiz_mask(
- input, quad_order=quad_order, quad_mask=quad_loss
- )
- output = model(mygpt.BracketedSequence(input)).x / temperature
- l[...] = (
- -F.cross_entropy(output.transpose(1, 2), input, reduction="none")
- * quiz_mask_loss
- ).sum(dim=1)
-
- model.train(t)
-
- return seq_logprobas.to("cpu")
-
- ######################################################################
-
- def generate_c_quizzes(self, nb, model_for_generation, procedure, recorder=None):
- seq_logprobas = torch.zeros(nb, device=self.device)
-
- c_quizzes = None
-
- for n_step, setup in enumerate(procedure):
- quad_order, quad_generate, model_modifier = setup
- if c_quizzes is None:
- c_quizzes = self.problem.create_empty_quizzes(nb, quad_order)
- c_quizzes = c_quizzes.to(self.device)
- elif quad_order != pred_quad_order:
- c_quizzes = self.problem.reconfigure(c_quizzes, quad_order)
- pred_quad_order = quad_order
-
- if model_modifier is not None:
- model_modifier(model_for_generation)
-
- self.autoregression(
- model=model_for_generation,
- input=c_quizzes,
- ar_mask=self.make_quiz_mask(
- quizzes=c_quizzes, quad_order=quad_order, quad_mask=quad_generate
- ),
- seq_logprobas=seq_logprobas,
- progress_bar_desc=f"autoregression {n_step+1}/{len(procedure)}",
- )
-
- model_for_generation.reset_transformations()
-
- if recorder is not None:
- x = c_quizzes.clone()
- t = torch.tensor(quad_generate, device=x.device)[None, :].expand(
- x.size(0), -1
- )
- recorder.append(
- self.problem.reconfigure([x, t], ("A", "f_A", "B", "f_B"))
- )
-
- c_quizzes = self.problem.reconfigure(c_quizzes, ("A", "f_A", "B", "f_B"))
-
- return c_quizzes.to("cpu")
-
- ######################################################################