Measures for FP16 and FP32
authorFrancois Fleuret <francois@fleuret.org>
Thu, 3 Sep 2020 04:45:19 +0000 (06:45 +0200)
committerFrancois Fleuret <francois@fleuret.org>
Thu, 3 Sep 2020 04:45:19 +0000 (06:45 +0200)
speed.py

index e4add26..9e845db 100755 (executable)
--- a/speed.py
+++ b/speed.py
@@ -12,21 +12,22 @@ else:
 nb_runs = 10000
 d1, d2, d3 = 2048, 2048, 2048
 
-a, b = torch.rand(d1, d2).to(device), torch.rand(d2, d3).to(device)
+for t in [ torch.float32, torch.float16 ]:
+    a = torch.rand(d1, d2, device = device, dtype = t)
+    b = torch.rand(d2, d3, device = device, dtype = t)
 
-sync()
-start_time = time.perf_counter()
-for k in range(nb_runs):
-    c = a @ b
-sync()
-duration = time.perf_counter() - start_time
+    sync()
+    start_time = time.perf_counter()
+    for k in range(nb_runs):
+        c = a @ b
+    sync()
+    duration = time.perf_counter() - start_time
 
-nb_flop = float(nb_runs * d1 * d2 * d3 * 2) # 1 multiply-and-add is 2 ops
-speed = nb_flop / duration
+    nb_flop = float(nb_runs * d1 * d2 * d3 * 2) # 1 multiply-and-add is 2 ops
+    speed = nb_flop / duration
 
-for u in [ '', 'K', 'M', 'G', 'T', 'P' ]:
-    if speed < 1e3: break
-    speed /= 1e3
-
-print(f'{speed:.02f} {u}flops on {device}')
+    for u in [ '', 'K', 'M', 'G', 'T', 'P' ]:
+        if speed < 1e3: break
+        speed /= 1e3
 
+    print(f'{speed:.02f} {u}flops with {t} on {device}')