projects
/
mygpt.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[mygpt.git]
/
picoclvr.py
diff --git
a/picoclvr.py
b/picoclvr.py
index
80e5fd0
..
2d57505
100755
(executable)
--- a/
picoclvr.py
+++ b/
picoclvr.py
@@
-93,11
+93,11
@@
def all_properties(height, width, nb_squares, square_i, square_j, square_c):
######################################################################
######################################################################
-def generate(nb, height
= 6, width = 8
,
+def generate(nb, height
, width
,
max_nb_squares = 5, max_nb_properties = 10,
max_nb_squares = 5, max_nb_properties = 10,
-
many_colors = False
):
+
nb_colors = 5
):
- nb_colors = len(color_tokens) - 1 if many_colors else max_nb_squares
+ assert nb_colors >= max_nb_squares and nb_colors <= len(color_tokens) - 1
descr = [ ]
descr = [ ]
@@
-129,7
+129,7
@@
def generate(nb, height = 6, width = 8,
######################################################################
######################################################################
-def descr2img(descr, height
= 6, width = 8
):
+def descr2img(descr, height
, width
):
if type(descr) == list:
return torch.cat([ descr2img(d, height, width) for d in descr ], 0)
if type(descr) == list:
return torch.cat([ descr2img(d, height, width) for d in descr ], 0)
@@
-152,7
+152,7
@@
def descr2img(descr, height = 6, width = 8):
######################################################################
######################################################################
-def descr2properties(descr, height
= 6, width = 8
):
+def descr2properties(descr, height
, width
):
if type(descr) == list:
return [ descr2properties(d, height, width) for d in descr ]
if type(descr) == list:
return [ descr2properties(d, height, width) for d in descr ]
@@
-181,7
+181,7
@@
def descr2properties(descr, height = 6, width = 8):
######################################################################
######################################################################
-def nb_missing_properties(descr, height
= 6, width = 8
):
+def nb_missing_properties(descr, height
, width
):
if type(descr) == list:
return [ nb_missing_properties(d, height, width) for d in descr ]
if type(descr) == list:
return [ nb_missing_properties(d, height, width) for d in descr ]