朴素贝叶斯 Numpy实现高斯朴素贝叶斯
Numpy实现朴素贝叶斯
朴素贝叶斯
y = arg ? max ? c k P ( Y = c k ) ∏ j = 1 n P ( X j = x ( j ) Y = c k ) y=\arg \max _{c_{k}} P\left(Y=c_{k}\right) \prod_{j=1}^{n} P\left(X_{j}=x^{(j)} Y=c_{k}\right) y=argck?max?P(Y=ck?)j=1∏n?P(Xj?=x(j)Y=ck?)
后验概率最大等价于0-1损失函数时的期望风险最小化。
GaussianNB 高斯朴素贝叶斯
特征的可能性被假设为高斯
概率密度函数:
P
(
x
i
∣
y
k
)
=
1
2
π
σ
y
k
2
e
x
p
(
?
(
x
i
?
μ
y
k
)
2
2
σ
y
k
2
)
P(x_i | y_k)=\frac{1}{\sqrt{2\pi\sigma^2_{yk}}}exp(-\frac{(x_i-\mu_{yk})^2}{2\sigma^2_{yk}})
P(xi?∣yk?)=2πσyk2??1?exp(?2σyk2?(xi??μyk?)2?)
数学期望(mean): μ \mu μ
方差: σ 2 = ∑ ( X ? μ ) 2 N \sigma^2=\frac{\sum(X-\mu)^2}{N} σ2=N∑(X?μ)2?
代码实现
import numpy as np
from scipy.stats import norm
class GaussianNaiveBayes:
def fit(self, X, y):
# 获取类别标签
self.classes = np.unique(y)
# 计算每个类别的先验概率
self.class_probs = self._calculate_class_probs(y)
# 计算每个类别的特征均值和方差
self.mean, self.variance = self._calculate_statistics(X, y)
def predict(self, X):
# 对每个样本进行预测
predictions = [self._predict_instance(x) for x in X]
return np.array(predictions)
def _predict_instance(self, x):
# 计算每个类别的后验概率,并返回具有最大后验概率的类别
posteriors = []
for idx, c in enumerate(self.classes):
prior = np.log(self.class_probs[idx])
posterior = np.sum(np.log(norm.pdf(x, loc=self.mean[idx], scale=np.sqrt(self.variance[idx]))))
posterior += prior
posteriors.append(posterior)
return self.classes[np.argmax(posteriors)]
def _calculate_class_probs(self, y):
# 计算每个类别的先验概率
class_probs = [np.sum(y == c) / len(y) for c in self.classes]
return class_probs
def _calculate_statistics(self, X, y):
# 计算每个类别的特征均值和方差
mean = []
variance = []
for c in self.classes:
X_c = X[y == c]
mean_c = np.mean(X_c, axis=0)
variance_c = np.var(X_c, axis=0)
mean.append(mean_c)
variance.append(variance_c)
return np.array(mean), np.array(variance)
# 生成一些示例数据
np.random.seed(42)
X = np.random.rand(100, 2)
y = (X[:, 0] + X[:, 1] > 1).astype(int)
# 创建并训练高斯朴素贝叶斯分类器
nb_classifier = GaussianNaiveBayes()
nb_classifier.fit(X, y)
# 预测新样本
new_samples = np.array([[0.8, 0.2], [0.4, 0.6]])
predictions = nb_classifier.predict(new_samples)
print("预测结果:", predictions)
预测结果: [1 0]
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!