Update.
[picoclvr.git] / expr.py
1 #!/usr/bin/env python
2
3 import math
4
5 import torch, torchvision
6
7 from torch import nn
8 from torch.nn import functional as F
9
10 def random_var(nb_variables=None, variables=None):
11     if variables is None:
12         return chr(ord('A') + torch.randint(nb_variables, (1,)).item())
13     else:
14         l = list(variables)
15         return l[torch.randint(len(l), (1,)).item()]
16
17 def random_expr(variables, budget):
18     if budget <= 5:
19         op=torch.randint(2, (1,)).item()
20         if op == 0 and len(variables) > 0:
21             return random_var(variables=variables)
22         else:
23             return str(torch.randint(10, (1,)).item())
24     else:
25         op=torch.randint(4, (1,)).item()
26         if op == 0:
27             e=random_expr(variables,budget-2)
28             if ("+" in e or "-" in e or "*" in e) and (e[0]!="(" or e[-1]!=")"):
29                 return "("+e+")"
30             else:
31                 return e
32         else:
33             b = 2 + torch.randint(budget-5, (1,)).item()
34             e1=random_expr(variables,b)
35             e2=random_expr(variables,budget-b-1)
36             if op == 1:
37                 return e1+"+"+e2
38             elif op == 2:
39                 return e1+"+"+e2
40             elif op == 3:
41                 return e1+"*"+e2
42
43 def generate_program(nb_variables, length):
44     s = ""
45     variables = set()
46     while len(s) < length:
47         v = random_var(nb_variables=nb_variables)
48         s += v+"="+random_expr(variables,budget = min(20,length-3-len(s)))+";"
49         variables.add(v)
50     return s, variables
51
52 def generate_sequences(nb, nb_variables = 5, length=20):
53     sequences=[]
54     for n in range(nb):
55         result = None
56         while result==None or max(result.values())>100:
57             p,v=generate_program(nb_variables, length)
58             v=", ".join([ "\""+v+"\": "+v for v in v ])
59             ldict={}
60             exec(p+"result={"+v+"}",globals(),ldict)
61             result=ldict["result"]
62
63         k=list(result.keys())
64         k.sort()
65         sequences.append(p+" "+";".join([v+":"+str(result[v]) for v in k]))
66
67     return sequences
68
69 if __name__ == "__main__":
70     import time
71     start_time = time.perf_counter()
72     sequences=generate_sequences(1000)
73     end_time = time.perf_counter()
74     for s in sequences[:10]:
75         print(s)
76     print(f"{len(sequences) / (end_time - start_time):.02f} samples per second")
77