Source code for lightgbm.plotting

# coding: utf-8
# pylint: disable = C0103
"""Plotting library."""
from __future__ import absolute_import

import warnings
from copy import deepcopy
from io import BytesIO

import numpy as np

from .basic import Booster
from .compat import (MATPLOTLIB_INSTALLED, GRAPHVIZ_INSTALLED, LGBMDeprecationWarning,
                     range_, zip_, string_type)
from .sklearn import LGBMModel


def _check_not_tuple_of_2_elements(obj, obj_name='obj'):
    """Check object is not tuple or does not have 2 elements."""
    if not isinstance(obj, tuple) or len(obj) != 2:
        raise TypeError('%s must be a tuple of 2 elements.' % obj_name)


def _float2str(value, precision=None):
    return ("{0:.{1}f}".format(value, precision)
            if precision is not None and not isinstance(value, string_type)
            else str(value))


[docs]def plot_importance(booster, ax=None, height=0.2, xlim=None, ylim=None, title='Feature importance', xlabel='Feature importance', ylabel='Features', importance_type='split', max_num_features=None, ignore_zero=True, figsize=None, grid=True, precision=None, **kwargs): """Plot model's feature importances. Parameters ---------- booster : Booster or LGBMModel Booster or LGBMModel instance which feature importance should be plotted. ax : matplotlib.axes.Axes or None, optional (default=None) Target axes instance. If None, new figure and axes will be created. height : float, optional (default=0.2) Bar height, passed to ``ax.barh()``. xlim : tuple of 2 elements or None, optional (default=None) Tuple passed to ``ax.xlim()``. ylim : tuple of 2 elements or None, optional (default=None) Tuple passed to ``ax.ylim()``. title : string or None, optional (default="Feature importance") Axes title. If None, title is disabled. xlabel : string or None, optional (default="Feature importance") X-axis title label. If None, title is disabled. ylabel : string or None, optional (default="Features") Y-axis title label. If None, title is disabled. importance_type : string, optional (default="split") How the importance is calculated. If "split", result contains numbers of times the feature is used in a model. If "gain", result contains total gains of splits which use the feature. max_num_features : int or None, optional (default=None) Max number of top features displayed on plot. If None or <1, all features will be displayed. ignore_zero : bool, optional (default=True) Whether to ignore features with zero importance. figsize : tuple of 2 elements or None, optional (default=None) Figure size. grid : bool, optional (default=True) Whether to add a grid for axes. precision : int or None, optional (default=None) Used to restrict the display of floating point values to a certain precision. **kwargs Other parameters passed to ``ax.barh()``. Returns ------- ax : matplotlib.axes.Axes The plot with model's feature importances. """ if MATPLOTLIB_INSTALLED: import matplotlib.pyplot as plt else: raise ImportError('You must install matplotlib to plot importance.') if isinstance(booster, LGBMModel): booster = booster.booster_ elif not isinstance(booster, Booster): raise TypeError('booster must be Booster or LGBMModel.') importance = booster.feature_importance(importance_type=importance_type) feature_name = booster.feature_name() if not len(importance): raise ValueError("Booster's feature_importance is empty.") tuples = sorted(zip_(feature_name, importance), key=lambda x: x[1]) if ignore_zero: tuples = [x for x in tuples if x[1] > 0] if max_num_features is not None and max_num_features > 0: tuples = tuples[-max_num_features:] labels, values = zip_(*tuples) if ax is None: if figsize is not None: _check_not_tuple_of_2_elements(figsize, 'figsize') _, ax = plt.subplots(1, 1, figsize=figsize) ylocs = np.arange(len(values)) ax.barh(ylocs, values, align='center', height=height, **kwargs) for x, y in zip_(values, ylocs): ax.text(x + 1, y, _float2str(x, precision) if importance_type == 'gain' else x, va='center') ax.set_yticks(ylocs) ax.set_yticklabels(labels) if xlim is not None: _check_not_tuple_of_2_elements(xlim, 'xlim') else: xlim = (0, max(values) * 1.1) ax.set_xlim(xlim) if ylim is not None: _check_not_tuple_of_2_elements(ylim, 'ylim') else: ylim = (-1, len(values)) ax.set_ylim(ylim) if title is not None: ax.set_title(title) if xlabel is not None: ax.set_xlabel(xlabel) if ylabel is not None: ax.set_ylabel(ylabel) ax.grid(grid) return ax
[docs]def plot_metric(booster, metric=None, dataset_names=None, ax=None, xlim=None, ylim=None, title='Metric during training', xlabel='Iterations', ylabel='auto', figsize=None, grid=True): """Plot one metric during training. Parameters ---------- booster : dict or LGBMModel Dictionary returned from ``lightgbm.train()`` or LGBMModel instance. metric : string or None, optional (default=None) The metric name to plot. Only one metric supported because different metrics have various scales. If None, first metric picked from dictionary (according to hashcode). dataset_names : list of strings or None, optional (default=None) List of the dataset names which are used to calculate metric to plot. If None, all datasets are used. ax : matplotlib.axes.Axes or None, optional (default=None) Target axes instance. If None, new figure and axes will be created. xlim : tuple of 2 elements or None, optional (default=None) Tuple passed to ``ax.xlim()``. ylim : tuple of 2 elements or None, optional (default=None) Tuple passed to ``ax.ylim()``. title : string or None, optional (default="Metric during training") Axes title. If None, title is disabled. xlabel : string or None, optional (default="Iterations") X-axis title label. If None, title is disabled. ylabel : string or None, optional (default="auto") Y-axis title label. If 'auto', metric name is used. If None, title is disabled. figsize : tuple of 2 elements or None, optional (default=None) Figure size. grid : bool, optional (default=True) Whether to add a grid for axes. Returns ------- ax : matplotlib.axes.Axes The plot with metric's history over the training. """ if MATPLOTLIB_INSTALLED: import matplotlib.pyplot as plt else: raise ImportError('You must install matplotlib to plot metric.') if isinstance(booster, LGBMModel): eval_results = deepcopy(booster.evals_result_) elif isinstance(booster, dict): eval_results = deepcopy(booster) else: raise TypeError('booster must be dict or LGBMModel.') num_data = len(eval_results) if not num_data: raise ValueError('eval results cannot be empty.') if ax is None: if figsize is not None: _check_not_tuple_of_2_elements(figsize, 'figsize') _, ax = plt.subplots(1, 1, figsize=figsize) if dataset_names is None: dataset_names = iter(eval_results.keys()) elif not isinstance(dataset_names, (list, tuple, set)) or not dataset_names: raise ValueError('dataset_names should be iterable and cannot be empty') else: dataset_names = iter(dataset_names) name = next(dataset_names) # take one as sample metrics_for_one = eval_results[name] num_metric = len(metrics_for_one) if metric is None: if num_metric > 1: msg = """more than one metric available, picking one to plot.""" warnings.warn(msg, stacklevel=2) metric, results = metrics_for_one.popitem() else: if metric not in metrics_for_one: raise KeyError('No given metric in eval results.') results = metrics_for_one[metric] num_iteration, max_result, min_result = len(results), max(results), min(results) x_ = range_(num_iteration) ax.plot(x_, results, label=name) for name in dataset_names: metrics_for_one = eval_results[name] results = metrics_for_one[metric] max_result, min_result = max(max(results), max_result), min(min(results), min_result) ax.plot(x_, results, label=name) ax.legend(loc='best') if xlim is not None: _check_not_tuple_of_2_elements(xlim, 'xlim') else: xlim = (0, num_iteration) ax.set_xlim(xlim) if ylim is not None: _check_not_tuple_of_2_elements(ylim, 'ylim') else: range_result = max_result - min_result ylim = (min_result - range_result * 0.2, max_result + range_result * 0.2) ax.set_ylim(ylim) if ylabel == 'auto': ylabel = metric if title is not None: ax.set_title(title) if xlabel is not None: ax.set_xlabel(xlabel) if ylabel is not None: ax.set_ylabel(ylabel) ax.grid(grid) return ax
def _to_graphviz(tree_info, show_info, feature_names, precision=None, **kwargs): """Convert specified tree to graphviz instance. See: - https://graphviz.readthedocs.io/en/stable/api.html#digraph """ if GRAPHVIZ_INSTALLED: from graphviz import Digraph else: raise ImportError('You must install graphviz to plot tree.') def add(root, parent=None, decision=None): """Recursively add node or edge.""" if 'split_index' in root: # non-leaf name = 'split{0}'.format(root['split_index']) if feature_names is not None: label = 'split_feature_name: {0}'.format(feature_names[root['split_feature']]) else: label = 'split_feature_index: {0}'.format(root['split_feature']) label += r'\nthreshold: {0}'.format(_float2str(root['threshold'], precision)) for info in show_info: if info in {'split_gain', 'internal_value'}: label += r'\n{0}: {1}'.format(info, _float2str(root[info], precision)) elif info == 'internal_count': label += r'\n{0}: {1}'.format(info, root[info]) graph.node(name, label=label) if root['decision_type'] == '<=': l_dec, r_dec = '<=', '>' elif root['decision_type'] == '==': l_dec, r_dec = 'is', "isn't" else: raise ValueError('Invalid decision type in tree model.') add(root['left_child'], name, l_dec) add(root['right_child'], name, r_dec) else: # leaf name = 'leaf{0}'.format(root['leaf_index']) label = 'leaf_index: {0}'.format(root['leaf_index']) label += r'\nleaf_value: {0}'.format(_float2str(root['leaf_value'], precision)) if 'leaf_count' in show_info: label += r'\nleaf_count: {0}'.format(root['leaf_count']) graph.node(name, label=label) if parent is not None: graph.edge(parent, name, decision) graph = Digraph(**kwargs) add(tree_info['tree_structure']) return graph
[docs]def create_tree_digraph(booster, tree_index=0, show_info=None, precision=None, old_name=None, old_comment=None, old_filename=None, old_directory=None, old_format=None, old_engine=None, old_encoding=None, old_graph_attr=None, old_node_attr=None, old_edge_attr=None, old_body=None, old_strict=False, **kwargs): """Create a digraph representation of specified tree. Note ---- For more information please visit https://graphviz.readthedocs.io/en/stable/api.html#digraph. Parameters ---------- booster : Booster or LGBMModel Booster or LGBMModel instance to be converted. tree_index : int, optional (default=0) The index of a target tree to convert. show_info : list of strings or None, optional (default=None) What information should be shown in nodes. Possible values of list items: 'split_gain', 'internal_value', 'internal_count', 'leaf_count'. precision : int or None, optional (default=None) Used to restrict the display of floating point values to a certain precision. **kwargs Other parameters passed to ``Digraph`` constructor. Check https://graphviz.readthedocs.io/en/stable/api.html#digraph for the full list of supported parameters. Returns ------- graph : graphviz.Digraph The digraph representation of specified tree. """ if isinstance(booster, LGBMModel): booster = booster.booster_ elif not isinstance(booster, Booster): raise TypeError('booster must be Booster or LGBMModel.') for param_name in ['old_name', 'old_comment', 'old_filename', 'old_directory', 'old_format', 'old_engine', 'old_encoding', 'old_graph_attr', 'old_node_attr', 'old_edge_attr', 'old_body']: param = locals().get(param_name) if param is not None: warnings.warn('{0} parameter is deprecated and will be removed in 2.4 version.\n' 'Please use **kwargs to pass {1} parameter.'.format(param_name, param_name[4:]), LGBMDeprecationWarning) if param_name[4:] not in kwargs: kwargs[param_name[4:]] = param if locals().get('strict'): warnings.warn('old_strict parameter is deprecated and will be removed in 2.4 version.\n' 'Please use **kwargs to pass strict parameter.', LGBMDeprecationWarning) if 'strict' not in kwargs: kwargs['strict'] = True model = booster.dump_model() tree_infos = model['tree_info'] if 'feature_names' in model: feature_names = model['feature_names'] else: feature_names = None if tree_index < len(tree_infos): tree_info = tree_infos[tree_index] else: raise IndexError('tree_index is out of range.') if show_info is None: show_info = [] graph = _to_graphviz(tree_info, show_info, feature_names, precision, **kwargs) return graph
[docs]def plot_tree(booster, ax=None, tree_index=0, figsize=None, old_graph_attr=None, old_node_attr=None, old_edge_attr=None, show_info=None, precision=None, **kwargs): """Plot specified tree. Note ---- It is preferable to use ``create_tree_digraph()`` because of its lossless quality and returned objects can be also rendered and displayed directly inside a Jupyter notebook. Parameters ---------- booster : Booster or LGBMModel Booster or LGBMModel instance to be plotted. ax : matplotlib.axes.Axes or None, optional (default=None) Target axes instance. If None, new figure and axes will be created. tree_index : int, optional (default=0) The index of a target tree to plot. figsize : tuple of 2 elements or None, optional (default=None) Figure size. show_info : list of strings or None, optional (default=None) What information should be shown in nodes. Possible values of list items: 'split_gain', 'internal_value', 'internal_count', 'leaf_count'. precision : int or None, optional (default=None) Used to restrict the display of floating point values to a certain precision. **kwargs Other parameters passed to ``Digraph`` constructor. Check https://graphviz.readthedocs.io/en/stable/api.html#digraph for the full list of supported parameters. Returns ------- ax : matplotlib.axes.Axes The plot with single tree. """ if MATPLOTLIB_INSTALLED: import matplotlib.pyplot as plt import matplotlib.image as image else: raise ImportError('You must install matplotlib to plot tree.') for param_name in ['old_graph_attr', 'old_node_attr', 'old_edge_attr']: param = locals().get(param_name) if param is not None: warnings.warn('{0} parameter is deprecated and will be removed in 2.4 version.\n' 'Please use **kwargs to pass {1} parameter.'.format(param_name, param_name[4:]), LGBMDeprecationWarning) if param_name[4:] not in kwargs: kwargs[param_name[4:]] = param if ax is None: if figsize is not None: _check_not_tuple_of_2_elements(figsize, 'figsize') _, ax = plt.subplots(1, 1, figsize=figsize) graph = create_tree_digraph(booster=booster, tree_index=tree_index, show_info=show_info, precision=precision, **kwargs) s = BytesIO() s.write(graph.pipe(format='png')) s.seek(0) img = image.imread(s) ax.imshow(img) ax.axis('off') return ax