Added #include <unistd.h> for nice()
[svrt.git] / vision_problem_17.cc
1 /*
2  *  svrt is the ``Synthetic Visual Reasoning Test'', an image
3  *  generator for evaluating classification performance of machine
4  *  learning systems, humans and primates.
5  *
6  *  Copyright (c) 2009 Idiap Research Institute, http://www.idiap.ch/
7  *  Written by Francois Fleuret <francois.fleuret@idiap.ch>
8  *
9  *  This file is part of svrt.
10  *
11  *  svrt is free software: you can redistribute it and/or modify it
12  *  under the terms of the GNU General Public License version 3 as
13  *  published by the Free Software Foundation.
14  *
15  *  svrt is distributed in the hope that it will be useful, but
16  *  WITHOUT ANY WARRANTY; without even the implied warranty of
17  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18  *  General Public License for more details.
19  *
20  *  You should have received a copy of the GNU General Public License
21  *  along with selector.  If not, see <http://www.gnu.org/licenses/>.
22  *
23  */
24
25 #include "vision_problem_17.h"
26 #include "shape.h"
27
28 VisionProblem_17::VisionProblem_17() { }
29
30 void VisionProblem_17::generate(int label, Vignette *vignette) {
31   const int nb_shapes = 4;
32   int xs[nb_shapes], ys[nb_shapes];
33   int shape_number[nb_shapes];
34
35   ASSERT(nb_shapes == 4);
36
37   int too_ambiguous;
38
39   int error;
40
41   do {
42     Shape shape1, shape2;
43     shape1.randomize(part_size/2, hole_size/2);
44     shape2.randomize(part_size/2, hole_size/2);
45
46     //////////////////////////////////////////////////////////////////////
47
48     do {
49       for(int n = 0; n < nb_shapes; n++) {
50         if(n < nb_shapes - 1) {
51           shape_number[n] = 0;
52         } else {
53           shape_number[n] = 1;
54         }
55         xs[n] = int(drand48() * (Vignette::width - part_size)) + part_size/2;
56         ys[n] = int(drand48() * (Vignette::width - part_size)) + part_size/2;
57       }
58
59       scalar_t a = scalar_t(xs[1] - xs[0]), b = scalar_t(ys[1] - ys[0]);
60       scalar_t c = scalar_t(xs[2] - xs[1]), d = scalar_t(ys[2] - ys[1]);
61       scalar_t det = a * d - b * c;
62       scalar_t u = scalar_t(xs[1] * xs[1] - xs[0] * xs[0] + ys[1] * ys[1] - ys[0] * ys[0]);
63       scalar_t v = scalar_t(xs[2] * xs[2] - xs[1] * xs[1] + ys[2] * ys[2] - ys[1] * ys[1]);
64       scalar_t xc = 1/(2 * det) *(  d * u - b * v);
65       scalar_t yc = 1/(2 * det) *(- c * u + a * v);
66
67       if(label == 1) {
68         xs[nb_shapes - 1] = int(xc);
69         ys[nb_shapes - 1] = int(yc);
70         too_ambiguous = 0;
71       } else {
72         too_ambiguous = sqrt(sq(scalar_t(xs[nb_shapes - 1]) - xc) +
73                              sq(scalar_t(ys[nb_shapes - 1]) - yc)) < scalar_t(part_size);
74       }
75     } while(too_ambiguous ||
76             cluttered_shapes(part_size, nb_shapes, xs, ys));
77
78     //////////////////////////////////////////////////////////////////////
79
80     vignette->clear();
81
82     error = 0;
83     for(int n = 0; n < nb_shapes; n++) {
84       if(shape_number[n] == 0) {
85         error |= shape1.overwrites(vignette, xs[n], ys[n]);
86         shape1.draw(vignette, xs[n], ys[n]);
87       } else {
88         error |= shape2.overwrites(vignette, xs[n], ys[n]);
89         shape2.draw(vignette, xs[n], ys[n]);
90       }
91     }
92   } while(error);
93 }