######################################################################
-def run_tests(model, task):
+def run_tests(model, task, deterministic_synthesis):
with torch.autograd.no_grad():
model.eval()
model=model,
result_dir=args.result_dir,
logger=log_string,
- deterministic_synthesis=args.deterministic_synthesis,
+ deterministic_synthesis=deterministic_synthesis,
)
test_perplexity = math.exp(min(100, acc_test_loss / nb_test_samples))
one_epoch(model, task, learning_rate)
- run_tests(model, task)
+ run_tests(model, task, deterministic_synthesis=True)
+
+ # --------------------------------------------
time_current_result = datetime.datetime.now()
if time_pred_result is not None:
)
time_pred_result = time_current_result
+ # --------------------------------------------
+
checkpoint = {
"nb_epochs_finished": n_epoch + 1,
"model_state": model.state_dict(),
nb,
height,
width,
- max_nb_obj=len(colors) - 2,
+ max_nb_obj=colors.size(0) - 2,
nb_iterations=2,
):
f_start = torch.zeros(nb, height, width, dtype=torch.int64)
for n in range(nb):
nb_fish = torch.randint(max_nb_obj, (1,)).item() + 1
- for c in range(nb_fish):
+ for c in torch.randperm(colors.size(0) - 2)[:nb_fish].sort().values:
i, j = (
torch.randint(height - 2, (1,))[0] + 1,
torch.randint(width - 2, (1,))[0] + 1,
height, width = 6, 8
start_time = time.perf_counter()
- seq = generate(nb=64, height=height, width=width)
+ seq = generate(nb=64, height=height, width=width, max_nb_obj=3)
delay = time.perf_counter() - start_time
print(f"{seq.size(0)/delay:02f} samples/s")