您的位置: 首页> 业界 > 正文

【机器学习】集成学习代码练习(随机森林、GBDT、XGBoost、LightGBM等)_全球热讯

2022-12-24 09:46:59 来源:

本文是中国大学慕课《机器学习》的“集成学习”章节的课后代码。

课程地址:

https://www.icourse163.org/course/WZU-1464096179


(资料图片仅供参考)

课程完整代码:

https://github.com/fengdu78/WZU-machine-learning-course

代码修改并注释:黄海广,haiguang2000@wzu.edu.cn

importwarningswarnings.filterwarnings("ignore")importpandasaspdfromsklearn.model_selectionimporttrain_test_split

生成数据

生成12000行的数据,训练集和测试集按照3:1划分

fromsklearn.datasetsimportmake_hastie_10_2data,target=make_hastie_10_2()

X_train,X_test,y_train,y_test=train_test_split(data,target,random_state=123)X_train.shape,X_test.shape

((9000, 10), (3000, 10))

模型对比

对比六大模型,都使用默认参数

fromsklearn.linear_modelimportLogisticRegressionfromsklearn.ensembleimportRandomForestClassifierfromsklearn.ensembleimportAdaBoostClassifierfromsklearn.ensembleimportGradientBoostingClassifierfromxgboostimportXGBClassifierfromlightgbmimportLGBMClassifierfromsklearn.model_selectionimportcross_val_scoreimporttimeclf1=LogisticRegression()clf2=RandomForestClassifier()clf3=AdaBoostClassifier()clf4=GradientBoostingClassifier()clf5=XGBClassifier()clf6=LGBMClassifier()forclf,labelinzip([clf1,clf2,clf3,clf4,clf5,clf6],["LogisticRegression","RandomForest","AdaBoost","GBDT","XGBoost","LightGBM"]):start=time.time()scores=cross_val_score(clf,X_train,y_train,scoring="accuracy",cv=5)end=time.time()running_time=end-startprint("Accuracy:%0.8f (+/-%0.2f),耗时%0.2f秒。模型名称[%s]"%(scores.mean(),scores.std(),running_time,label))

Accuracy: 0.47488889 (+/- 0.00),耗时0.04秒。模型名称[Logistic Regression]Accuracy: 0.88966667 (+/- 0.01),耗时16.34秒。模型名称[Random Forest]Accuracy: 0.88311111 (+/- 0.00),耗时3.39秒。模型名称[AdaBoost]Accuracy: 0.91388889 (+/- 0.01),耗时13.14秒。模型名称[GBDT]Accuracy: 0.92977778 (+/- 0.00),耗时3.60秒。模型名称[XGBoost]Accuracy: 0.93188889 (+/- 0.01),耗时0.58秒。模型名称[LightGBM]

对比了六大模型,可以看出,逻辑回归速度最快,但准确率最低。而LightGBM,速度快,而且准确率最高,所以,现在处理结构化数据的时候,大部分都是用LightGBM算法。

XGBoost的使用 1.原生XGBoost的使用

importxgboostasxgb#记录程序运行时间importtimestart_time=time.time()#xgb矩阵赋值xgb_train=xgb.DMatrix(X_train,y_train)xgb_test=xgb.DMatrix(X_test,label=y_test)##参数params={"booster":"gbtree",#"silent":1,#设置成1则没有运行信息输出,最好是设置为0.#"nthread":7,#cpu线程数默认最大"eta":0.007,#如同学习率"min_child_weight":3,#这个参数默认是1,是每个叶子里面h的和至少是多少,对正负样本不均衡时的0-1分类而言#,假设 h 在0.01 附近,min_child_weight 为 1 意味着叶子节点中最少需要包含 100个样本。#这个参数非常影响结果,控制叶子节点中二阶导的和的最小值,该参数值越小,越容易 overfitting。"max_depth":6,#构建树的深度,越大越容易过拟合"gamma":0.1,#树的叶子节点上作进一步分区所需的最小损失减少,越大越保守,一般0.1、0.2这样子。"subsample":0.7,#随机采样训练样本"colsample_bytree":0.7,#生成树时进行的列采样"lambda":2,#控制模型复杂度的权重值的L2正则化项参数,参数越大,模型越不容易过拟合。#"alpha":0,#L1正则项参数#"scale_pos_weight":1, #如果取值大于0的话,在类别样本不平衡的情况下有助于快速收敛。#"objective":"multi:softmax",#多分类的问题#"num_class":10,#类别数,多分类与multisoftmax并用"seed":1000,#随机种子#"eval_metric":"auc"}plst=list(params.items())num_rounds=500#迭代次数watchlist=[(xgb_train,"train"),(xgb_test,"val")]

#训练模型并保存#early_stopping_rounds当设置的迭代次数较大时,early_stopping_rounds可在一定的迭代次数内准确率没有提升就停止训练model=xgb.train(plst,xgb_train,num_rounds,watchlist,early_stopping_rounds=100,)#model.save_model("./model/xgb.model")#用于存储训练出的模型print("bestbest_ntree_limit",model.best_ntree_limit)y_pred=model.predict(xgb_test,ntree_limit=model.best_ntree_limit)print("error=%f"%(sum(1foriinrange(len(y_pred))ifint(y_pred[i]>0.5)!=y_test[i])/float(len(y_pred))))#输出运行时长cost_time=time.time()-start_timeprint("xgboostsuccess!","\n","costtime:",cost_time,"(s)......")

[0]train-rmse:1.11000val-rmse:1.10422[1]train-rmse:1.10734val-rmse:1.10182[2]train-rmse:1.10465val-rmse:1.09932[3]train-rmse:1.10207val-rmse:1.09694

……

[497]train-rmse:0.62135val-rmse:0.68680[498]train-rmse:0.62096val-rmse:0.68650[499]train-rmse:0.62056val-rmse:0.68624best best_ntree_limit 500error=0.826667xgboost success!  cost time: 3.5742645263671875 (s)......

2.使用scikit-learn接口

会改变的函数名是:

eta -> learning_rate

lambda -> reg_lambda

alpha -> reg_alpha

fromsklearn.model_selectionimporttrain_test_splitfromsklearnimportmetricsfromxgboostimportXGBClassifierclf=XGBClassifier(# silent=0, #设置成1则没有运行信息输出,最好是设置为0.是否在运行升级时打印消息。#nthread=4,#cpu线程数默认最大learning_rate=0.3,#如同学习率min_child_weight=1,#这个参数默认是1,是每个叶子里面h的和至少是多少,对正负样本不均衡时的0-1分类而言#,假设 h 在0.01 附近,min_child_weight 为 1 意味着叶子节点中最少需要包含 100个样本。#这个参数非常影响结果,控制叶子节点中二阶导的和的最小值,该参数值越小,越容易 overfitting。max_depth=6,#构建树的深度,越大越容易过拟合gamma=0,#树的叶子节点上作进一步分区所需的最小损失减少,越大越保守,一般0.1、0.2这样子。subsample=1,#随机采样训练样本训练实例的子采样比max_delta_step=0,#最大增量步长,我们允许每个树的权重估计。colsample_bytree=1,#生成树时进行的列采样reg_lambda=1,#控制模型复杂度的权重值的L2正则化项参数,参数越大,模型越不容易过拟合。#reg_alpha=0,#L1正则项参数#scale_pos_weight=1, #如果取值大于0的话,在类别样本不平衡的情况下有助于快速收敛。平衡正负权重#objective="multi:softmax",#多分类的问题指定学习任务和相应的学习目标#num_class=10,#类别数,多分类与multisoftmax并用n_estimators=100,#树的个数seed=1000#随机种子#eval_metric="auc")clf.fit(X_train,y_train)y_true,y_pred=y_test,clf.predict(X_test)print("Accuracy:%.4g"%metrics.accuracy_score(y_true,y_pred))

Accuracy : 0.936

LIghtGBM的使用 1.原生接口

importlightgbmaslgbfromsklearn.metricsimportmean_squared_error#加载你的数据#print("Loaddata...")#df_train=pd.read_csv("../regression/regression.train",header=None,sep="\t")#df_test=pd.read_csv("../regression/regression.test",header=None,sep="\t")##y_train=df_train[0].values#y_test=df_test[0].values#X_train=df_train.drop(0,axis=1).values#X_test=df_test.drop(0,axis=1).values#创建成lgb特征的数据集格式lgb_train=lgb.Dataset(X_train,y_train)#将数据保存到LightGBM二进制文件将使加载更快lgb_eval=lgb.Dataset(X_test,y_test,reference=lgb_train)#创建验证数据#将参数写成字典下形式params={"task":"train","boosting_type":"gbdt",#设置提升类型"objective":"regression",#目标函数"metric":{"l2","auc"},#评估函数"num_leaves":31,#叶子节点数"learning_rate":0.05,#学习速率"feature_fraction":0.9,#建树的特征选择比例"bagging_fraction":0.8,#建树的样本采样比例"bagging_freq":5,#k意味着每k次迭代执行bagging"verbose":1#<0显示致命的,=0显示错误(警告),>0显示信息}print("Starttraining...")#训练cvandtraingbm=lgb.train(params,lgb_train,num_boost_round=500,valid_sets=lgb_eval,early_stopping_rounds=5)#训练数据需要参数列表和数据集print("Savemodel...")gbm.save_model("model.txt")#训练后保存模型到文件print("Startpredicting...")#预测数据集y_pred=gbm.predict(X_test,num_iteration=gbm.best_iteration)#如果在训练期间启用了早期停止,可以通过best_iteration方式从最佳迭代中获得预测#评估模型print("error=%f"%(sum(1foriinrange(len(y_pred))ifint(y_pred[i]>0.5)!=y_test[i])/float(len(y_pred))))

Start training...[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000448 seconds.You can set `force_col_wise=true` to remove the overhead.[LightGBM] [Info] Total Bins 2550[LightGBM] [Info] Number of data points in the train set: 9000, number of used features: 10[LightGBM] [Info] Start training from score 0.012000[1]valid_0"s auc: 0.814399valid_0"s l2: 0.965563Training until validation scores don"t improve for 5 rounds[2]valid_0"s auc: 0.84729valid_0"s l2: 0.934647[3]valid_0"s auc: 0.872805valid_0"s l2: 0.905265[4]valid_0"s auc: 0.884117valid_0"s l2: 0.877875[5]valid_0"s auc: 0.895115valid_0"s l2: 0.852189

……

[191]valid_0"s auc: 0.982783valid_0"s l2: 0.319851[192]valid_0"s auc: 0.982751valid_0"s l2: 0.319971[193]valid_0"s auc: 0.982685valid_0"s l2: 0.320043Early stopping, best iteration is:[188]valid_0"s auc: 0.982794valid_0"s l2: 0.319746Save model...Start predicting...error=0.664000

2.scikit-learn接口

fromsklearnimportmetricsfromlightgbmimportLGBMClassifierclf=LGBMClassifier(boosting_type="gbdt",#提升树的类型gbdt,dart,goss,rfnum_leaves=31,#树的最大叶子数,对比xgboost一般为2^(max_depth)max_depth=-1,#最大树的深度learning_rate=0.1,#学习率n_estimators=100,#拟合的树的棵树,相当于训练轮数subsample_for_bin=200000,objective=None,class_weight=None,min_split_gain=0.0,#最小分割增益min_child_weight=0.001,#分支结点的最小权重min_child_samples=20,subsample=1.0,#训练样本采样率行subsample_freq=0,#子样本频率colsample_bytree=1.0,#训练特征采样率列reg_alpha=0.0,#L1正则化系数reg_lambda=0.0,#L2正则化系数random_state=None,n_jobs=-1,silent=True,)clf.fit(X_train,y_train,eval_metric="auc")#设置验证集合verbose=False不打印过程clf.fit(X_train,y_train)y_true,y_pred=y_test,clf.predict(X_test)print("Accuracy:%.4g"%metrics.accuracy_score(y_true,y_pred))

Accuracy : 0.927

参考

1.https://xgboost.readthedocs.io/

2.https://lightgbm.readthedocs.io/

3.https://blog.csdn.net/q383700092/article/details/53763328?locationNum=9&fps=1

往期精彩回顾适合初学者入门人工智能的路线及资料下载(图文+视频)机器学习入门系列下载机器学习及深度学习笔记等资料打印《统计学习方法》的代码复现专辑机器学习交流qq群955171419,加入微信群请扫码

关键词:

资讯
业界
企业
骑闻
产品
证券公司理财产品安全吗?开证券账户有什么风险?
证券公司理财产品安全吗1 从平台正规性来看,证券公司和银行一样都是正规的理财平台,它们的产品都是经过监管部门审批的,2 从理财产品本身
2022-06-23
  海军驻西沙某场站基础设施建设提档升级——  军嫂上岛有了温馨的家  解放军报特约记者 高宏伟 通讯员 沈宏权 傅金泉  春节前
2022-01-24
  1月23日,新疆维吾尔自治区第十三届人民代表大会第五次会议在乌鲁木齐开幕。开幕当天的政府工作报告中指出,新疆全面推进反恐维稳法治
2022-01-24
青海海西州德令哈市附近发生5.8级地震
  中新网1月23日电 据国家地震台网官方微博消息,中国地震台网正式测定:1月23日10时21分在青海海西州德令哈市(北纬38 44度,东经97 37
2022-01-24
迎春节旅游旺季 三亚启动“奇趣新春·游品三亚”系列活动
  中新网三亚1月23日电 (记者 王晓斌)1月22日晚,“奇趣新春·游品三亚”2022新春文化旅游季活动正式启动,即日起至春节假期结束,三亚
2022-01-24
  虎年春节将至,作为广府文化传统习俗的广州迎春花市还办不办、怎么办?1月23日,记者从广州市花市办获悉,经征求各方意见,广州市政府
2022-01-24
外汇局:11月我国国际收支货物和服务贸易进出口规模同比降3%|短讯
国家外汇管理局统计数据显示,2022年11月,我国国际收支货物和服务贸易进出口规模39804亿元,同比下降3%。按美元计值,2022年11月,我国国际收
2022-12-30
2022广州车展:Huracán Tecnica亮相|天天热点评
[汽车之家新车首发]2022广州车展正式开幕,在本届车展上,兰博基尼HuracánTecnica车型正式亮相。新车凝聚了兰博基尼在设计及工程学方面的专业
2022-12-30
报道:金发拉比12月30日快速反弹
以下是金发拉比在北京时间12月30日09:41分盘口异动快照:12月30日,金发拉比盘中快速反弹,5分钟内涨幅超过2%,截至9点41分,报8 98元,成交75
2022-12-30
河北省省区劳动争议律师费用一般怎么计算
1、按件收费收取(1)无财产争议:6000元-20000元之间;(2)法律文书:600元-2000元之间;(3)律师见证:2000元-10000元之间;(4)代办公证
2022-12-30
环球要闻:中央广播电视总台发布2022年度十大国内、十大国际军事新闻
央视军事中央广播电视总台发布2022年度十大国内军事新闻2022年度十大国际军事新闻
2022-12-29
敏芯股份(688286.SH):高管张辰良完成减持15.30万股_每日焦点
格隆汇12月29日丨敏芯股份公布,2022年12月29日,公司收到高级管理人员张辰良出具的《关于股份减持结果的告知函》。截至公告披露日,公司高级
2022-12-29
  中新网海口1月23日电(王子谦 符宇群)海南省高级人民法院院长陈凤超23日说,2021年海南法院为自贸港建设提供坚强司法保障,全年有效管
2022-01-24
  新华社武汉1月23日电(记者王贤)随着春节假期临近,从广州、深圳等地返回湖北的旅客较多。为此,23日,武汉站、汉口站、襄阳东站、十堰
2022-01-24
  1月22日0—24时,广东省新增本土确诊病例3例和本土无症状感染者1例,均为珠海报告。23日,珠海市疫情防控新闻发布会上,珠海市政府副秘
2022-01-24
青海海西州德令哈市发生3.7级地震
  据中国地震台网正式测定,1月23日11时58分在青海海西州德令哈市发生3 7级地震,震源深度9千米,震中位于北纬38 40度,东经97 35度。
2022-01-24
  北京2022年冬奥会和冬残奥会颁奖花束已于近期完成交付。与传统的鲜切花不同,这些花束全部采用上海市非物质文化遗产“海派绒线编结技艺
2022-01-24
  疫情就是命令,防控就是责任。在抗击疫情的关键时刻,西安全员上下一盘棋,同舟共济、共克时艰。不论是党员干部或是社区志愿者,他们都
2022-01-24
  中新网宿迁1月23日电 (刘林 张华东)核酸检测是当下及时发现潜在感染者、阻断疫情传播的有效方法。23日,记者从宿迁市宿豫区警方获悉
2022-01-24
  记者从天津市人社局获悉,从明天(24日)起,天津2022年度第一期积分落户申报工作正式开始,这是新修订的《天津市居住证管理办法》《天津
2022-01-24
  中新社北京1月23日电 (记者 刘亮)记者23日从中国海关总署获悉,2021年,中国海关组织开展“国门绿盾”专项行动,在寄递、旅客携带物
2022-01-24
  记者从天津市疫情防控指挥部获悉,天津疫情第341—360例阳性感染者基本信息公布。  目前,这20例阳性感染者已转运至市定点医院做进一
2022-01-24
“最美基层民警”武文斌:案子破了最管用
  中新网吕梁1月23日电 题:“最美基层民警”武文斌:案子破了最管用  作者 高瑞峰  同事称他为“拼命三郎”。从警14年,武文斌破
2022-01-24
  据“西安发布”消息,截至2022年1月23日,雁塔区长延堡街道近14天内无新增本地病例和聚集性疫情。根据国务院联防联控机制关于分区分级
2022-01-24
  中新网西宁1月23日电 (记者 孙睿)据青海省地震台网测定,2022年1月23日10点21分(北京时间)在青海省海西州德令哈市(北纬38 44度,东经
2022-01-24
江西南昌:市民赏年画迎新年 书法家挥毫送春联
  (新春见闻)江西南昌:市民赏年画迎新年 书法家挥毫送春联  1月23日,“赏年画过大年”新年画作品联展江西南昌站活动在江西省文化馆
2022-01-24
  中新网成都1月23日电 (祝欢)成都市第十七届人民代表大会第六次会议23日在成都举行,成都市中级人民法院院长郭彦与成都市人民检察院检
2022-01-24
列车临时停车3分钟救旅客
  (新春见闻)列车临时停车3分钟救旅客  中新网广州1月23日电 (郭军 黄伟伟)“车长,车长,4号车厢有位旅客腹涨难忍,身体不舒服”…
2022-01-24
女子背负命案潜逃24年 因涉疫人员核查落网
  中新网湖州1月23日电(施紫楠 徐盛煜 赵学良)1998年7月,犯罪嫌疑人杜某因家庭琐事,用菜刀将自己的弟媳砍伤致死。案发后,她从老家河
2022-01-24
广东“00后”雄狮少年锤炼功夫迎新春
  (新春见闻)广东“00后”雄狮少年锤炼功夫迎新春  中新社广州1月23日电 题:广东“00后”雄狮少年锤炼功夫迎新春  作者 孙秋霞 
2022-01-24
08-18 宁夏文旅厅:推动建立旅游企业“首席质量官”和“标杆服务员”
宁夏文旅厅:推动建立旅游企业“首席质量官”和“标杆服务员”
宁夏回族自治区文化和旅游厅近日印发《自治区文化和旅游厅关于推动建立旅游企业首席质量官和标杆服务员制度的通知》(以下简称《通知》),提 [详细]
08-18 第七届中国非物质文化遗产博览会将于8月25日至29日在济南举行
第七届中国非物质文化遗产博览会将于8月25日至29日在济南举行
8月17日,文化和旅游部召开第七届中国非物质文化遗产博览会(以下简称博览会)新闻发布会。会上宣布,博览会将于8月25日至29日在山东省济南市 [详细]
01-24 西安浐灞回应“社区领导怒怼咨询群众”:涉事社区主任已停职
西安浐灞回应“社区领导怒怼咨询群众”:涉事社区主任已停职
  西安浐灞回应“一社区领导在市民咨询离市政策时发生争执”事件 涉事社区主任已停职  西部网讯(记者 刘望)日前,网络上流传一条视频 [详细]
01-24 陕西:截至23日12时 西安56.5万大中专学生已离校返家
陕西:截至23日12时 西安56.5万大中专学生已离校返家
  1月23日,陕西省举行第45场疫情防控工作发布会,发布会上陕西省教育厅相关负责人通报,陕西全省疫情有效控制后,大中专学校能不能放假 [详细]
01-24 河北魏县发布北京一阳性人员在魏县的主要轨迹
河北魏县发布北京一阳性人员在魏县的主要轨迹
  魏县疾病预防控制中心关于紧急寻找丰台区新冠肺炎阳性检测者同时间同空间人员的公告  2022年01月22日,接到邯郸市疾控中心转北京市疾 [详细]
01-24 陕西:滞留西安的外省研考生已于1月15日安全返乡
陕西:滞留西安的外省研考生已于1月15日安全返乡
  总台记者从陕西省第45场疫情防控工作新闻发布会上获悉,2022年全国研究生考试陕西全省报名16 8万人,其中应在西安市参考11 85万人,实 [详细]
01-24 宁夏:“草根主播”把货卖 “线上赶集”年味浓
宁夏:“草根主播”把货卖 “线上赶集”年味浓
  (新春走基层)宁夏:“草根主播”把货卖 “线上赶集”年味浓  中新网宁夏红寺堡1月23日电 题:宁夏:“草根主播”把货卖 “线上赶 [详细]