Minor update.
[pysvrt.git] / vision_problem_13.cc
1 /*
2  *  svrt is the ``Synthetic Visual Reasoning Test'', an image
3  *  generator for evaluating classification performance of machine
4  *  learning systems, humans and primates.
5  *
6  *  Copyright (c) 2009 Idiap Research Institute, http://www.idiap.ch/
7  *  Written by Francois Fleuret <francois.fleuret@idiap.ch>
8  *
9  *  This file is part of svrt.
10  *
11  *  svrt is free software: you can redistribute it and/or modify it
12  *  under the terms of the GNU General Public License version 3 as
13  *  published by the Free Software Foundation.
14  *
15  *  svrt is distributed in the hope that it will be useful, but
16  *  WITHOUT ANY WARRANTY; without even the implied warranty of
17  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18  *  General Public License for more details.
19  *
20  *  You should have received a copy of the GNU General Public License
21  *  along with svrt.  If not, see <http://www.gnu.org/licenses/>.
22  *
23  */
24
25 #include "vision_problem_13.h"
26 #include "shape.h"
27
28 VisionProblem_13::VisionProblem_13() { }
29
30 void VisionProblem_13::generate(int label, Vignette *vignette) {
31   Shape big_shape, small_shape;
32   int big_xs1, big_ys1, small_xs1, small_ys1;
33   int big_xs2, big_ys2, small_xs2, small_ys2;
34   int translated_small_xs = 0, translated_small_ys = 0;
35   Vignette tmp;
36   const int dist_min = Vignette::width/4;
37   int nb_attempts;
38   const int max_nb_attempts = 100;
39
40   do {
41     nb_attempts = 0;
42     do {
43
44       vignette->clear();
45
46       big_shape.randomize(big_part_size / 2, big_part_hole_size / 2);
47
48       tmp.clear();
49       do {
50         big_xs1 = int(random_uniform_0_1() * Vignette::width);
51         big_ys1 = int(random_uniform_0_1() * Vignette::height);
52         nb_attempts++;
53       } while(nb_attempts < max_nb_attempts &&
54               big_shape.overwrites(vignette, big_xs1, big_ys1));
55
56       if(nb_attempts < max_nb_attempts) {
57         big_shape.draw(0, vignette, big_xs1, big_ys1);
58         big_shape.draw(0, &tmp, big_xs1, big_ys1);
59         for(int k = 0; k < dist_min; k++) tmp.grow();
60       }
61
62       do {
63         small_shape.randomize(small_part_size / 2, small_part_hole_size / 2);
64         small_xs1 = int(random_uniform_0_1() * Vignette::width);
65         small_ys1 = int(random_uniform_0_1() * Vignette::height);
66         nb_attempts++;
67       } while(nb_attempts < max_nb_attempts &&
68               (!small_shape.overwrites(&tmp, small_xs1, small_ys1) ||
69                small_shape.overwrites(vignette, small_xs1, small_ys1)));
70
71       if(nb_attempts < max_nb_attempts) {
72         small_shape.draw(1, vignette, small_xs1, small_ys1);
73       }
74
75       tmp.clear();
76       do {
77         big_xs2 = int(random_uniform_0_1() * Vignette::width);
78         big_ys2 = int(random_uniform_0_1() * Vignette::height);
79         nb_attempts++;
80       } while(nb_attempts < max_nb_attempts &&
81               big_shape.overwrites(vignette, big_xs2, big_ys2));
82       if(nb_attempts < max_nb_attempts) {
83         big_shape.draw(2, vignette, big_xs2, big_ys2);
84         big_shape.draw(0, &tmp, big_xs2, big_ys2);
85         for(int k = 0; k < dist_min; k++) tmp.grow();
86
87         translated_small_xs = small_xs1 + (big_xs2 - big_xs1);
88         translated_small_ys = small_ys1 + (big_ys2 - big_ys1);
89       }
90     } while(nb_attempts < max_nb_attempts &&
91             small_shape.overwrites(vignette,
92                                    translated_small_xs,
93                                    translated_small_ys));
94
95     if(label) {
96       small_xs2 = translated_small_xs;
97       small_ys2 = translated_small_ys;
98     } else {
99       do {
100         small_xs2 = int(random_uniform_0_1() * Vignette::width);
101         small_ys2 = int(random_uniform_0_1() * Vignette::height);
102         nb_attempts++;
103       } while(nb_attempts < max_nb_attempts &&
104               (sq(small_xs2 - translated_small_xs) + sq(small_ys2 - translated_small_ys) < sq(dist_min) ||
105                !small_shape.overwrites(&tmp, small_xs2, small_ys2) ||
106                small_shape.overwrites(vignette, small_xs2, small_ys2)));
107     }
108   } while(nb_attempts >= max_nb_attempts);
109   small_shape.draw(3, vignette, small_xs2, small_ys2);
110 }