######################################################################
-
-def bag_len(bag):
- return sum([x.size(0) for x in bag])
-
-
-def bag_to_tensors(bag):
- return tuple(torch.cat([x[i] for x in bag], dim=0) for i in range(len(bag[0])))
-
-
-######################################################################
-
# If we need to move an optimizer to a different device
)
-######################################################################
-
-
-class TokenCat(nn.Module):
- def __init__(self, m, n):
- super().__init__()
- self.m = m
- self.n = n
-
- def forward(self, x):
- u = torch.cat([x.new_zeros(x.size(0), self.n), x], dim=1)
- u = self.m(u)
- u = u[:, self.n :]
- return u
-
-
######################################################################
import attae
######################################################################
-def quiz_validation_(
- models,
- c_quizzes,
- local_device,
- nb_have_to_be_correct,
- nb_have_to_be_wrong,
- nb_mistakes_to_be_wrong,
- nb_hints,
- nb_runs=1,
-):
- ######################################################################
- # If too many with process per-batch
-
- if c_quizzes.size(0) > args.inference_batch_size:
- record = []
- for q, nh in zip(
- c_quizzes.split(args.inference_batch_size),
- nb_hints.split(args.inference_batch_size),
- ):
- record.append(
- quiz_validation(
- models=models,
- c_quizzes=q,
- local_device=local_device,
- nb_have_to_be_correct=nb_have_to_be_correct,
- nb_have_to_be_wrong=nb_have_to_be_wrong,
- nb_mistakes_to_be_wrong=nb_mistakes_to_be_wrong,
- nb_hints=nh,
- nb_runs=nb_runs,
- )
- )
-
- r = []
- for k in range(len(record[0])):
- r.append(torch.cat([x[k] for x in record], dim=0))
-
- return tuple(r)
- ######################################################################
-
- record_wrong = []
- nb_correct, nb_wrong = 0, 0
-
- for i, model in enumerate(models):
- assert i == model.id # a bit of paranoia
- model = copy.deepcopy(model).to(local_device).eval()
- correct, wrong = True, False
- for quad in [(1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)]:
- mask_generate = quiz_machine.make_quiz_mask(
- quizzes=c_quizzes,
- quad_order=("A", "f_A", "B", "f_B"),
- quad_mask=quad,
- )
-
- sub_correct, sub_wrong = False, True
- for _ in range(nb_runs):
- result = diffuser.generate(
- model=model,
- x_0=c_quizzes,
- mask_generate=mask_generate,
- nb_hints=nb_hints,
- )
-
- nb_mistakes = (result != c_quizzes).long().sum(dim=1)
- sub_correct = sub_correct | (nb_mistakes == 0)
- sub_wrong = sub_wrong & (nb_mistakes >= nb_mistakes_to_be_wrong)
-
- correct = correct & sub_correct
- wrong = wrong | sub_wrong
-
- record_wrong.append(wrong[:, None])
- nb_correct += correct.long()
- nb_wrong += wrong.long()
-
- to_keep = (nb_correct >= nb_have_to_be_correct) & (nb_wrong >= nb_have_to_be_wrong)
-
- wrong = torch.cat(record_wrong, dim=1)
-
- return to_keep, nb_correct, nb_wrong, wrong
-
-
-######################################################################
-
-
-def generate_c_quizzes_(models, nb, local_device=main_device):
- # To be thread-safe we must make copies
-
- def copy_for_inference(model):
- return copy.deepcopy(model).to(local_device).eval()
-
- quad_order = ("A", "f_A", "B", "f_B")
-
- template = quiz_machine.problem.create_empty_quizzes(
- nb=args.inference_batch_size, quad_order=quad_order
- ).to(local_device)
-
- wanted_nb = nb
- nb_to_save = 256
- nb_c_quizzes_per_model = torch.zeros(len(models), device=local_device)
-
- with torch.autograd.no_grad():
- record_c_quizzes, record_agreements = [], []
-
- last_log = -1
- start_time = time.perf_counter()
-
- while nb_c_quizzes_per_model.min() < wanted_nb:
- model = copy_for_inference(models[torch.randint(len(models), (1,)).item()])
- generator_id = model.id
-
- mask_generate = quiz_machine.make_quiz_mask(
- quizzes=template, quad_order=quad_order, quad_mask=(1, 1, 1, 1)
- )
-
- c_quizzes = diffuser.generate(model, template, mask_generate)
-
- to_keep = quiz_machine.problem.trivial(c_quizzes) == False
- c_quizzes = c_quizzes[to_keep]
-
- nb_hints = torch.full(
- (c_quizzes.size(0),), args.nb_hints, device=c_quizzes.device
- )
-
- if c_quizzes.size(0) > 0:
- to_keep, nb_correct, nb_wrong, record_wrong = quiz_validation(
- models,
- c_quizzes,
- local_device,
- nb_have_to_be_correct=args.nb_have_to_be_correct,
- nb_have_to_be_wrong=args.nb_have_to_be_wrong,
- nb_mistakes_to_be_wrong=args.nb_mistakes_to_be_wrong,
- nb_hints=nb_hints,
- nb_runs=args.nb_runs,
- )
-
- # to_keep[...]=True
-
- q = c_quizzes[to_keep]
-
- if q.size(0) > 0:
- record_c_quizzes.append(q)
- a = (record_wrong == False)[to_keep]
- record_agreements.append(a)
- nb_c_quizzes_per_model += a.long().sum(dim=0)
-
- duration = time.perf_counter() - start_time
- nb_generated = nb_c_quizzes_per_model.min().item()
-
- if last_log < 0 or duration > last_log + 5:
- last_log = duration
- if nb_generated > 0:
- if nb_generated < wanted_nb:
- d = (wanted_nb - nb_generated) * duration / nb_generated
- e = (
- datetime.datetime.now() + datetime.timedelta(seconds=d)
- ).strftime("%a %H:%M")
- else:
- e = "now!"
- else:
- e = "???"
-
- log_string(
- f"nb_generated {bag_len(record_c_quizzes)} model {generator_id} (finishes {e} -- {int((nb_generated * 3600)/duration)}/h)"
- )
-
- duration = time.perf_counter() - start_time
-
- log_string(f"generate_c_quizz_speed {int(3600 * wanted_nb / duration)}/h")
-
- c_quizzes = torch.cat(record_c_quizzes, dim=0)
- agreements = torch.cat(record_agreements, dim=0)
-
- return c_quizzes.to("cpu"), agreements.to("cpu")
-
-
-######################################################################
-
-
def generate_c_quizzes(models, nb, local_device=main_device):
record = []
nb_validated = 0
+
+ start_time = time.perf_counter()
+ last_log = -1
+
while nb_validated < nb:
model = models[torch.randint(len(models), (1,)).item()]
model = copy.deepcopy(model).to(local_device).eval()
log_string(f"generate_c_quizzes {nb_validated}")
+ #####################
+
+ duration = time.perf_counter() - start_time
+
+ if last_log < 0 or duration > last_log + 10:
+ last_log = duration
+ if nb_validated > 0:
+ if nb_validated < wanted_nb:
+ d = (wanted_nb - nb_validated) * duration / nb_validated
+ e = (
+ datetime.datetime.now() + datetime.timedelta(seconds=d)
+ ).strftime("%a %H:%M")
+ else:
+ e = "now!"
+ else:
+ e = "???"
+
+ log_string(
+ f"nb_validated {nb_validated} model {generator_id} (finishes {e} -- {int((nb_validated * 3600)/duration)}/h)"
+ )
+
+ #####################
+
+ duration = time.perf_counter() - start_time
+
+ log_string(f"generate_c_quizz_speed {int(3600 * wanted_nb / duration)}/h")
+
return torch.cat(record)