# coding: utf-8
# pylint: disable = C0103
"""Plotting library."""
from __future__ import absolute_import
import warnings
from copy import deepcopy
from io import BytesIO
import numpy as np
from .basic import Booster
from .compat import (MATPLOTLIB_INSTALLED, GRAPHVIZ_INSTALLED, LGBMDeprecationWarning,
range_, zip_, string_type)
from .sklearn import LGBMModel
def _check_not_tuple_of_2_elements(obj, obj_name='obj'):
"""Check object is not tuple or does not have 2 elements."""
if not isinstance(obj, tuple) or len(obj) != 2:
raise TypeError('%s must be a tuple of 2 elements.' % obj_name)
def _float2str(value, precision=None):
return ("{0:.{1}f}".format(value, precision)
if precision is not None and not isinstance(value, string_type)
else str(value))
[docs]def plot_importance(booster, ax=None, height=0.2,
xlim=None, ylim=None, title='Feature importance',
xlabel='Feature importance', ylabel='Features',
importance_type='split', max_num_features=None,
ignore_zero=True, figsize=None, grid=True,
precision=None, **kwargs):
"""Plot model's feature importances.
Parameters
----------
booster : Booster or LGBMModel
Booster or LGBMModel instance which feature importance should be plotted.
ax : matplotlib.axes.Axes or None, optional (default=None)
Target axes instance.
If None, new figure and axes will be created.
height : float, optional (default=0.2)
Bar height, passed to ``ax.barh()``.
xlim : tuple of 2 elements or None, optional (default=None)
Tuple passed to ``ax.xlim()``.
ylim : tuple of 2 elements or None, optional (default=None)
Tuple passed to ``ax.ylim()``.
title : string or None, optional (default="Feature importance")
Axes title.
If None, title is disabled.
xlabel : string or None, optional (default="Feature importance")
X-axis title label.
If None, title is disabled.
ylabel : string or None, optional (default="Features")
Y-axis title label.
If None, title is disabled.
importance_type : string, optional (default="split")
How the importance is calculated.
If "split", result contains numbers of times the feature is used in a model.
If "gain", result contains total gains of splits which use the feature.
max_num_features : int or None, optional (default=None)
Max number of top features displayed on plot.
If None or <1, all features will be displayed.
ignore_zero : bool, optional (default=True)
Whether to ignore features with zero importance.
figsize : tuple of 2 elements or None, optional (default=None)
Figure size.
grid : bool, optional (default=True)
Whether to add a grid for axes.
precision : int or None, optional (default=None)
Used to restrict the display of floating point values to a certain precision.
**kwargs
Other parameters passed to ``ax.barh()``.
Returns
-------
ax : matplotlib.axes.Axes
The plot with model's feature importances.
"""
if MATPLOTLIB_INSTALLED:
import matplotlib.pyplot as plt
else:
raise ImportError('You must install matplotlib to plot importance.')
if isinstance(booster, LGBMModel):
booster = booster.booster_
elif not isinstance(booster, Booster):
raise TypeError('booster must be Booster or LGBMModel.')
importance = booster.feature_importance(importance_type=importance_type)
feature_name = booster.feature_name()
if not len(importance):
raise ValueError("Booster's feature_importance is empty.")
tuples = sorted(zip_(feature_name, importance), key=lambda x: x[1])
if ignore_zero:
tuples = [x for x in tuples if x[1] > 0]
if max_num_features is not None and max_num_features > 0:
tuples = tuples[-max_num_features:]
labels, values = zip_(*tuples)
if ax is None:
if figsize is not None:
_check_not_tuple_of_2_elements(figsize, 'figsize')
_, ax = plt.subplots(1, 1, figsize=figsize)
ylocs = np.arange(len(values))
ax.barh(ylocs, values, align='center', height=height, **kwargs)
for x, y in zip_(values, ylocs):
ax.text(x + 1, y,
_float2str(x, precision) if importance_type == 'gain' else x,
va='center')
ax.set_yticks(ylocs)
ax.set_yticklabels(labels)
if xlim is not None:
_check_not_tuple_of_2_elements(xlim, 'xlim')
else:
xlim = (0, max(values) * 1.1)
ax.set_xlim(xlim)
if ylim is not None:
_check_not_tuple_of_2_elements(ylim, 'ylim')
else:
ylim = (-1, len(values))
ax.set_ylim(ylim)
if title is not None:
ax.set_title(title)
if xlabel is not None:
ax.set_xlabel(xlabel)
if ylabel is not None:
ax.set_ylabel(ylabel)
ax.grid(grid)
return ax
[docs]def plot_metric(booster, metric=None, dataset_names=None,
ax=None, xlim=None, ylim=None,
title='Metric during training',
xlabel='Iterations', ylabel='auto',
figsize=None, grid=True):
"""Plot one metric during training.
Parameters
----------
booster : dict or LGBMModel
Dictionary returned from ``lightgbm.train()`` or LGBMModel instance.
metric : string or None, optional (default=None)
The metric name to plot.
Only one metric supported because different metrics have various scales.
If None, first metric picked from dictionary (according to hashcode).
dataset_names : list of strings or None, optional (default=None)
List of the dataset names which are used to calculate metric to plot.
If None, all datasets are used.
ax : matplotlib.axes.Axes or None, optional (default=None)
Target axes instance.
If None, new figure and axes will be created.
xlim : tuple of 2 elements or None, optional (default=None)
Tuple passed to ``ax.xlim()``.
ylim : tuple of 2 elements or None, optional (default=None)
Tuple passed to ``ax.ylim()``.
title : string or None, optional (default="Metric during training")
Axes title.
If None, title is disabled.
xlabel : string or None, optional (default="Iterations")
X-axis title label.
If None, title is disabled.
ylabel : string or None, optional (default="auto")
Y-axis title label.
If 'auto', metric name is used.
If None, title is disabled.
figsize : tuple of 2 elements or None, optional (default=None)
Figure size.
grid : bool, optional (default=True)
Whether to add a grid for axes.
Returns
-------
ax : matplotlib.axes.Axes
The plot with metric's history over the training.
"""
if MATPLOTLIB_INSTALLED:
import matplotlib.pyplot as plt
else:
raise ImportError('You must install matplotlib to plot metric.')
if isinstance(booster, LGBMModel):
eval_results = deepcopy(booster.evals_result_)
elif isinstance(booster, dict):
eval_results = deepcopy(booster)
else:
raise TypeError('booster must be dict or LGBMModel.')
num_data = len(eval_results)
if not num_data:
raise ValueError('eval results cannot be empty.')
if ax is None:
if figsize is not None:
_check_not_tuple_of_2_elements(figsize, 'figsize')
_, ax = plt.subplots(1, 1, figsize=figsize)
if dataset_names is None:
dataset_names = iter(eval_results.keys())
elif not isinstance(dataset_names, (list, tuple, set)) or not dataset_names:
raise ValueError('dataset_names should be iterable and cannot be empty')
else:
dataset_names = iter(dataset_names)
name = next(dataset_names) # take one as sample
metrics_for_one = eval_results[name]
num_metric = len(metrics_for_one)
if metric is None:
if num_metric > 1:
msg = """more than one metric available, picking one to plot."""
warnings.warn(msg, stacklevel=2)
metric, results = metrics_for_one.popitem()
else:
if metric not in metrics_for_one:
raise KeyError('No given metric in eval results.')
results = metrics_for_one[metric]
num_iteration, max_result, min_result = len(results), max(results), min(results)
x_ = range_(num_iteration)
ax.plot(x_, results, label=name)
for name in dataset_names:
metrics_for_one = eval_results[name]
results = metrics_for_one[metric]
max_result, min_result = max(max(results), max_result), min(min(results), min_result)
ax.plot(x_, results, label=name)
ax.legend(loc='best')
if xlim is not None:
_check_not_tuple_of_2_elements(xlim, 'xlim')
else:
xlim = (0, num_iteration)
ax.set_xlim(xlim)
if ylim is not None:
_check_not_tuple_of_2_elements(ylim, 'ylim')
else:
range_result = max_result - min_result
ylim = (min_result - range_result * 0.2, max_result + range_result * 0.2)
ax.set_ylim(ylim)
if ylabel == 'auto':
ylabel = metric
if title is not None:
ax.set_title(title)
if xlabel is not None:
ax.set_xlabel(xlabel)
if ylabel is not None:
ax.set_ylabel(ylabel)
ax.grid(grid)
return ax
def _to_graphviz(tree_info, show_info, feature_names, precision=None, **kwargs):
"""Convert specified tree to graphviz instance.
See:
- https://graphviz.readthedocs.io/en/stable/api.html#digraph
"""
if GRAPHVIZ_INSTALLED:
from graphviz import Digraph
else:
raise ImportError('You must install graphviz to plot tree.')
def add(root, parent=None, decision=None):
"""Recursively add node or edge."""
if 'split_index' in root: # non-leaf
name = 'split{0}'.format(root['split_index'])
if feature_names is not None:
label = 'split_feature_name: {0}'.format(feature_names[root['split_feature']])
else:
label = 'split_feature_index: {0}'.format(root['split_feature'])
label += r'\nthreshold: {0}'.format(_float2str(root['threshold'], precision))
for info in show_info:
if info in {'split_gain', 'internal_value'}:
label += r'\n{0}: {1}'.format(info, _float2str(root[info], precision))
elif info == 'internal_count':
label += r'\n{0}: {1}'.format(info, root[info])
graph.node(name, label=label)
if root['decision_type'] == '<=':
l_dec, r_dec = '<=', '>'
elif root['decision_type'] == '==':
l_dec, r_dec = 'is', "isn't"
else:
raise ValueError('Invalid decision type in tree model.')
add(root['left_child'], name, l_dec)
add(root['right_child'], name, r_dec)
else: # leaf
name = 'leaf{0}'.format(root['leaf_index'])
label = 'leaf_index: {0}'.format(root['leaf_index'])
label += r'\nleaf_value: {0}'.format(_float2str(root['leaf_value'], precision))
if 'leaf_count' in show_info:
label += r'\nleaf_count: {0}'.format(root['leaf_count'])
graph.node(name, label=label)
if parent is not None:
graph.edge(parent, name, decision)
graph = Digraph(**kwargs)
add(tree_info['tree_structure'])
return graph
[docs]def create_tree_digraph(booster, tree_index=0, show_info=None, precision=None,
old_name=None, old_comment=None, old_filename=None, old_directory=None,
old_format=None, old_engine=None, old_encoding=None, old_graph_attr=None,
old_node_attr=None, old_edge_attr=None, old_body=None, old_strict=False, **kwargs):
"""Create a digraph representation of specified tree.
Note
----
For more information please visit
https://graphviz.readthedocs.io/en/stable/api.html#digraph.
Parameters
----------
booster : Booster or LGBMModel
Booster or LGBMModel instance to be converted.
tree_index : int, optional (default=0)
The index of a target tree to convert.
show_info : list of strings or None, optional (default=None)
What information should be shown in nodes.
Possible values of list items: 'split_gain', 'internal_value', 'internal_count', 'leaf_count'.
precision : int or None, optional (default=None)
Used to restrict the display of floating point values to a certain precision.
**kwargs
Other parameters passed to ``Digraph`` constructor.
Check https://graphviz.readthedocs.io/en/stable/api.html#digraph for the full list of supported parameters.
Returns
-------
graph : graphviz.Digraph
The digraph representation of specified tree.
"""
if isinstance(booster, LGBMModel):
booster = booster.booster_
elif not isinstance(booster, Booster):
raise TypeError('booster must be Booster or LGBMModel.')
for param_name in ['old_name', 'old_comment', 'old_filename', 'old_directory',
'old_format', 'old_engine', 'old_encoding', 'old_graph_attr',
'old_node_attr', 'old_edge_attr', 'old_body']:
param = locals().get(param_name)
if param is not None:
warnings.warn('{0} parameter is deprecated and will be removed in 2.4 version.\n'
'Please use **kwargs to pass {1} parameter.'.format(param_name, param_name[4:]),
LGBMDeprecationWarning)
if param_name[4:] not in kwargs:
kwargs[param_name[4:]] = param
if locals().get('strict'):
warnings.warn('old_strict parameter is deprecated and will be removed in 2.4 version.\n'
'Please use **kwargs to pass strict parameter.',
LGBMDeprecationWarning)
if 'strict' not in kwargs:
kwargs['strict'] = True
model = booster.dump_model()
tree_infos = model['tree_info']
if 'feature_names' in model:
feature_names = model['feature_names']
else:
feature_names = None
if tree_index < len(tree_infos):
tree_info = tree_infos[tree_index]
else:
raise IndexError('tree_index is out of range.')
if show_info is None:
show_info = []
graph = _to_graphviz(tree_info, show_info, feature_names, precision, **kwargs)
return graph
[docs]def plot_tree(booster, ax=None, tree_index=0, figsize=None,
old_graph_attr=None, old_node_attr=None, old_edge_attr=None,
show_info=None, precision=None, **kwargs):
"""Plot specified tree.
Note
----
It is preferable to use ``create_tree_digraph()`` because of its lossless quality
and returned objects can be also rendered and displayed directly inside a Jupyter notebook.
Parameters
----------
booster : Booster or LGBMModel
Booster or LGBMModel instance to be plotted.
ax : matplotlib.axes.Axes or None, optional (default=None)
Target axes instance.
If None, new figure and axes will be created.
tree_index : int, optional (default=0)
The index of a target tree to plot.
figsize : tuple of 2 elements or None, optional (default=None)
Figure size.
show_info : list of strings or None, optional (default=None)
What information should be shown in nodes.
Possible values of list items: 'split_gain', 'internal_value', 'internal_count', 'leaf_count'.
precision : int or None, optional (default=None)
Used to restrict the display of floating point values to a certain precision.
**kwargs
Other parameters passed to ``Digraph`` constructor.
Check https://graphviz.readthedocs.io/en/stable/api.html#digraph for the full list of supported parameters.
Returns
-------
ax : matplotlib.axes.Axes
The plot with single tree.
"""
if MATPLOTLIB_INSTALLED:
import matplotlib.pyplot as plt
import matplotlib.image as image
else:
raise ImportError('You must install matplotlib to plot tree.')
for param_name in ['old_graph_attr', 'old_node_attr', 'old_edge_attr']:
param = locals().get(param_name)
if param is not None:
warnings.warn('{0} parameter is deprecated and will be removed in 2.4 version.\n'
'Please use **kwargs to pass {1} parameter.'.format(param_name, param_name[4:]),
LGBMDeprecationWarning)
if param_name[4:] not in kwargs:
kwargs[param_name[4:]] = param
if ax is None:
if figsize is not None:
_check_not_tuple_of_2_elements(figsize, 'figsize')
_, ax = plt.subplots(1, 1, figsize=figsize)
graph = create_tree_digraph(booster=booster, tree_index=tree_index,
show_info=show_info, precision=precision, **kwargs)
s = BytesIO()
s.write(graph.pipe(format='png'))
s.seek(0)
img = image.imread(s)
ax.imshow(img)
ax.axis('off')
return ax