파이썬으로 코로나19 확진자수 예측하기 (Fitting data with NumPy, SciPy, Matplotlib)

NumPy, SciPy는 수학·과학 계산에 유용한 파이썬 라이브러리이고,

Matplotlib은 그래프 플롯을 위한 라이브러리입니다.


파이썬과 NumPy, SciPy, Matplotlib 라이브러리를 이용해서 데이터 피팅을 수행하고, 그래프로 나타내는 방법을 소개합니다.

1월 16일부터 3월 1일까지 15일 동안의 날짜별 코로나19 확진자수 데이터를 사용했고,

피팅과 예측 모델로써, 간단한 \(y = a\cdot exp(b\cdot x)\) 함수를 사용했습니다.


예제

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
import numpy as np


def func(x, a, b):
    return a * np.exp(b * x)


day = np.arange(1, 16)
cases = np.array([30, 30, 31, 51, 104, 204, 433, 602, 833, 977, 1261, 1766, 2337, 3150, 3736])
plt.plot(day, cases, 'co', label='Total Number of Cases (Korea)')

popt, pcov = curve_fit(func, day, cases, p0=(1, 0.1))
plt.plot(day, func(day, *popt), 'r-', alpha=0.5, label='fit: a = %5.3f, b = %5.3f' % tuple(popt))

plt.xlabel('Days Since January 16')
plt.ylabel('Total Number of Cases (Korea)')
plt.legend()
plt.show()


설명

import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
import numpy as np

필요한 라이브러리와 모듈을 불러옵니다.

scipy.optimize는 자주 사용되는 최적화 알고리즘을 제공하는 패키지입니다.

curve_fit최소제곱법을 이용해서 주어진 함수를 데이터에 피팅하는 기능을 갖는 함수입니다.



day = np.arange(1, 16)
cases = np.array([30, 30, 31, 51, 104, 204, 433, 602, 833, 977, 1261, 1766, 2337, 3150, 3736])
plt.plot(day, cases, 'co', label='Total Number of Cases (Korea)')

day는 1월 16일로부터 경과일을 나타내는 NumPy 어레이이고,

cases는 날짜별 확진자수를 나타내는 NumPy 어레이입니다.

matplotlib.pyplot의 plot() 함수를 이용해서 day와 cases를 x, y 값으로 하는 그래프를 나타냅니다.



popt, pcov = curve_fit(func, day, cases, p0=(1, 0.1))
plt.plot(day, func(day, *popt), 'r-', alpha=0.5, label='fit: a = %5.3f, b = %5.3f' % tuple(popt))

popt와 pcov는 curve_fit() 함수가 반환하는 값입니다.

popt는 최소제곱법을 이용해서 함수와 데이터를 피팅하는 최적화된 파라미터 값이고, pcov는 covariance의 행렬입니다.

(자세한 내용은 scipy.optimize.curve_fit 페이지를 참고하세요.)

func(day, *popt)는 최적화된 파라미터를 사용한 피팅 함수의 값입니다.

최적 파라미터의 값 popt를 그래프에 표시하도록 했습니다.



plt.xlabel('Days Since January 16')
plt.ylabel('Total Number of Cases (Korea)')
plt.legend()
plt.show()

plt.xlabe()과 plt.ylabel()을 사용해서 x, y 축의 레이블을 나타내고,

plt.legend()를 이용해서 범례를 나타냅니다.

결과는 아래와 같습니다.


결과

_images/fitting_data_01.png