parser.add_argument(
"--task",
type=str,
- default="picoclvr",
- help="picoclvr, mnist, maze, snake, stack, expr, world",
+ default="sandbox",
+ help="sandbox, picoclvr, mnist, maze, snake, stack, expr, world",
)
parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
######################################################################
default_args = {
+ "sandbox": {
+ "nb_epochs": 10,
+ "batch_size": 25,
+ "nb_train_samples": 25000,
+ "nb_test_samples": 10000,
+ },
"picoclvr": {
"nb_epochs": 25,
"batch_size": 25,
"world": {
"nb_epochs": 10,
"batch_size": 25,
- "nb_train_samples": 125000,
+ "nb_train_samples": 25000,
"nb_test_samples": 1000,
},
}
######################################################################
-if args.task == "picoclvr":
+if args.task == "sandbox":
+ task = tasks.SandBox(
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.batch_size,
+ logger=log_string,
+ device=device,
+ )
+
+elif args.task == "picoclvr":
task = tasks.PicoCLVR(
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
pass
+######################################################################
+
+
+class Problem:
+ def generate(nb):
+ pass
+
+ def perf(seq, logger):
+ pass
+
+
+class ProblemByheart(Problem):
+ def __init__(self):
+ pass
+
+
+class SandBox(Task):
+ def __init__(
+ self,
+ nb_train_samples,
+ nb_test_samples,
+ batch_size,
+ logger=None,
+ device=torch.device("cpu"),
+ ):
+ super().__init__()
+
+ self.batch_size = batch_size
+
+ def generate_sequences(nb_samples):
+ problem_indexes = torch.randint(len(problems), (nb_samples,))
+ nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0)
+ print(f"{nb_samples_per_problem}")
+
+ self.train_input = generate_sequences(nb_train_samples)
+ self.test_input = generate_sequences(nb_test_samples)
+
+ self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
+
+ def batches(self, split="train", nb_to_use=-1, desc=None):
+ assert split in {"train", "test"}
+ input = self.train_input if split == "train" else self.test_input
+ if nb_to_use > 0:
+ input = input[:nb_to_use]
+ if desc is None:
+ desc = f"epoch-{split}"
+ for batch in tqdm.tqdm(
+ input.split(self.batch_size), dynamic_ncols=True, desc=desc
+ ):
+ yield batch
+
+ def vocabulary_size(self):
+ return self.nb_codes
+
+ def produce_results(
+ self, n_epoch, model, result_dir, logger, deterministic_synthesis
+ ):
+ # logger(
+ # f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
+ # )
+ pass
+
+
######################################################################
import picoclvr
pruner_train=None,
pruner_eval=None,
):
+ super().__init__()
+
def generate_descr(nb, cache_suffix, pruner):
return picoclvr.generate(
nb,
def __init__(
self, nb_train_samples, nb_test_samples, batch_size, device=torch.device("cpu")
):
+ super().__init__()
+
self.nb_train_samples = (nb_train_samples,)
self.nb_test_samples = (nb_test_samples,)
self.batch_size = batch_size
nb_walls,
device=torch.device("cpu"),
):
+ super().__init__()
+
self.batch_size = batch_size
self.height = height
self.width = width
prompt_length,
device=torch.device("cpu"),
):
+ super().__init__()
+
self.batch_size = batch_size
self.height = height
self.width = width
fraction_values_for_train=None,
device=torch.device("cpu"),
):
+ super().__init__()
+
self.batch_size = batch_size
self.nb_steps = nb_steps
self.nb_stacks = nb_stacks
batch_size,
device=torch.device("cpu"),
):
+ super().__init__()
+
self.batch_size = batch_size
self.device = device
device=torch.device("cpu"),
device_storage=torch.device("cpu"),
):
+ super().__init__()
+
self.batch_size = batch_size
self.device = device
return s
+def loss_H(binary_logits, h_threshold=1):
+ p = binary_logits.sigmoid().mean(0)
+ h = (-p.xlogy(p) - (1 - p).xlogy(1 - p)) / math.log(2)
+ h.clamp_(max=h_threshold)
+ return h_threshold - h.mean()
+
+
def train_encoder(
train_input,
test_input,
depth=2,
dim_hidden=48,
nb_bits_per_token=8,
+ lambda_entropy=0.0,
lr_start=1e-3,
lr_end=1e-4,
nb_epochs=10,
train_loss = F.cross_entropy(output, input)
+ if lambda_entropy > 0:
+ loss = loss + lambda_entropy * loss_H(z, h_threshold=0.5)
+
acc_train_loss += train_loss.item() * input.size(0)
optimizer.zero_grad()
)
-def random_scene():
+def random_scene(nb_insert_attempts=3):
scene = []
colors = [
((Box.nb_rgb_levels - 1), 0, 0),
),
]
- for k in range(10):
+ for k in range(nb_insert_attempts):
wh = torch.rand(2) * 0.2 + 0.2
xy = torch.rand(2) * (1 - wh)
c = colors[torch.randint(len(colors), (1,))]
xh, yh = tuple(x.item() for x in torch.rand(2))
actions = torch.randint(len(effects), (len(steps),))
- change = False
+ nb_changes = 0
for s, a in zip(steps, actions):
if s:
frames.append(scene2tensor(xh, yh, scene, size=size))
- g, dx, dy = effects[a]
- if g:
+ grasp, dx, dy = effects[a]
+
+ if grasp:
for b in scene:
if b.x <= xh and b.x + b.w >= xh and b.y <= yh and b.y + b.h >= yh:
x, y = b.x, b.y
else:
xh += dx
yh += dy
- change = True
+ nb_changes += 1
else:
x, y = xh, yh
xh += dx
if xh < 0 or xh > 1 or yh < 0 or yh > 1:
xh, yh = x, y
- if change:
+ if nb_changes > len(steps) // 3:
break
return frames, actions
steps = [True] + [False] * (nb_steps + 1) + [True]
train_input, train_actions = generate_episodes(nb_train_samples, steps)
- train_input, train_actions = train_input.to(device_storage), train_actions.to(device_storage)
+ train_input, train_actions = train_input.to(device_storage), train_actions.to(
+ device_storage
+ )
test_input, test_actions = generate_episodes(nb_test_samples, steps)
- test_input, test_actions = test_input.to(device_storage), test_actions.to(device_storage)
+ test_input, test_actions = test_input.to(device_storage), test_actions.to(
+ device_storage
+ )
encoder, quantizer, decoder = train_encoder(
- train_input, test_input, nb_epochs=nb_epochs, logger=logger, device=device
+ train_input,
+ test_input,
+ lambda_entropy=1.0,
+ nb_epochs=nb_epochs,
+ logger=logger,
+ device=device,
)
encoder.train(False)
quantizer.train(False)
seq = []
p = pow2.to(device)
for x in input.split(batch_size):
- x=x.to(device)
+ x = x.to(device)
z = encoder(x)
ze_bool = (quantizer(z) >= 0).long()
output = (