deterministic_synthesis,
forbidden_tokens=None,
logit_biases=None,
- progress_bar_desc="autoregression",
+ progress_bar_desc=None,
device=torch.device("cpu"),
):
assert input.size() == ar_mask.size()
######################################################################
-import sky
-
class QuizzMachine:
def make_ar_mask(self, input):
def __init__(
self,
+ problem,
nb_train_samples,
nb_test_samples,
batch_size,
):
super().__init__()
- self.problem = sky.Sky(height=6, width=8, nb_birds=3, nb_iterations=2)
+ self.problem = problem
self.batch_size = batch_size
self.device = device
if result_dir is not None:
self.problem.save_quizzes(
- self.train_w_quizzes[:72], result_dir, f"culture_w_quizzes", logger
+ self.train_w_quizzes[:72], result_dir, "culture_w_quizzes"
)
def batches(self, split="train", desc=None):
)
self.problem.save_quizzes(
- result[:72],
- result_dir,
- f"culture_prediction_{n_epoch:04d}_{model.id:02d}",
- logger,
+ result[:72], result_dir, f"culture_prediction_{n_epoch:04d}_{model.id:02d}"
)
return main_test_accuracy
def create_c_quizzes(
self,
+ nb,
+ model_for_generation,
+ models_for_validation,
+ min_ave_seq_logproba,
n_epoch,
result_dir,
logger,
- nb,
- model,
- other_models,
- min_ave_seq_logproba,
):
###############################################################
# Generate quizzes with model
seq_logproba[...] = 0
masked_inplace_autoregression(
- model=model,
+ model=model_for_generation,
batch_size=self.batch_size,
input=c_quizzes,
ar_mask=ar_mask,
seq_logproba=seq_logproba,
temperature=temperature,
deterministic_synthesis=False,
- progress_bar_desc="sampling c_quizzes",
+ # progress_bar_desc="sampling c_quizzes",
device=self.device,
)
ave_seq_logproba = seq_logproba.mean()
- logger(f"{ave_seq_logproba=} {min_ave_seq_logproba=}")
-
if min_ave_seq_logproba is None:
break
# Oh man that's ugly
- if ave_seq_logproba < min_ave_seq_logproba * 1.1:
+ if ave_seq_logproba < min_ave_seq_logproba:
if d_temperature > 0:
d_temperature *= -1 / 3
temperature += d_temperature
- elif ave_seq_logproba > min_ave_seq_logproba:
+ elif ave_seq_logproba > min_ave_seq_logproba * 0.99:
if d_temperature < 0:
d_temperature *= -1 / 3
temperature += d_temperature
else:
break
- logger(f"chaging temperature to {temperature}")
+ logger(f"changing temperature to {temperature}")
###############################################################
# Create the reverse quizzes
nb_correct = []
- for m in other_models:
+ for model in models_for_validation:
result = c_quizzes.clone()
masked_inplace_autoregression(
- model=m,
+ model=model,
batch_size=self.batch_size,
input=result,
ar_mask=ar_mask,
seq_logproba=seq_logproba,
temperature=1.0,
deterministic_synthesis=True,
- progress_bar_desc="solving c_quizzes",
+ # progress_bar_desc="solving c_quizzes",
device=self.device,
)
reverse_result = reverse_c_quizzes.clone()
masked_inplace_autoregression(
- model=m,
+ model=model,
batch_size=self.batch_size,
input=reverse_result,
ar_mask=ar_mask,
seq_logproba=seq_logproba,
temperature=1.0,
deterministic_synthesis=True,
- progress_bar_desc="solving reversed c_quizzes",
+ # progress_bar_desc="solving reversed c_quizzes",
device=self.device,
)