5 if torch.cuda.is_available():
6 device = torch.device('cuda')
7 sync = lambda: torch.cuda.synchronize()
9 device = torch.device('cpu')
13 d1, d2, d3 = 2048, 2048, 2048
15 a, b = torch.rand(d1, d2).to(device), torch.rand(d2, d3).to(device)
18 start_time = time.perf_counter()
19 for k in range(nb_runs):
22 duration = time.perf_counter() - start_time
24 nb_flop = float(nb_runs * d1 * d2 * d3 * 2) # 1 multiply-and-add is 2 ops
25 speed = nb_flop / duration
27 for u in [ '', 'K', 'M', 'G', 'T', 'P' ]:
31 print(f'{speed:.02f} {u}flops on {device}')