- test_loss = F.cross_entropy(output, input)
-
- acc_test_loss += test_loss.item() * input.size(0)
-
- train_loss = acc_train_loss / train_input.size(0)
- test_loss = acc_test_loss / test_input.size(0)
-
- logger(f"vqae train {k} lr {lr} train_loss {train_loss} test_loss {test_loss}")
- sys.stdout.flush()
-
- return encoder, quantizer, decoder
-
-
-######################################################################
-
-
-def scene2tensor(xh, yh, scene, size):
- width, height = size, size
- pixel_map = torch.ByteTensor(width, height, 4).fill_(255)
- data = pixel_map.numpy()
- surface = cairo.ImageSurface.create_for_data(
- data, cairo.FORMAT_ARGB32, width, height
- )
-
- ctx = cairo.Context(surface)
- ctx.set_fill_rule(cairo.FILL_RULE_EVEN_ODD)
-
- for b in scene:
- ctx.move_to(b.x * size, b.y * size)
- ctx.rel_line_to(b.w * size, 0)
- ctx.rel_line_to(0, b.h * size)
- ctx.rel_line_to(-b.w * size, 0)
- ctx.close_path()
- ctx.set_source_rgba(
- b.r / (Box.nb_rgb_levels - 1),
- b.g / (Box.nb_rgb_levels - 1),
- b.b / (Box.nb_rgb_levels - 1),
- 1.0,
- )
- ctx.fill()
-
- hs = size * 0.1
- ctx.set_source_rgba(0.0, 0.0, 0.0, 1.0)
- ctx.move_to(xh * size - hs / 2, yh * size - hs / 2)
- ctx.rel_line_to(hs, 0)
- ctx.rel_line_to(0, hs)
- ctx.rel_line_to(-hs, 0)
- ctx.close_path()
- ctx.fill()
-
- return (
- pixel_map[None, :, :, :3]
- .flip(-1)
- .permute(0, 3, 1, 2)
- .long()
- .mul(Box.nb_rgb_levels)
- .floor_divide(256)
- )
-
-
-def random_scene(nb_insert_attempts=3):
- scene = []
- colors = [
- ((Box.nb_rgb_levels - 1), 0, 0),
- (0, (Box.nb_rgb_levels - 1), 0),
- (0, 0, (Box.nb_rgb_levels - 1)),
- ((Box.nb_rgb_levels - 1), (Box.nb_rgb_levels - 1), 0),
- (
- (Box.nb_rgb_levels * 2) // 3,
- (Box.nb_rgb_levels * 2) // 3,
- (Box.nb_rgb_levels * 2) // 3,
- ),
- ]
-
- for k in range(nb_insert_attempts):
- wh = torch.rand(2) * 0.2 + 0.2
- xy = torch.rand(2) * (1 - wh)
- c = colors[torch.randint(len(colors), (1,))]
- b = Box(
- xy[0].item(), xy[1].item(), wh[0].item(), wh[1].item(), c[0], c[1], c[2]
- )
- if not b.collision(scene):
- scene.append(b)
-
- return scene
-
-
-def generate_episode(steps, size=64):
- delta = 0.1
- effects = [
- (False, 0, 0),
- (False, delta, 0),
- (False, 0, delta),
- (False, -delta, 0),
- (False, 0, -delta),
- (True, delta, 0),
- (True, 0, delta),
- (True, -delta, 0),
- (True, 0, -delta),
- ]
-
- while True:
- frames = []
-
- scene = random_scene()
- xh, yh = tuple(x.item() for x in torch.rand(2))
-
- actions = torch.randint(len(effects), (len(steps),))
- nb_changes = 0
-
- for s, a in zip(steps, actions):
- if s:
- frames.append(scene2tensor(xh, yh, scene, size=size))
-
- grasp, dx, dy = effects[a]
-
- if grasp:
- for b in scene:
- if b.x <= xh and b.x + b.w >= xh and b.y <= yh and b.y + b.h >= yh:
- x, y = b.x, b.y
- b.x += dx
- b.y += dy
- if (
- b.x < 0
- or b.y < 0
- or b.x + b.w > 1
- or b.y + b.h > 1
- or b.collision(scene)
- ):
- b.x, b.y = x, y
- else:
- xh += dx
- yh += dy
- nb_changes += 1
- else:
- x, y = xh, yh
- xh += dx
- yh += dy
- if xh < 0 or xh > 1 or yh < 0 or yh > 1:
- xh, yh = x, y
-
- if nb_changes > len(steps) // 3:
- break
-
- return frames, actions