Update.
authorFrancois Fleuret <francois@fleuret.org>
Thu, 3 Sep 2020 06:18:03 +0000 (08:18 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Thu, 3 Sep 2020 06:18:03 +0000 (08:18 +0200)
speed.py

index e03b3b7..e5b0e3a 100755 (executable)
--- a/speed.py
+++ b/speed.py
@@ -9,30 +9,37 @@ import time, torch
 
 if torch.cuda.is_available():
     device = torch.device('cuda')
 
 if torch.cuda.is_available():
     device = torch.device('cuda')
-    sync = lambda: torch.cuda.synchronize()
+    sync = torch.cuda.synchronize
 else:
     device = torch.device('cpu')
     sync = lambda: None
 
 else:
     device = torch.device('cpu')
     sync = lambda: None
 
-nb_runs = 10000
+max_duration = 30
 d1, d2, d3 = 2048, 2048, 2048
 
 for t in [ torch.float32, torch.float16 ]:
 d1, d2, d3 = 2048, 2048, 2048
 
 for t in [ torch.float32, torch.float16 ]:
-    a = torch.rand(d1, d2, device = device, dtype = t)
-    b = torch.rand(d2, d3, device = device, dtype = t)
+    try:
+        a = torch.rand(d1, d2, device = device, dtype = t)
+        b = torch.rand(d2, d3, device = device, dtype = t)
+        nb_runs = 0
 
 
-    sync()
-    start_time = time.perf_counter()
-    for k in range(nb_runs):
-        c = a @ b
-    sync()
-    duration = time.perf_counter() - start_time
+        sync()
+        start_time = time.perf_counter()
+        while time.perf_counter() - start_time < max_duration:
+            c = a @ b
+            nb_runs += 1
+        sync()
+        duration = time.perf_counter() - start_time
 
 
-    nb_flop = float(nb_runs * d1 * d2 * d3 * 2) # 1 multiply-and-add is 2 ops
-    speed = nb_flop / duration
+        nb_flop = float(nb_runs * d1 * d2 * d3 * 2) # 1 multiply-and-add is 2 ops
+        speed = nb_flop / duration
 
 
-    for u in [ '', 'K', 'M', 'G', 'T', 'P' ]:
-        if speed < 1e3: break
-        speed /= 1e3
+        for u in [ '', 'K', 'M', 'G', 'T', 'P' ]:
+            if speed < 1e3: break
+            speed /= 1e3
 
 
-    print(f'{speed:.02f} {u}flops with {t} on {device}')
+        print(f'{speed:.02f} {u}flops with {t} on {device}')
+
+    except:
+
+        print(f'Cannot try with {t}')