前回はSARIMAモデルで乗客数を予測してみました。
今回は大本命のprophetを使ってモデリングをしようと思います。
アップル引っ越しのデータセットで既にprophetを使ったことがありますので、こちらの記事も参考になれば幸いです。
(オプション) Prophetのインストール
下記にprophetのインストール手順が記載されていますので詳細は割愛します。
ライブラリ間の依存関係もあるので、まずはPythonの仮想環境を作成することをおすすめしています。
簡単にまとめると下記のような形で仮想環境の作成とprophetのインストールができると思います。
# 仮想環境作成
python3 -m venv venv-prophet
# 仮想環境内に入る
source venv-prophet/bin/activate
# prophetとplotlyのインストール
python3 -m pip install prophet
python3 -m pip install plotly
seabornのflightデータの確認
flightsデータの中身をビジュアル化しています。確認したい方は下記から参照ください。
USの国際線乗客数をprophetで予想する
ARモデルやSARIMAモデルの時は非定常過程のデータを定常過程に変換していましたが、prophetの場合はそのままの状態で利用し全てアルゴリズムにお任せしてみます。
prophet is not really a timeseries model. its just estimating the signal as a function of time - so its good for trends and periodic behaviour.
参考: https://stats.stackexchange.com/questions/591425/time-series-data-transformation-for-prophet-model
# 描画設定
from IPython.display import HTML
import seaborn as sns
from matplotlib import ticker
import matplotlib.pyplot as plt
sns.set_style("whitegrid")
from matplotlib import rcParams
rcParams['font.family'] = 'Hiragino Sans' # Macの場合
#rcParams['font.family'] = 'Meiryo' # Windowsの場合
#rcParams['font.family'] = 'VL PGothic' # Linuxの場合
rcParams['xtick.labelsize'] = 12 # x軸のラベルのフォントサイズ
rcParams['ytick.labelsize'] = 12 # y軸のラベルのフォントサイズ
rcParams['axes.labelsize'] = 18 # ラベルのフォントとサイズ
rcParams['figure.figsize'] = 18,8 # 画像サイズの変更(inch)
データの読み込みとデータ加工
import pandas as pd
# データの読み込み (既に読み込み済みだが、読みやすさのためもう一度実行する)
flights = sns.load_dataset("flights")
# 日付変換用カラムの作成
flights["date"] = flights.year.astype('str') + "-" + flights.month.astype('str')
# 日付カラムの作成
flights["ds"] = pd.to_datetime(flights["date"], format="%Y-%b")
# 目的変数の型をfloat64に変換 (日付カラムはindexに設定しないこと)
flights = flights.set_index('ds',drop=False).astype({'passengers': 'float64'}).drop(["year","month","date"], axis=1)
# prophet用にpassengersカラムの名称をyに変更
flights = flights.rename(columns={'passengers': 'y'})
# 日付の間隔をasfreqメソッドで変更する
flights = flights.asfreq('MS')
flights
y ds ds 1949-01-01 112.0 1949-01-01 1949-02-01 118.0 1949-02-01 1949-03-01 132.0 1949-03-01 1949-04-01 129.0 1949-04-01 1949-05-01 121.0 1949-05-01 ... ... ... 1960-08-01 606.0 1960-08-01 1960-09-01 508.0 1960-09-01 1960-10-01 461.0 1960-10-01 1960-11-01 390.0 1960-11-01 1960-12-01 432.0 1960-12-01 144 rows × 2 columns
訓練データとテストデータに分割
# 訓練用とテスト用に分ける
df = flights.iloc[:-12].copy() # 1949-01 ~ 1959-12
df_test = flights.iloc[-12:].copy() # 1960-01 ~ 1960-12
prophetのモデル定義の作成
from prophet import Prophet
# Prophetのモデルを「乗法」で作成する
m = Prophet(seasonality_mode='multiplicative')
# 休日情報の追加はしないでやってみる
# m.add_country_holidays(country_name='US')
# 月次変動の追加
m.add_seasonality(name='monthly', period=30.5, fourier_order=5)
prophetで学習と予測
# 学習する
m.fit(flights)
20:53:30 - cmdstanpy - INFO - Chain [1] start processing 20:53:30 - cmdstanpy - INFO - Chain [1] done processing
# サンプルに倣ってfutureという変数にテストデータを格納
future = df_test
future.tail()
y ds ds 1960-08-01 606.0 1960-08-01 1960-09-01 508.0 1960-09-01 1960-10-01 461.0 1960-10-01 1960-11-01 390.0 1960-11-01 1960-12-01 432.0 1960-12-01
# 時系列予測の実施
forecast = m.predict(future)
forecast = forecast.set_index('ds',drop=False)
forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']].tail()
ds yhat yhat_lower yhat_upper ds 1960-08-01 1960-08-01 606.280191 596.077565 616.221777 1960-09-01 1960-09-01 514.147986 502.799201 524.273711 1960-10-01 1960-10-01 456.119044 445.630420 465.832296 1960-11-01 1960-11-01 390.898670 379.908335 401.735758 1960-12-01 1960-12-01 437.168026 426.236624 447.385007
# 予測結果を描画
fig1 = m.plot(forecast)
# トレンドや各種変動の描画
fig2 = m.plot_components(forecast)
トレンドや年次変動なども確認できるので便利ですね。
やっぱりUSの夏休み(サマーブレーク)である5月末あたりから8月までのの期間は利用が伸びていそうです。
予測も当たっていそうです。
plt.figure(figsize=(20, 3))
ytrue = flights.y.iloc[:-12] # 訓練データの実数値
ypred = forecast.yhat # 1960-01 ~ 1960-12の予測値
ytrue2 = flights.y.iloc[-12:] # 1960-01 ~ 1960-12の実測値
plt.plot(ytrue, label="Training Data")
plt.plot(ypred, label="Forecasts")
plt.plot(ytrue2, label="Test Data")
plt.title("airplane passengers prediction")
_ = plt.legend()
やっぱりprophetが一番! 笑
まとめ
prophetはモデリングの時間も短く精度も良いのでコスパが一番いいのではないかと思います。
時系列予測の案件を会社でやることになった方はとりあえずprophetを試してみてもよいのではないかと思います。
本記事で利用したライブラリのバージョン
import pandas as pd
import numpy as np
import plotly
import matplotlib as mpl
import scipy as scp
import prophet
print('pandas',pd.__version__)
print('numpy',np.__version__)
print('plotly',plotly.__version__)
print('matplotlib',mpl.__version__)
print('scipy',scp.__version__)
print('prophet',prophet.__version__)
pandas 1.4.3 numpy 1.23.2 plotly 5.10.0 matplotlib 3.5.3 scipy 1.9.0 prophet 1.0