Cleaning up a bit.
authorFrancois Fleuret <francois@fleuret.org>
Fri, 13 Mar 2020 07:53:37 +0000 (08:53 +0100)
committerFrancois Fleuret <francois@fleuret.org>
Fri, 13 Mar 2020 07:53:37 +0000 (08:53 +0100)
covid19.py

index ef9e393..78c81ed 100755 (executable)
@@ -11,33 +11,38 @@ import matplotlib.pyplot as plt
 import matplotlib.dates as mdates
 import urllib.request
 
-url = 'https://github.com/CSSEGISandData/COVID-19/raw/master/csse_covid_19_data/csse_covid_19_time_series/time_series_19-covid-Confirmed.csv'
+######################################################################
 
-file = url[url.rfind('/')+1:]
+def gentle_download(url, delay = 86400):
+    filename = url[url.rfind('/') + 1:]
+    if not os.path.isfile(filename) or os.path.getmtime(filename) < time.time() - delay:
+        print(f'Retrieving {url}')
+        urllib.request.urlretrieve(url, filename)
+    return filename
 
 ######################################################################
 
-if not os.path.isfile(file) or os.path.getmtime(file) < time.time() - 86400:
-    print('Retrieving file')
-    urllib.request.urlretrieve(url, file)
+nbcases_filename = gentle_download(
+    'https://github.com/CSSEGISandData/COVID-19/raw/master/csse_covid_19_data/csse_covid_19_time_series/time_series_19-covid-Confirmed.csv'
+)
 
 ######################################################################
 
-with open(file, newline='') as csvfile:
+with open(nbcases_filename, newline='') as csvfile:
     reader = csv.reader(csvfile, delimiter=',')
     times = []
     nb_cases = {}
     time_col = 5
     for row_nb, row in enumerate(reader):
         for col_nb, field in enumerate(row):
-            if row_nb >= 1 and col_nb == 1:
-                country = field
-                if not country in nb_cases:
-                    nb_cases[country] = numpy.zeros(len(times))
             if row_nb == 0 and col_nb >= time_col:
                 times.append(time.mktime(time.strptime(field, '%m/%d/%y')))
             if row_nb >= 1:
-                if col_nb >= time_col:
+                if col_nb == 1:
+                    country = field
+                    if not country in nb_cases:
+                        nb_cases[country] = numpy.zeros(len(times))
+                elif col_nb >= time_col:
                     nb_cases[country][col_nb - time_col] += int(field)
 
 countries = list(nb_cases.keys())
@@ -70,12 +75,13 @@ for key, color, label in [
         ('Italy', 'purple', 'Italy'),
         ('China', 'orange', 'China')
 ]:
-    ax.plot(dates, nb_cases[key], color = color, label = label, linewidth=2)
+    ax.plot(dates, nb_cases[key],
+            color = color, label = label, linewidth = 2)
 
 ax.legend(frameon = False)
 
 plt.show()
 
-fig.savefig('covid19.png')
+fig.savefig('covid19_nb_cases.png')
 
 ######################################################################