Update.
[pytorch.git] / speed.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import time, torch
9
10 if torch.cuda.is_available():
11     device = torch.device("cuda")
12     sync = torch.cuda.synchronize
13 else:
14     device = torch.device("cpu")
15     sync = lambda: None
16
17 max_duration = 30
18 d1, d2, d3 = 2048, 2048, 2048
19
20 for t in [torch.float32, torch.float16]:
21     try:
22         a = torch.rand(d1, d2, device=device, dtype=t)
23         b = torch.rand(d2, d3, device=device, dtype=t)
24         nb_runs = 0
25
26         sync()
27         start_time = time.perf_counter()
28         while time.perf_counter() - start_time < max_duration:
29             c = a @ b
30             nb_runs += 1
31         sync()
32         duration = time.perf_counter() - start_time
33
34         nb_flop = float(nb_runs * d1 * d2 * d3 * 2)  # 1 multiply-and-add is 2 ops
35         speed = nb_flop / duration
36
37         for u in ["", "K", "M", "G", "T", "P"]:
38             if speed < 1e3:
39                 break
40             speed /= 1e3
41
42         print(f"{speed:.02f} {u}flops with {t} on {device}")
43
44     except:
45         print(f"{t} is not available on {device}")