f682a169076aca29f3f7b4b695bb55fc179b305d
[pytorch.git] / speed.py
1 #!/usr/bin/env python
2
3 import time, torch
4
5 if torch.cuda.is_available():
6     device = torch.device('cuda')
7 else:
8     device = torch.device('cpu')
9
10 nb_runs = 10000
11 d1, d2, d3 = 50000, 256, 512
12
13 a, b = torch.rand(d1, d2).to(device), torch.rand(d2, d3).to(device)
14
15 start_time = time.perf_counter()
16 for k in range(nb_runs):
17     c = a @ b
18 duration = time.perf_counter() - start_time
19
20 nb_flop = float(nb_runs * d1 * d2 * d3)
21 speed = nb_flop / duration
22
23 for u in [ '', 'K', 'M', 'G', 'T', 'P' ]:
24     if speed < 1e3: break
25     speed /= 1e3
26
27 print(f'{speed:.02f} {u}flops on {device}')
28