Cleaning up for DL.
[universe.git] / main.cc
1
2 // Written and (C) by Francois Fleuret
3 // Contact <francois.fleuret@idiap.ch> for comments & bug reports
4
5 #include <iostream>
6 #include <fstream>
7 #include <cmath>
8 #include <stdio.h>
9 #include <stdlib.h>
10 #include <stdint.h>
11 #include <errno.h>
12 #include <string.h>
13
14 using namespace std;
15
16 #include "misc.h"
17 #include "task.h"
18 #include "simple_window.h"
19 #include "universe.h"
20 #include "plotter.h"
21 #include "manipulator.h"
22 #include "intelligence.h"
23
24 #include "canvas_cairo.h"
25
26 void generate_png(Universe *universe, scalar_t scale, FILE *file) {
27   CanvasCairo canvas(scale, universe->width(), universe->height());
28   universe->draw(&canvas);
29   canvas.write_png(file);
30 }
31
32 // To train
33 // ./main --task hit_shape.task 0 --action-mode=random --nb-ticks=5000 --proportion-for-training=0.5 --save-file=dump.mem --no-window
34
35 // To test
36 // ./main --task hit_shape.task 0 --action-mode=intelligent --load-file=dump.mem
37
38 //////////////////////////////////////////////////////////////////////
39
40 void check_opt(int argc, char **argv, int n_opt, int n, const char *help) {
41   if(n_opt + n >= argc) {
42     cerr << "Missing argument for " << argv[n_opt] << "." << endl;
43     cerr << "Expecting " << help << "." << endl;
44     exit(1);
45   }
46 }
47
48 void print_help_and_exit(int e) {
49   cout << "Arguments:" << endl;
50   cout << "  --no-window" << endl;
51   cout << "  --nb-ticks=<int: number of ticks>" << endl;
52   cout << "  --nb-training-iterations=<int: number of training iterations>" << endl;
53   cout << "  --load-file=<filename: dump file>" << endl;
54   cout << "  --save-file=<filename: dump file>" << endl;
55   cout << "  --proportion-for-training=<float: proportion of samples for training>" << endl;
56   cout << "  --action-mode=<idle|random|intelligent>" << endl;
57   cout << "  --task <filename: task to load> <int: degree>" << endl;
58   exit(e);
59 }
60
61 //////////////////////////////////////////////////////////////////////
62
63 int main(int argc, char **argv) {
64
65   const int buffer_size = 256;
66   char intelligence_load_file[buffer_size] = "", intelligence_save_file[buffer_size] = "";
67
68   Task *task = 0;
69   int task_degree = 0;
70
71   int nb_ticks = 10000;
72   int nb_training_iterations = 10;
73   scalar_t proportion_for_training = -1;
74
75   Polygon *grabbed_polygon = 0;
76   scalar_t relative_grab_x = 0, relative_grab_y = 0;
77
78   bool quit = false;
79   bool press_shift = false;
80   enum { IDLE, RANDOM, INTELLIGENT } action_mode = IDLE;
81   bool got_event = true;
82   int current_action = 0;
83
84   scalar_t last_hand_x = 0, last_hand_y = 0;
85   Polygon *last_grabbing = 0;
86   bool no_window = false;
87
88   //////////////////////////////////////////////////////////////////////
89   //                    Parsing the shell arguments
90   //////////////////////////////////////////////////////////////////////
91
92   int i = 1;
93   while(i < argc) {
94     if(argc == 1 || strcmp(argv[i], "--help") == 0) print_help_and_exit(0);
95
96     else if(strcmp(argv[i], "--test") == 0) {
97       test_approximer();
98       exit(0);
99     }
100
101     else if(strcmp(argv[i], "--task") == 0) {
102       check_opt(argc, argv, i, 2, "<filename: task to load> <int: degree>");
103       if(task) {
104         cerr << "Can not load two tasks." << endl;
105         exit(1);
106       }
107       task = load_task(argv[i+1]);
108       task_degree = atoi(argv[i+2]);
109       i += 3;
110
111     } else if(strncmp(argv[i], "--", 2) == 0) {
112       char variable_name[buffer_size] = "", variable_value[buffer_size] = "";
113       char *o = argv[i]+2, *s = variable_name, *u = variable_value;
114       while(*o && *o != '=') *s++ = *o++;
115       if(*o) {
116         o++;
117         while(*o) *u++ = *o++;
118       }
119
120       if(strcmp(variable_name, "nb-ticks") == 0) {
121         nb_ticks = atoi(variable_value);
122       } else if(strcmp(variable_name, "nb-training-iterations") == 0) {
123         nb_training_iterations = atoi(variable_value);
124       } else if(strcmp(variable_name, "proportion-for-training") == 0) {
125         proportion_for_training = atof(variable_value);
126       } else if(strcmp(variable_name, "no-window") == 0) {
127         no_window = true;
128       } else if(strcmp(variable_name, "save-file") == 0) {
129         strcpy(intelligence_save_file, variable_value);
130       } else if(strcmp(variable_name, "load-file") == 0) {
131         strcpy(intelligence_load_file, variable_value);
132       } else if(strcmp(variable_name, "action-mode") == 0) {
133         if(strcmp(variable_value, "idle") == 0) {
134           action_mode = IDLE;
135         } else if(strcmp(variable_value, "random") == 0) {
136           action_mode = RANDOM;
137         } else if(strcmp(variable_value, "intelligent") == 0) {
138           action_mode = INTELLIGENT;
139         } else {
140           cerr << "The only known modes are idle, random and intelligent" << endl;
141           exit(1);
142         }
143       } else {
144         cerr << "Unknown option " << argv[i] << endl;
145         print_help_and_exit(1);
146       }
147       i++;
148     } else {
149       cerr << "Unknown option " << argv[i] << endl;
150       print_help_and_exit(1);
151     }
152   }
153
154   cout << "FlatLand, a toy universe for goal-planning experiments." << endl;
155
156   if(!task) {
157     task = load_task("dummy.so");
158     task_degree = 0;
159   }
160
161   cout << "Loaded task " << task->name()
162        << " with degree " << task_degree << "/" << task->nb_degrees()
163        << endl;
164
165   if(task_degree < 0 || task_degree >= task->nb_degrees()) {
166     cout << "Invalid degree: " << task_degree << "." << endl;
167     exit(1);
168   }
169
170   //////////////////////////////////////////////////////////////////////
171   //                      Various initializations
172   //////////////////////////////////////////////////////////////////////
173
174   Universe universe(100, task->width(), task->height());
175   task->init(&universe, task_degree);
176   Manipulator manipulator(task);
177   manipulator.force_move(task->width()/2, task->height()/2);
178
179   SimpleWindow *window_main = 0;
180   int window_main_fd = -1;
181
182 #ifdef CAIRO_SUPPORT
183   // cairo_t *window_main_cairo_cr = 0;
184 #endif
185
186   MapConcatener sensory_map(2);
187   sensory_map.add_map(&manipulator);
188   sensory_map.init();
189
190   MapExpander expanded_map(1000);
191   expanded_map.set_input(&sensory_map);
192   expanded_map.init();
193
194   Intelligence intelligence(&expanded_map, &manipulator, nb_ticks + 1, nb_training_iterations);
195   intelligence.update(0, 0.0);
196
197   if(intelligence_load_file[0]) {
198     cout << "Loading from " << intelligence_load_file << " ... " ;
199     cout.flush();
200     ifstream in(intelligence_load_file);
201     if(in.fail()) {
202       cerr << "error reading " << intelligence_load_file << "." << endl;
203       exit(1);
204     }
205     intelligence.load(in);
206     cout << "done." << endl ;
207   }
208
209   if(no_window) {
210     cout << "Started without windows." << endl;
211   } else {
212     window_main = new SimpleWindow("Universe (main window)", 4, 4, task->width(), task->height());
213     window_main_fd = window_main->file_descriptor();
214     window_main->map();
215 #ifdef CAIRO_SUPPORT
216     // window_main_cairo_cr = window_main->get_cairo_context_resource();
217 #endif
218     cout << "When the main window has the focus, press `q' to quit and click and drag to move" << endl
219          << "objects." << endl;
220   }
221
222   int tick = 0;
223   time_t last_t = 0;
224   scalar_t sum_reward = 0;
225
226   //////////////////////////////////////////////////////////////////////
227   //                         The main loop
228   //////////////////////////////////////////////////////////////////////
229
230   while(!quit && tick != nb_ticks) {
231
232     int r;
233     fd_set fds;
234
235 #ifdef CAIRO_SUPPORT
236     if(tick < 100) {
237       char buffer[1024];
238       sprintf(buffer, "frame-%06d.png", tick);
239       FILE *file = fopen(buffer, "w");
240       generate_png(&universe, 0.25, file);
241       cout << "Universe image saved in " << buffer << endl;
242       fclose(file);
243     }
244 #endif
245
246     if(window_main) {
247       struct timeval tv;
248       FD_ZERO (&fds);
249       FD_SET (window_main_fd, &fds);
250       tv.tv_sec = 0;
251       tv.tv_usec = 5000; // 0.05s
252       r = select(window_main_fd + 1, &fds, 0, 0, &tv);
253     } else r = 0;
254
255     time_t t = time(0);
256
257     if(t > last_t) {
258       last_t = t;
259       cout << tick << " " << sum_reward << "              \r"; cout.flush();
260     }
261
262     if(r == 0) { // No window event, thus it's the clock tick
263
264         int nb_it = 10;
265
266         bool changed = got_event;
267         got_event = false;
268
269         switch(action_mode) {
270         case IDLE:
271           break;
272         case RANDOM:
273           current_action = manipulator.random_action();
274           break;
275         case INTELLIGENT:
276           current_action = intelligence.best_action();
277           //           if(drand48() < 0.5) current_action = intelligence.best_action();
278           //           else                current_action = manipulator.random_action();
279           break;
280         }
281
282         manipulator.do_action(current_action);
283
284         scalar_t dt = 1.0/scalar_t(nb_it);
285         for(int k = 0; k < nb_it; k++) {
286           manipulator.update(dt, &universe);
287           task->update(dt, &universe, &manipulator);
288           universe.apply_gravity(dt, 0.0, 2.0);
289           changed |= universe.update(dt);
290         }
291
292         tick++;
293
294         changed |= manipulator.hand_x() != last_hand_x ||
295           manipulator.hand_y() != last_hand_y ||
296           manipulator.grabbing() != last_grabbing;
297
298         scalar_t reward = task->reward(&universe, &manipulator);
299         sum_reward += abs(reward);
300         intelligence.update(current_action, reward);
301         expanded_map.update_map();
302
303         if(changed) {
304           last_hand_x = manipulator.hand_x();
305           last_hand_y = manipulator.hand_y();
306           last_grabbing = manipulator.grabbing();
307
308           if(window_main) {
309             window_main->color(0.0, 0.0, 0.0);
310             window_main->color(1.0, 1.0, 1.0);
311             window_main->fill();
312             universe.draw(window_main);
313
314             task->draw(window_main);
315             manipulator.draw_on_universe(window_main);
316
317             if(grabbed_polygon) {
318               int x, y, delta = 3;
319               x = int(grabbed_polygon->absolute_x(relative_grab_x, relative_grab_y));
320               y = int(grabbed_polygon->absolute_y(relative_grab_x, relative_grab_y));
321               window_main->color(0.0, 0.0, 0.0);
322               window_main->draw_line(x - delta, y, x + delta, y);
323               window_main->draw_line(x, y - delta, x, y + delta);
324             }
325
326             window_main->show();
327           }
328         }
329
330     } else if(r > 0) { // We got window events, let's process them
331
332       got_event = true;
333
334       if(FD_ISSET(window_main_fd, &fds)) {
335
336         SimpleEvent se;
337
338         do {
339           se = window_main->event();
340
341           switch(se.type) {
342
343           case SimpleEvent::MOUSE_CLICK_PRESS:
344             {
345               switch(se.button) {
346
347               case 1:
348                 if(press_shift) {
349                   manipulator.force_move(se.x, se.y);
350                   manipulator.do_action(Manipulator::ACTION_GRAB);
351                 } else {
352                   grabbed_polygon = universe.pick_polygon(se.x, se.y);
353                   if(grabbed_polygon) {
354                     relative_grab_x = grabbed_polygon->relative_x(se.x, se.y);
355                     relative_grab_y = grabbed_polygon->relative_y(se.x, se.y);
356                   }
357                 }
358                 break;
359               case 4:
360                 {
361                   Polygon *g = universe.pick_polygon(se.x, se.y);
362                   if(g) g->_theta += M_PI/32;
363                 }
364                 break;
365               case 5:
366                 {
367                   Polygon *g = universe.pick_polygon(se.x, se.y);
368                   if(g) g->_theta -= M_PI/32;
369                 }
370                 break;
371               }
372             }
373             break;
374
375           case SimpleEvent::MOUSE_CLICK_RELEASE:
376             switch(se.button) {
377             case 1:
378               if(press_shift) manipulator.do_action(Manipulator::ACTION_RELEASE);
379               else            grabbed_polygon = 0;
380               break;
381             default:
382               break;
383             }
384
385           case SimpleEvent::MOUSE_MOTION:
386             {
387               if(press_shift) {
388                 manipulator.force_move(se.x, se.y);
389               } else if(grabbed_polygon) {
390                 scalar_t xf, yf, force_x, force_y, f, fmax = 100;
391                 xf = grabbed_polygon->absolute_x(relative_grab_x, relative_grab_y);
392                 yf = grabbed_polygon->absolute_y(relative_grab_x, relative_grab_y);
393                 force_x = se.x - xf;
394                 force_y = se.y - yf;
395                 f = sqrt(sq(force_x) + sq(force_y));
396                 if(f > fmax) { force_x = (force_x * fmax)/f; force_y = (force_y * fmax)/f; }
397                 grabbed_polygon->apply_force(0.1, xf, yf, force_x, force_y);
398               }
399               break;
400             }
401             break;
402
403           case SimpleEvent::KEY_PRESS:
404             {
405               if(strcmp(se.key, "q") == 0) {
406                 quit = true;
407               }
408
409               else if(strcmp(se.key, "s") == 0) {
410
411                 {
412                   Plotter plotter(int(universe.width()), int(universe.height()), 4);
413                   plotter.save_as_ppm(&universe, "/tmp/plotter.ppm", 16);
414                 }
415
416 #ifdef CAIRO_SUPPORT
417                 {
418                   FILE *file = fopen("/tmp/screenshot.png", "w");
419                   generate_png(&universe, 0.25, file);
420                   cout << "Universe image saved in /tmp/screenshot.png" << endl;
421                   fclose(file);
422                 }
423 #endif
424
425               }
426
427               else if(strcmp(se.key, "Shift_L") == 0 || strcmp(se.key, "Shift_R") == 0) {
428                 press_shift = true;
429               }
430
431               else if(strcmp(se.key, "Up") == 0) {
432                 manipulator.do_action(Manipulator::ACTION_MOVE_UP);
433               }
434
435               else if(strcmp(se.key, "Right") == 0) {
436                 manipulator.do_action(Manipulator::ACTION_MOVE_RIGHT);
437               }
438
439               else if(strcmp(se.key, "Down") == 0) {
440                 manipulator.do_action(Manipulator::ACTION_MOVE_DOWN);
441               }
442
443               else if(strcmp(se.key, "Left") == 0) {
444                 manipulator.do_action(Manipulator::ACTION_MOVE_LEFT);
445               }
446
447               else if(strcmp(se.key, "g") == 0) {
448                 manipulator.do_action(Manipulator::ACTION_GRAB);
449               }
450
451               else if(strcmp(se.key, "r") == 0) {
452                 manipulator.do_action(Manipulator::ACTION_RELEASE);
453               }
454
455               else if(strcmp(se.key, "space") == 0) {
456                 switch(action_mode) {
457                 case IDLE:
458                   action_mode = RANDOM;
459                   cout << "Switched to random mode" << endl;
460                   break;
461                 case RANDOM:
462                   action_mode = INTELLIGENT;
463                   cout << "Switched to intelligent mode" << endl;
464                   break;
465                 case INTELLIGENT:
466                   cout << "Switched to idle mode" << endl;
467                   action_mode = IDLE;
468                   break;
469                 }
470               }
471
472               else cout << "Undefined key " << se.key << endl;
473             }
474             break;
475
476           case SimpleEvent::KEY_RELEASE:
477             {
478               if(strcmp(se.key, "Shift_L") == 0 || strcmp(se.key, "Shift_R") == 0) press_shift = false;
479             }
480             break;
481
482           default:
483             break;
484
485           }
486         } while(se.type != SimpleEvent::NO_EVENT);
487       } else {
488         cerr << "Error on select: " << strerror(errno) << endl;
489         exit(1);
490       }
491     }
492   }
493
494   if(proportion_for_training > 0) {
495     cout << "Learning ... "; cout.flush();
496     intelligence.learn(proportion_for_training);
497     cout << "done." << endl;
498   }
499
500   if(intelligence_save_file[0]) {
501     cout << "Saving to " << intelligence_save_file << endl; cout.flush();
502     ofstream os(intelligence_save_file);
503     if(os.fail()) {
504       cerr << "error writing " << intelligence_save_file << "." << endl;
505       exit(1);
506     }
507     cout << "done." << endl;
508     intelligence.save(os);
509   }
510
511   delete window_main;
512
513 }