projects
/
picoclvr.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (parent:
9d8e9d4
)
Update.
author
François Fleuret
<francois@fleuret.org>
Tue, 16 Jan 2024 13:31:03 +0000
(14:31 +0100)
committer
François Fleuret
<francois@fleuret.org>
Tue, 16 Jan 2024 13:31:03 +0000
(14:31 +0100)
grid.py
patch
|
blob
|
history
diff --git
a/grid.py
b/grid.py
index
268f4ee
..
2135710
100755
(executable)
--- a/
grid.py
+++ b/
grid.py
@@
-9,10
+9,6
@@
import math
import torch, torchvision
import torch.nn.functional as F
import torch, torchvision
import torch.nn.functional as F
-name_shapes = ["A", "B", "C", "D", "E", "F"]
-
-name_colors = ["red", "yellow", "blue", "green", "white", "purple"]
-
######################################################################
######################################################################
@@
-23,20
+19,24
@@
class GridFactory:
max_nb_items=4,
max_nb_transformations=3,
nb_questions=4,
max_nb_items=4,
max_nb_transformations=3,
nb_questions=4,
+ nb_shapes=6,
+ nb_colors=6,
):
assert size % 2 == 0
self.size = size
self.max_nb_items = max_nb_items
self.max_nb_transformations = max_nb_transformations
self.nb_questions = nb_questions
):
assert size % 2 == 0
self.size = size
self.max_nb_items = max_nb_items
self.max_nb_transformations = max_nb_transformations
self.nb_questions = nb_questions
+ self.name_shapes = ["A", "B", "C", "D", "E", "F"]
+ self.name_colors = ["red", "yellow", "blue", "green", "white", "purple"]
def generate_scene(self):
nb_items = torch.randint(self.max_nb_items - 1, (1,)).item() + 2
col = torch.full((self.size * self.size,), -1)
shp = torch.full((self.size * self.size,), -1)
def generate_scene(self):
nb_items = torch.randint(self.max_nb_items - 1, (1,)).item() + 2
col = torch.full((self.size * self.size,), -1)
shp = torch.full((self.size * self.size,), -1)
- a = torch.randperm(len(
name_colors) * len(
name_shapes))[:nb_items]
- col[:nb_items] = a % len(name_colors)
- shp[:nb_items] = a // len(name_colors)
+ a = torch.randperm(len(
self.name_colors) * len(self.
name_shapes))[:nb_items]
+ col[:nb_items] = a % len(
self.
name_colors)
+ shp[:nb_items] = a // len(
self.
name_colors)
i = torch.randperm(self.size * self.size)
col = col[i]
shp = shp[i]
i = torch.randperm(self.size * self.size)
col = col[i]
shp = shp[i]
@@
-76,12
+76,15
@@
class GridFactory:
# for i in range(self.size):
# for j in range(self.size):
# if col[i,j] >= 0:
# for i in range(self.size):
# for j in range(self.size):
# if col[i,j] >= 0:
- # print(f"at ({i},{j}) {
name_colors[col[i,j]]} {
name_shapes[shp[i,j]]}")
+ # print(f"at ({i},{j}) {
self.name_colors[col[i,j]]} {self.
name_shapes[shp[i,j]]}")
for i in range(self.size):
for j in range(self.size):
if col[i, j] >= 0:
for i in range(self.size):
for j in range(self.size):
if col[i, j] >= 0:
- print(f"{name_colors[col[i,j]][0]}{name_shapes[shp[i,j]]}", end="")
+ print(
+ f"{self.name_colors[col[i,j]][0]}{self.name_shapes[shp[i,j]]}",
+ end="",
+ )
elif j == 0:
print(" +", end="")
else:
elif j == 0:
print(" +", end="")
else:
@@
-103,7
+106,7
@@
class GridFactory:
for i in range(self.size):
for j in range(self.size):
if col[i, j] >= 0:
for i in range(self.size):
for j in range(self.size):
if col[i, j] >= 0:
- n = f"{
name_colors[col[i,j]]} {
name_shapes[shp[i,j]]}"
+ n = f"{
self.name_colors[col[i,j]]} {self.
name_shapes[shp[i,j]]}"
properties += [f"a {n} at {i} {j}"]
return properties
properties += [f"a {n} at {i} {j}"]
return properties
@@
-116,7
+119,9
@@
class GridFactory:
for i1 in range(self.size):
for j1 in range(self.size):
if col[i1, j1] >= 0:
for i1 in range(self.size):
for j1 in range(self.size):
if col[i1, j1] >= 0:
- n1 = f"{name_colors[col[i1,j1]]} {name_shapes[shp[i1,j1]]}"
+ n1 = (
+ f"{self.name_colors[col[i1,j1]]} {self.name_shapes[shp[i1,j1]]}"
+ )
properties += [f"there is a {n1}"]
if i1 < self.size // 2:
properties += [f"a {n1} is in the top half"]
properties += [f"there is a {n1}"]
if i1 < self.size // 2:
properties += [f"a {n1} is in the top half"]
@@
-129,7
+134,7
@@
class GridFactory:
for i2 in range(self.size):
for j2 in range(self.size):
if col[i2, j2] >= 0:
for i2 in range(self.size):
for j2 in range(self.size):
if col[i2, j2] >= 0:
- n2 = f"{
name_colors[col[i2,j2]]} {
name_shapes[shp[i2,j2]]}"
+ n2 = f"{
self.name_colors[col[i2,j2]]} {self.
name_shapes[shp[i2,j2]]}"
if i1 > i2:
properties += [f"a {n1} is below a {n2}"]
if i1 < i2:
if i1 > i2:
properties += [f"a {n1} is below a {n2}"]
if i1 < i2: