机器学习---决策树和随机森林代码
2023-12-20 09:46:51
1、决策树代码
1.object ClassificationDecisionTree {
2.
3. def main(args: Array[String]): Unit = {
4. val conf = new SparkConf()
5. conf.setAppName("analysItem")
6. conf.setMaster("local[3]")
7. val sc = new SparkContext(conf)
8. val data = MLUtils.loadLibSVMFile(sc, "汽车数据样本.txt")
9. // Split the data into training and test sets (30% held out for testing)
10. val splits = data.randomSplit(Array(0.7, 0.3))
11. val (trainingData, testData) = (splits(0), splits(1))
12. //指明类别
13. val numClasses=2
14. //指定离散变量,未指明的都当作连续变量处理
15. //1,2,3,4维度进来就变成了0,1,2,3
16. //这里天气维度有3类,但是要指明4,这里是个坑,后面以此类推
17. val categoricalFeaturesInfo=Map[Int,Int](0->4,1->4,2->3,3->3)
18. //设定评判标准 "gini"/"entropy"
19. val impurity="entropy"
20. //树的最大深度,太深运算量大也没有必要 剪枝 防止模型的过拟合!!!
21. val maxDepth=3
22. //设置离散化程度,连续数据需要离散化,分成32个区间,默认其实就是32,分割的区间保证数量差不多 这个参数也可以进行剪枝
23. val maxBins=32
24. //生成模型
25. val model =DecisionTree.trainClassifier(trainingData,numClasses,categoricalFeaturesInfo,impurity,maxDepth,maxBins)
26. //测试
27. val labelAndPreds = testData.map { point =>
28. val prediction = model.predict(point.features)
29. (point.label, prediction)
30. }
31. val testErr = labelAndPreds.filter(r => r._1 != r._2).count().toDouble / testData.count()
32. println("Test Error = " + testErr)
33. println("Learned classification tree model:\n" + model.toDebugString)
34.
35. }
36.}
2、随机森林代码
1.object ClassificationRandomForest {
2. def main(args: Array[String]): Unit = {
3. val conf = new SparkConf()
4. conf.setAppName("analysItem")
5. conf.setMaster("local[3]")
6. val sc = new SparkContext(conf)
7. //读取数据
8. val data = MLUtils.loadLibSVMFile(sc,"汽车数据样本.txt")
9. //将样本按7:3的比例分成
10. val splits = data.randomSplit(Array(0.7, 0.3))
11. val (trainingData, testData) = (splits(0), splits(1))
12. //分类数
13. val numClasses = 2
14. // categoricalFeaturesInfo 为空,意味着所有的特征为连续型变量
15. val categoricalFeaturesInfo =Map[Int, Int](0->4,1->4,2->3,3->3)
16. //树的个数
17. val numTrees = 3
18. //特征子集采样策略,auto 表示算法自主选取
19. //"auto"根据特征数量在4个中进行选择
20. // 1,all 全部特征 2,sqrt 把特征数量开根号后随机选择的 3,log2 取对数个 4,onethird 三分之一
21. val featureSubsetStrategy = "auto"
22. //纯度计算 "gini"/"entropy"
23. val impurity = "entropy"
24. //树的最大层次
25. val maxDepth = 3
26. //特征最大装箱数,即连续数据离散化的区间
27. val maxBins = 32
28. //训练随机森林分类器,trainClassifier 返回的是 RandomForestModel 对象
29. val model = RandomForest.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
30. numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)
31. //打印模型
32. println(model.toDebugString)
33. //保存模型
34. //model.save(sc,"汽车保险")
35. //在测试集上进行测试
36. val count = testData.map { point =>
37. val prediction = model.predict(point.features)
38. // Math.abs(prediction-point.label)
39. (prediction,point.label)
40. }.filter(r => r._1 != r._2).count()
41. println("Test Error = " + count.toDouble/testData.count().toDouble)
42. println()
43. }
44.}
文章来源:https://blog.csdn.net/yaya_jn/article/details/135099044
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。 如若内容造成侵权/违法违规/事实不符,请联系我的编程经验分享网邮箱:veading@qq.com进行投诉反馈,一经查实,立即删除!