机器学习中Entropy怎么理解和计算?

文章导读
Previous Quiz Next 熵(Entropy)是一个源于热力学的概念,后来被应用于多个领域,包括信息论、统计学和机器学习。在机器学习中,熵被用作衡量数据集不纯度或随机性的指标。具体来说,熵在决策树算法中用于决定如何分割数据,以创建更均匀的子集。本文将讨论机器学习
📋 目录
  1. 示例 - Entropy 实现
A A

机器学习 - 熵



Previous
Quiz
Next

熵(Entropy)是一个源于热力学的概念,后来被应用于多个领域,包括信息论、统计学和机器学习。在机器学习中,熵被用作衡量数据集不纯度或随机性的指标。具体来说,熵在决策树算法中用于决定如何分割数据,以创建更均匀的子集。本文将讨论机器学习中的熵、其特性以及在 Python 中的实现。

熵被定义为系统中无序或随机性的度量。在决策树上下文中,熵被用作衡量节点不纯度的指标。如果节点中的所有样本都属于同一类,则该节点被视为纯节点。相反,如果节点包含来自多个类的样本,则该节点是不纯的。

要计算熵,首先需要定义数据集中每个类的概率。设 p(i) 为样本属于类 i 的概率。如果有 k 个类,则系统的总熵,用 H(S) 表示,计算公式如下 −

$$H\left ( S \right )=-sum\left ( p\left ( i \right )\ast log_{2}\left ( p\left ( i \right ) \right ) \right )$$

其中求和针对所有 k 个类进行。该方程称为 Shannon 熵。

例如,假设有一个包含 100 个样本的数据集,其中 60 个属于类 A,40 个属于类 B。那么类 A 的概率为 0.6,类 B 的概率为 0.4。该数据集的熵为 −

$$H\left ( S \right )=-(0.6\times log_{2}(0.6)+ 0.4\times log_{2}(0.4)) = 0.971$$

如果数据集中的所有样本都属于同一类,则熵为 0,表示纯节点。另一方面,如果样本在所有类中均匀分布,则熵较高,表示不纯节点。

在决策树算法中,熵用于确定每个节点的最佳分割。目标是创建导致最均匀子集的分割。这通过计算每个可能分割的熵并选择总熵最低的分割来实现。

例如,假设有一个具有两个特征 X1 和 X2 的数据集,目标是预测类标签 Y。我们首先计算整个数据集的熵 H(S)。接下来,计算基于每个特征的每个可能分割的熵。例如,可以基于 X1 的值或 X2 的值分割数据。每个分割的熵计算如下 −

$$H\left ( X_{1} \right )=p_{1}\times H\left ( S_{1} \right )+p_{2}\times H\left ( S_{2} \right )H\left ( X_{2} \right )=p_{3}\times H\left ( S_{3} \right )+p_{4}\times H\left ( S_{4} \right )$$

其中 p1、p2、p3 和 p4 是每个子集的概率;H(S1)、H(S2)、H(S3) 和 H(S4) 是每个子集的熵。

然后选择导致总熵最低的分割,该总熵由下式给出 −

$$H_{split}=H\left ( X_{1} \right )\, if\, H\left ( X_{1} \right )\leq H\left ( X_{2} \right );\: else\: H\left ( X_{2} \right )$$

此分割随后用于创建决策树的子节点,并递归重复该过程,直到所有节点均为纯节点或满足停止准则。

示例 - Entropy 实现

让我们通过一个示例来理解它如何在 Python 中实现。这里我们将使用 "iris" 数据集 −

from sklearn.datasets import load_iris
import numpy as np

# 加载 iris 数据集
iris = load_iris()

# 提取特征和目标
X = iris.data
y = iris.target

# 定义一个计算 entropy 的函数
def entropy(y):
   n = len(y)
   _, counts = np.unique(y, return_counts=True)
   probs = counts / n
   return -np.sum(probs * np.log2(probs))

# 计算目标变量的 entropy
target_entropy = entropy(y)
print(f"Target entropy: {target_entropy:.3f}")

上述代码加载 iris 数据集,提取特征和目标,并定义了一个计算 entropy 的函数。entropy() 函数接收目标值的向量,并返回集合的 entropy。

该函数首先计算集合中的示例数量和每个 class 的计数。然后计算每个 class 的比例,并使用这些比例根据 entropy 公式计算集合的 entropy。最后,代码计算 iris 数据集中目标变量的 entropy 并将其打印到控制台。

输出

执行此代码时,将产生以下输出 −

Target entropy: 1.585