Gaussian HMM of stock dataΒΆ
This script shows how to use Gaussian HMM on stock price data from
Yahoo! finance. For more information on how to visualize stock prices
with matplotlib, please refer to date_demo1.py
of matplotlib.
from __future__ import print_function
import datetime
import numpy as np
from matplotlib import cm, pyplot as plt
from matplotlib.dates import YearLocator, MonthLocator
try:
from matplotlib.finance import quotes_historical_yahoo_ochl
except ImportError:
# For Matplotlib prior to 1.5.
from matplotlib.finance import (
quotes_historical_yahoo as quotes_historical_yahoo_ochl
)
from hmmlearn.hmm import GaussianHMM
print(__doc__)
Get quotes from Yahoo! finance
quotes = quotes_historical_yahoo_ochl(
"INTC", datetime.date(1995, 1, 1), datetime.date(2012, 1, 6))
# Unpack quotes
dates = np.array([q[0] for q in quotes], dtype=int)
close_v = np.array([q[2] for q in quotes])
volume = np.array([q[5] for q in quotes])[1:]
# Take diff of close value. Note that this makes
# ``len(diff) = len(close_t) - 1``, therefore, other quantities also
# need to be shifted by 1.
diff = np.diff(close_v)
dates = dates[1:]
close_v = close_v[1:]
# Pack diff and volume for training.
X = np.column_stack([diff, volume])
Run Gaussian HMM
print("fitting to HMM and decoding ...", end="")
# Make an HMM instance and execute fit
model = GaussianHMM(n_components=4, covariance_type="diag", n_iter=1000).fit(X)
# Predict the optimal sequence of internal hidden state
hidden_states = model.predict(X)
print("done")
Print trained parameters and plot
print("Transition matrix")
print(model.transmat_)
print()
print("Means and vars of each hidden state")
for i in range(model.n_components):
print("{0}th hidden state".format(i))
print("mean = ", model.means_[i])
print("var = ", np.diag(model.covars_[i]))
print()
fig, axs = plt.subplots(model.n_components, sharex=True, sharey=True)
colours = cm.rainbow(np.linspace(0, 1, model.n_components))
for i, (ax, colour) in enumerate(zip(axs, colours)):
# Use fancy indexing to plot data in each state.
mask = hidden_states == i
ax.plot_date(dates[mask], close_v[mask], ".-", c=colour)
ax.set_title("{0}th hidden state".format(i))
# Format the ticks.
ax.xaxis.set_major_locator(YearLocator())
ax.xaxis.set_minor_locator(MonthLocator())
ax.grid(True)
plt.show()
Total running time of the script: ( 0 minutes 0.000 seconds)