e03b3b74febe038e3e7382fb155909e197c71cc1
[pytorch.git] / speed.py
1 #!/usr/bin/env python
2
3 # Any copyright is dedicated to the Public Domain.
4 # https://creativecommons.org/publicdomain/zero/1.0/
5
6 # Written by Francois Fleuret <francois@fleuret.org>
7
8 import time, torch
9
10 if torch.cuda.is_available():
11     device = torch.device('cuda')
12     sync = lambda: torch.cuda.synchronize()
13 else:
14     device = torch.device('cpu')
15     sync = lambda: None
16
17 nb_runs = 10000
18 d1, d2, d3 = 2048, 2048, 2048
19
20 for t in [ torch.float32, torch.float16 ]:
21     a = torch.rand(d1, d2, device = device, dtype = t)
22     b = torch.rand(d2, d3, device = device, dtype = t)
23
24     sync()
25     start_time = time.perf_counter()
26     for k in range(nb_runs):
27         c = a @ b
28     sync()
29     duration = time.perf_counter() - start_time
30
31     nb_flop = float(nb_runs * d1 * d2 * d3 * 2) # 1 multiply-and-add is 2 ops
32     speed = nb_flop / duration
33
34     for u in [ '', 'K', 'M', 'G', 'T', 'P' ]:
35         if speed < 1e3: break
36         speed /= 1e3
37
38     print(f'{speed:.02f} {u}flops with {t} on {device}')