X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=expr.py;fp=expr.py;h=8a899459e4ff4bcc0f731575445fd9218f07b965;hb=b5fd9b344c8c782460941c604b6e637d7549fe7d;hp=0000000000000000000000000000000000000000;hpb=0b147af672d69d5fca328bc937467993c22fb20d;p=picoclvr.git diff --git a/expr.py b/expr.py new file mode 100755 index 0000000..8a89945 --- /dev/null +++ b/expr.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python + +import math + +import torch, torchvision + +from torch import nn +from torch.nn import functional as F + +def random_var(nb_variables=None, variables=None): + if variables is None: + return chr(ord('A') + torch.randint(nb_variables, (1,)).item()) + else: + l = list(variables) + return l[torch.randint(len(l), (1,)).item()] + +def random_expr(variables, budget): + if budget <= 5: + op=torch.randint(2, (1,)).item() + if op == 0 and len(variables) > 0: + return random_var(variables=variables) + else: + return str(torch.randint(10, (1,)).item()) + else: + op=torch.randint(4, (1,)).item() + if op == 0: + e=random_expr(variables,budget-2) + if ("+" in e or "-" in e or "*" in e) and (e[0]!="(" or e[-1]!=")"): + return "("+e+")" + else: + return e + else: + b = 2 + torch.randint(budget-5, (1,)).item() + e1=random_expr(variables,b) + e2=random_expr(variables,budget-b-1) + if op == 1: + return e1+"+"+e2 + elif op == 2: + return e1+"+"+e2 + elif op == 3: + return e1+"*"+e2 + +def generate_program(nb_variables, length): + s = "" + variables = set() + while len(s) < length: + v = random_var(nb_variables=nb_variables) + s += v+"="+random_expr(variables,budget = min(20,length-3-len(s)))+";" + variables.add(v) + return s, variables + +def generate_sequences(nb, nb_variables = 5, length=20): + sequences=[] + for n in range(nb): + result = None + while result==None or max(result.values())>100: + p,v=generate_program(nb_variables, length) + v=", ".join([ "\""+v+"\": "+v for v in v ]) + ldict={} + exec(p+"result={"+v+"}",globals(),ldict) + result=ldict["result"] + + k=list(result.keys()) + k.sort() + sequences.append(p+" "+";".join([v+":"+str(result[v]) for v in k])) + + return sequences + +if __name__ == "__main__": + import time + start_time = time.perf_counter() + sequences=generate_sequences(1000) + end_time = time.perf_counter() + for s in sequences[:10]: + print(s) + print(f"{len(sequences) / (end_time - start_time):.02f} samples per second") +