政安晨:【Keras机器学习示例演绎】(五十三)—— 使用 TensorFlow 决策森林进行分类
目录
简介
设置
准备数据
定义数据集元数据
配置超参数
实施培训和评估程序
实验 1:使用原始特征的决策森林
检查模型
实验 2:目标编码决策森林
创建模型输入
使用目标编码实现特征编码
使用预处理器创建梯度提升树模型
训练和评估模型
实验 3:决策森林与训练嵌入
结束语
政安晨的个人主页:政安晨
欢迎 👍点赞✍评论⭐收藏
收录专栏: TensorFlow与Keras机器学习实战
希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!
本文目标:使用 TensorFlow 决策森林进行结构化数据分类。
简介
TensorFlow 决策森林(TensorFlow Decision Forests)是与 Keras API 兼容的决策森林模型的最新算法集合。 这些模型包括随机森林(Random Forests)、梯度提升树(Gradient Boosted Trees)和 CART,可用于回归、分类和排序任务。
本示例使用梯度提升树模型对结构化数据进行二进制分类,包括以下场景:
1. 通过指定输入特征用法来构建决策森林模型。
2. 将自定义二进制目标编码器作为 Keras 预处理层来实现,以便根据目标值共现对分类特征进行编码,然后使用编码后的特征构建决策森林模型。
3. 将分类特征编码为嵌入,在简单的 NN 模型中训练这些嵌入,然后使用训练后的嵌入作为输入构建决策森林模型。
本示例使用 TensorFlow 2.7 或更高版本以及 TensorFlow 决策森林,您可以使用以下命令安装 TensorFlow 决策森林:
pip install -U tensorflow_decision_forests
设置
import math
import urllib
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_decision_forests as tfdf
准备数据
本示例使用加州大学欧文分校机器学习资料库提供的美国人口普查收入数据集。
该数据集包含约 30 万个实例和 41 个输入特征,其中有 7 个数字特征和 34 个分类特征:
首先,我们将 UCI 机器学习库中的数据加载到 Pandas DataFrame 中。
BASE_PATH = "https://kdd.ics.uci.edu/databases/census-income/census-income"
CSV_HEADER = [l.decode("utf-8").split(":")[0].replace(" ", "_")for l in urllib.request.urlopen(f"{BASE_PATH}.names")if not l.startswith(b"|")
][2:]
CSV_HEADER.append("income_level")train_data = pd.read_csv(f"{BASE_PATH}.data.gz", header=None, names=CSV_HEADER,)
test_data = pd.read_csv(f"{BASE_PATH}.test.gz", header=None, names=CSV_HEADER,)
定义数据集元数据
在此,我们定义了数据集的元数据,这些元数据将有助于根据输入特征的类型对其进行编码。
# Target column name.
TARGET_COLUMN_NAME = "income_level"
# The labels of the target columns.
TARGET_LABELS = [" - 50000.", " 50000+."]
# Weight column name.
WEIGHT_COLUMN_NAME = "instance_weight"
# Numeric feature names.
NUMERIC_FEATURE_NAMES = ["age","wage_per_hour","capital_gains","capital_losses","dividends_from_stocks","num_persons_worked_for_employer","weeks_worked_in_year",
]
# Categorical features and their vocabulary lists.
CATEGORICAL_FEATURE_NAMES = ["class_of_worker","detailed_industry_recode","detailed_occupation_recode","education","enroll_in_edu_inst_last_wk","marital_stat","major_industry_code","major_occupation_code","race","hispanic_origin","sex","member_of_a_labor_union","reason_for_unemployment","full_or_part_time_employment_stat","tax_filer_stat","region_of_previous_residence","state_of_previous_residence","detailed_household_and_family_stat","detailed_household_summary_in_household","migration_code-change_in_msa","migration_code-change_in_reg","migration_code-move_within_reg","live_in_this_house_1_year_ago","migration_prev_res_in_sunbelt","family_members_under_18","country_of_birth_father","country_of_birth_mother","country_of_birth_self","citizenship","own_business_or_self_employed","fill_inc_questionnaire_for_veteran's_admin","veterans_benefits","year",
]
现在我们进行基本的数据准备。
def prepare_dataframe(dataframe):# Convert the target labels from string to integer.dataframe[TARGET_COLUMN_NAME] = dataframe[TARGET_COLUMN_NAME].map(TARGET_LABELS.index)# Cast the categorical features to string.for feature_name in CATEGORICAL_FEATURE_NAMES:dataframe[feature_name] = dataframe[feature_name].astype(str)prepare_dataframe(train_data)
prepare_dataframe(test_data)
现在,让我们展示训练数据帧和测试数据帧的形状,并显示一些实例。
print(f"Train data shape: {train_data.shape}")
print(f"Test data shape: {test_data.shape}")
print(train_data.head().T)
Train data shape: (199523, 42)
Test data shape: (99762, 42)0 \
age 73
class_of_worker Not in universe
detailed_industry_recode 0
detailed_occupation_recode 0
education High school graduate
wage_per_hour 0
enroll_in_edu_inst_last_wk Not in universe
marital_stat Widowed
major_industry_code Not in universe or children
major_occupation_code Not in universe
race White
hispanic_origin All other
sex Female
member_of_a_labor_union Not in universe
reason_for_unemployment Not in universe
full_or_part_time_employment_stat Not in labor force
capital_gains 0
capital_losses 0
dividends_from_stocks 0
tax_filer_stat Nonfiler
region_of_previous_residence Not in universe
state_of_previous_residence Not in universe
detailed_household_and_family_stat Other Rel 18+ ever marr not in subfamily
detailed_household_summary_in_household Other relative of householder
instance_weight 1700.09
migration_code-change_in_msa ?
migration_code-change_in_reg ?
migration_code-move_within_reg ?
live_in_this_house_1_year_ago Not in universe under 1 year old
migration_prev_res_in_sunbelt ?
num_persons_worked_for_employer 0
family_members_under_18 Not in universe
country_of_birth_father United-States
country_of_birth_mother United-States
country_of_birth_self United-States
citizenship Native- Born in the United States
own_business_or_self_employed 0
fill_inc_questionnaire_for_veteran's_admin Not in universe
veterans_benefits 2
weeks_worked_in_year 0
year 95
income_level 0
1 \
age 58
class_of_worker Self-employed-not incorporated
detailed_industry_recode 4
detailed_occupation_recode 34
education Some college but no degree
wage_per_hour 0
enroll_in_edu_inst_last_wk Not in universe
marital_stat Divorced
major_industry_code Construction
major_occupation_code Precision production craft & repair
race White
hispanic_origin All other
sex Male
member_of_a_labor_union Not in universe
reason_for_unemployment Not in universe
full_or_part_time_employment_stat Children or Armed Forces
capital_gains 0
capital_losses 0
dividends_from_stocks 0
tax_filer_stat Head of household
region_of_previous_residence South
state_of_previous_residence Arkansas
detailed_household_and_family_stat Householder
detailed_household_summary_in_household Householder
instance_weight 1053.55
migration_code-change_in_msa MSA to MSA
migration_code-change_in_reg Same county
migration_code-move_within_reg Same county
live_in_this_house_1_year_ago No
migration_prev_res_in_sunbelt Yes
num_persons_worked_for_employer 1
family_members_under_18 Not in universe
country_of_birth_father United-States
country_of_birth_mother United-States
country_of_birth_self United-States
citizenship Native- Born in the United States
own_business_or_self_employed 0
fill_inc_questionnaire_for_veteran's_admin Not in universe
veterans_benefits 2
weeks_worked_in_year 52
year 94
income_level 0
2 \
age 18
class_of_worker Not in universe
detailed_industry_recode 0
detailed_occupation_recode 0
education 10th grade
wage_per_hour 0
enroll_in_edu_inst_last_wk High school
marital_stat Never married
major_industry_code Not in universe or children
major_occupation_code Not in universe
race Asian or Pacific Islander
hispanic_origin All other
sex Female
member_of_a_labor_union Not in universe
reason_for_unemployment Not in universe
full_or_part_time_employment_stat Not in labor force
capital_gains 0
capital_losses 0
dividends_from_stocks 0
tax_filer_stat Nonfiler
region_of_previous_residence Not in universe
state_of_previous_residence Not in universe
detailed_household_and_family_stat Child 18+ never marr Not in a subfamily
detailed_household_summary_in_household Child 18 or older
instance_weight 991.95
migration_code-change_in_msa ?
migration_code-change_in_reg ?
migration_code-move_within_reg ?
live_in_this_house_1_year_ago Not in universe under 1 year old
migration_prev_res_in_sunbelt ?
num_persons_worked_for_employer 0
family_members_under_18 Not in universe
country_of_birth_father Vietnam
country_of_birth_mother Vietnam
country_of_birth_self Vietnam
citizenship Foreign born- Not a citizen of U S
own_business_or_self_employed 0
fill_inc_questionnaire_for_veteran's_admin Not in universe
veterans_benefits 2
weeks_worked_in_year 0
year 95
income_level 0
3 \
age 9
class_of_worker Not in universe
detailed_industry_recode 0
detailed_occupation_recode 0
education Children
wage_per_hour 0
enroll_in_edu_inst_last_wk Not in universe
marital_stat Never married
major_industry_code Not in universe or children
major_occupation_code Not in universe
race White
hispanic_origin All other
sex Female
member_of_a_labor_union Not in universe
reason_for_unemployment Not in universe
full_or_part_time_employment_stat Children or Armed Forces
capital_gains 0
capital_losses 0
dividends_from_stocks 0
tax_filer_stat Nonfiler
region_of_previous_residence Not in universe
state_of_previous_residence Not in universe
detailed_household_and_family_stat Child <18 never marr not in subfamily
detailed_household_summary_in_household Child under 18 never married
instance_weight 1758.14
migration_code-change_in_msa Nonmover
migration_code-change_in_reg Nonmover
migration_code-move_within_reg Nonmover
live_in_this_house_1_year_ago Yes
migration_prev_res_in_sunbelt Not in universe
num_persons_worked_for_employer 0
family_members_under_18 Both parents present
country_of_birth_father United-States
country_of_birth_mother United-States
country_of_birth_self United-States
citizenship Native- Born in the United States
own_business_or_self_employed 0
fill_inc_questionnaire_for_veteran's_admin Not in universe
veterans_benefits 0
weeks_worked_in_year 0
year 94
income_level 0
4
age 10
class_of_worker Not in universe
detailed_industry_recode 0
detailed_occupation_recode 0
education Children
wage_per_hour 0
enroll_in_edu_inst_last_wk Not in universe
marital_stat Never married
major_industry_code Not in universe or children
major_occupation_code Not in universe
race White
hispanic_origin All other
sex Female
member_of_a_labor_union Not in universe
reason_for_unemployment Not in universe
full_or_part_time_employment_stat Children or Armed Forces
capital_gains 0
capital_losses 0
dividends_from_stocks 0
tax_filer_stat Nonfiler
region_of_previous_residence Not in universe
state_of_previous_residence Not in universe
detailed_household_and_family_stat Child <18 never marr not in subfamily
detailed_household_summary_in_household Child under 18 never married
instance_weight 1069.16
migration_code-change_in_msa Nonmover
migration_code-change_in_reg Nonmover
migration_code-move_within_reg Nonmover
live_in_this_house_1_year_ago Yes
migration_prev_res_in_sunbelt Not in universe
num_persons_worked_for_employer 0
family_members_under_18 Both parents present
country_of_birth_father United-States
country_of_birth_mother United-States
country_of_birth_self United-States
citizenship Native- Born in the United States
own_business_or_self_employed 0
fill_inc_questionnaire_for_veteran's_admin Not in universe
veterans_benefits 0
weeks_worked_in_year 0
year 94
income_level 0
配置超参数
你可以在文档中找到梯度提升树模型的所有参数。
# Maximum number of decision trees. The effective number of trained trees can be smaller if early stopping is enabled.
NUM_TREES = 250
# Minimum number of examples in a node.
MIN_EXAMPLES = 6
# Maximum depth of the tree. max_depth=1 means that all trees will be roots.
MAX_DEPTH = 5
# Ratio of the dataset (sampling without replacement) used to train individual trees for the random sampling method.
SUBSAMPLE = 0.65
# Control the sampling of the datasets used to train individual trees.
SAMPLING_METHOD = "RANDOM"
# Ratio of the training dataset used to monitor the training. Require to be >0 if early stopping is enabled.
VALIDATION_RATIO = 0.1
实施培训和评估程序
run_experiment() 方法负责加载训练数据集和测试数据集、训练给定模型以及评估训练后的模型。 请注意,在训练决策森林模型时,只需要一个历元来读取完整的数据集。
任何额外的步骤都会导致不必要的训练速度减慢。 因此,在 run_experiment() 方法中使用了默认的 num_epochs=1。
def run_experiment(model, train_data, test_data, num_epochs=1, batch_size=None):train_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(train_data, label=TARGET_COLUMN_NAME, weight=WEIGHT_COLUMN_NAME)test_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(test_data, label=TARGET_COLUMN_NAME, weight=WEIGHT_COLUMN_NAME)model.fit(train_dataset, epochs=num_epochs, batch_size=batch_size)_, accuracy = model.evaluate(test_dataset, verbose=0)print(f"Test accuracy: {round(accuracy * 100, 2)}%")
实验 1:使用原始特征的决策森林
指定模型输入特征用法
您可以为每个特征附加语义,以控制模型如何使用它。
如果没有指定,语义将从表示类型中推断出来。 建议明确指定特征用法,以避免推断语义不正确。 例如,一个分类值标识符(整数)会被推断为数值,而它在语义上是分类的。 对于数值特征,可以将离散参数设置为数值特征应被离散的桶数。 这样可以加快训练速度,但可能会导致模型质量下降。
def specify_feature_usages():feature_usages = []for feature_name in NUMERIC_FEATURE_NAMES:feature_usage = tfdf.keras.FeatureUsage(name=feature_name, semantic=tfdf.keras.FeatureSemantic.NUMERICAL)feature_usages.append(feature_usage)for feature_name in CATEGORICAL_FEATURE_NAMES:feature_usage = tfdf.keras.FeatureUsage(name=feature_name, semantic=tfdf.keras.FeatureSemantic.CATEGORICAL)feature_usages.append(feature_usage)return feature_usages
创建梯度提升树模型
在编译决策森林模型时,只能提供额外的评估指标。 损失在模型构建中指定,优化器与决策森林模型无关。
def create_gbt_model():# See all the model parameters in https://www.tensorflow.org/decision_forests/api_docs/python/tfdf/keras/GradientBoostedTreesModelgbt_model = tfdf.keras.GradientBoostedTreesModel(features=specify_feature_usages(),exclude_non_specified_features=True,num_trees=NUM_TREES,max_depth=MAX_DEPTH,min_examples=MIN_EXAMPLES,subsample=SUBSAMPLE,validation_ratio=VALIDATION_RATIO,task=tfdf.keras.Task.CLASSIFICATION,)gbt_model.compile(metrics=[keras.metrics.BinaryAccuracy(name="accuracy")])return gbt_model
训练和评估模型
gbt_model = create_gbt_model()
run_experiment(gbt_model, train_data, test_data)
Starting reading the dataset
200/200 [==============================] - ETA: 0s
Dataset read in 0:00:08.829036
Training model
Model trained in 0:00:48.639771
Compiling model
200/200 [==============================] - 58s 268ms/step
Test accuracy: 95.79%
检查模型
model.summary() 方法将显示有关决策树模型、模型类型、任务、输入特征和特征重要性的几类信息。
print(gbt_model.summary())
Model: "gradient_boosted_trees_model"
_________________________________________________________________Layer (type) Output Shape Param #
=================================================================
=================================================================
Total params: 1
Trainable params: 0
Non-trainable params: 1
_________________________________________________________________
Type: "GRADIENT_BOOSTED_TREES"
Task: CLASSIFICATION
Label: "__LABEL"
Input Features (40):agecapital_gainscapital_lossescitizenshipclass_of_workercountry_of_birth_fathercountry_of_birth_mothercountry_of_birth_selfdetailed_household_and_family_statdetailed_household_summary_in_householddetailed_industry_recodedetailed_occupation_recodedividends_from_stockseducationenroll_in_edu_inst_last_wkfamily_members_under_18fill_inc_questionnaire_for_veteran's_adminfull_or_part_time_employment_stathispanic_originlive_in_this_house_1_year_agomajor_industry_codemajor_occupation_codemarital_statmember_of_a_labor_unionmigration_code-change_in_msamigration_code-change_in_regmigration_code-move_within_regmigration_prev_res_in_sunbeltnum_persons_worked_for_employerown_business_or_self_employedracereason_for_unemploymentregion_of_previous_residencesexstate_of_previous_residencetax_filer_statveterans_benefitswage_per_hourweeks_worked_in_yearyear
Trained with weights
Variable Importance: MEAN_MIN_DEPTH:1. "enroll_in_edu_inst_last_wk" 3.942647 ################2. "family_members_under_18" 3.942647 ################3. "live_in_this_house_1_year_ago" 3.942647 ################4. "migration_code-change_in_msa" 3.942647 ################5. "migration_code-move_within_reg" 3.942647 ################6. "year" 3.942647 ################7. "__LABEL" 3.942647 ################8. "__WEIGHTS" 3.942647 ################9. "citizenship" 3.942137 ###############10. "detailed_household_summary_in_household" 3.942137 ###############11. "region_of_previous_residence" 3.942137 ###############12. "veterans_benefits" 3.942137 ###############13. "migration_prev_res_in_sunbelt" 3.940135 ###############14. "migration_code-change_in_reg" 3.939926 ###############15. "major_occupation_code" 3.937681 ###############16. "major_industry_code" 3.933687 ###############17. "reason_for_unemployment" 3.926320 ###############18. "hispanic_origin" 3.900776 ###############19. "member_of_a_labor_union" 3.894843 ###############20. "race" 3.878617 ###############21. "num_persons_worked_for_employer" 3.818566 ##############22. "marital_stat" 3.795667 ##############23. "full_or_part_time_employment_stat" 3.795431 ##############24. "country_of_birth_mother" 3.787967 ##############25. "tax_filer_stat" 3.784505 ##############26. "fill_inc_questionnaire_for_veteran's_admin" 3.783607 ##############27. "own_business_or_self_employed" 3.776398 ##############28. "country_of_birth_father" 3.715252 #############29. "sex" 3.708745 #############30. "class_of_worker" 3.688424 #############31. "weeks_worked_in_year" 3.665290 #############32. "state_of_previous_residence" 3.657234 #############33. "country_of_birth_self" 3.654377 #############34. "age" 3.634295 ############35. "wage_per_hour" 3.617817 ############36. "detailed_household_and_family_stat" 3.594743 ############37. "capital_losses" 3.439298 ##########38. "dividends_from_stocks" 3.423652 ##########39. "capital_gains" 3.222753 ########40. "education" 3.158698 ########41. "detailed_industry_recode" 2.981471 ######42. "detailed_occupation_recode" 2.364817
Variable Importance: NUM_AS_ROOT:1. "education" 33.000000 ################2. "capital_gains" 29.000000 ##############3. "capital_losses" 24.000000 ###########4. "detailed_household_and_family_stat" 14.000000 ######5. "dividends_from_stocks" 14.000000 ######6. "wage_per_hour" 12.000000 #####7. "country_of_birth_self" 11.000000 #####8. "detailed_occupation_recode" 11.000000 #####9. "weeks_worked_in_year" 11.000000 #####10. "age" 10.000000 ####11. "state_of_previous_residence" 10.000000 ####12. "fill_inc_questionnaire_for_veteran's_admin" 9.000000 ####13. "class_of_worker" 8.000000 ###14. "full_or_part_time_employment_stat" 8.000000 ###15. "marital_stat" 8.000000 ###16. "own_business_or_self_employed" 8.000000 ###17. "sex" 6.000000 ##18. "tax_filer_stat" 5.000000 ##19. "country_of_birth_father" 4.000000 #20. "race" 3.000000 #21. "detailed_industry_recode" 2.000000 22. "hispanic_origin" 2.000000 23. "country_of_birth_mother" 1.000000 24. "num_persons_worked_for_employer" 1.000000 25. "reason_for_unemployment" 1.000000
Variable Importance: NUM_NODES:1. "detailed_occupation_recode" 785.000000 ################2. "detailed_industry_recode" 668.000000 #############3. "capital_gains" 275.000000 #####4. "dividends_from_stocks" 220.000000 ####5. "capital_losses" 197.000000 ####6. "education" 178.000000 ###7. "country_of_birth_mother" 128.000000 ##8. "country_of_birth_father" 116.000000 ##9. "age" 114.000000 ##10. "wage_per_hour" 98.000000 #11. "state_of_previous_residence" 95.000000 #12. "detailed_household_and_family_stat" 78.000000 #13. "class_of_worker" 67.000000 #14. "country_of_birth_self" 65.000000 #15. "sex" 65.000000 #16. "weeks_worked_in_year" 60.000000 #17. "tax_filer_stat" 57.000000 #18. "num_persons_worked_for_employer" 54.000000 #19. "own_business_or_self_employed" 30.000000 20. "marital_stat" 26.000000 21. "member_of_a_labor_union" 16.000000 22. "fill_inc_questionnaire_for_veteran's_admin" 15.000000 23. "full_or_part_time_employment_stat" 15.000000 24. "major_industry_code" 15.000000 25. "hispanic_origin" 9.000000 26. "major_occupation_code" 7.000000 27. "race" 7.000000 28. "citizenship" 1.000000 29. "detailed_household_summary_in_household" 1.000000 30. "migration_code-change_in_reg" 1.000000 31. "migration_prev_res_in_sunbelt" 1.000000 32. "reason_for_unemployment" 1.000000 33. "region_of_previous_residence" 1.000000 34. "veterans_benefits" 1.000000
Variable Importance: SUM_SCORE:1. "detailed_occupation_recode" 15392441.075369 ################2. "capital_gains" 5277826.822514 #####3. "education" 4751749.289550 ####4. "dividends_from_stocks" 3792002.951255 ###5. "detailed_industry_recode" 2882200.882109 ##6. "sex" 2559417.877325 ##7. "age" 2042990.944829 ##8. "capital_losses" 1735728.772551 #9. "weeks_worked_in_year" 1272820.203971 #10. "tax_filer_stat" 697890.160846 11. "num_persons_worked_for_employer" 671351.905595 12. "detailed_household_and_family_stat" 444620.829557 13. "class_of_worker" 362250.565331 14. "country_of_birth_mother" 296311.574426 15. "country_of_birth_father" 258198.889206 16. "wage_per_hour" 239764.219048 17. "state_of_previous_residence" 237687.602572 18. "country_of_birth_self" 103002.168158 19. "marital_stat" 102449.735314 20. "own_business_or_self_employed" 82938.893541 21. "fill_inc_questionnaire_for_veteran's_admin" 22692.700206 22. "full_or_part_time_employment_stat" 19078.398837 23. "major_industry_code" 18450.345505 24. "member_of_a_labor_union" 14905.360879 25. "hispanic_origin" 12602.867902 26. "major_occupation_code" 8709.665989 27. "race" 6116.282065 28. "citizenship" 3291.490393 29. "detailed_household_summary_in_household" 2733.439375 30. "veterans_benefits" 1230.940488 31. "region_of_previous_residence" 1139.240981 32. "reason_for_unemployment" 219.245124 33. "migration_code-change_in_reg" 55.806436 34. "migration_prev_res_in_sunbelt" 37.780635
Loss: BINOMIAL_LOG_LIKELIHOOD
Validation loss value: 0.228983
Number of trees per iteration: 1
Node format: NOT_SET
Number of trees: 245
Total number of nodes: 7179
Number of nodes by tree:
Count: 245 Average: 29.302 StdDev: 2.96211
Min: 17 Max: 31 Ignored: 0
----------------------------------------------
[ 17, 18) 2 0.82% 0.82%
[ 18, 19) 0 0.00% 0.82%
[ 19, 20) 3 1.22% 2.04%
[ 20, 21) 0 0.00% 2.04%
[ 21, 22) 4 1.63% 3.67%
[ 22, 23) 0 0.00% 3.67%
[ 23, 24) 15 6.12% 9.80% #
[ 24, 25) 0 0.00% 9.80%
[ 25, 26) 5 2.04% 11.84%
[ 26, 27) 0 0.00% 11.84%
[ 27, 28) 21 8.57% 20.41% #
[ 28, 29) 0 0.00% 20.41%
[ 29, 30) 39 15.92% 36.33% ###
[ 30, 31) 0 0.00% 36.33%
[ 31, 31] 156 63.67% 100.00% ##########
Depth by leafs:
Count: 3712 Average: 3.95259 StdDev: 0.249814
Min: 2 Max: 4 Ignored: 0
----------------------------------------------
[ 2, 3) 32 0.86% 0.86%
[ 3, 4) 112 3.02% 3.88%
[ 4, 4] 3568 96.12% 100.00% ##########
Number of training obs by leaf:
Count: 3712 Average: 11849.3 StdDev: 33719.3
Min: 6 Max: 179360 Ignored: 0
----------------------------------------------
[ 6, 8973) 3100 83.51% 83.51% ##########
[ 8973, 17941) 148 3.99% 87.50%
[ 17941, 26909) 79 2.13% 89.63%
[ 26909, 35877) 36 0.97% 90.60%
[ 35877, 44844) 44 1.19% 91.78%
[ 44844, 53812) 17 0.46% 92.24%
[ 53812, 62780) 20 0.54% 92.78%
[ 62780, 71748) 39 1.05% 93.83%
[ 71748, 80715) 24 0.65% 94.48%
[ 80715, 89683) 12 0.32% 94.80%
[ 89683, 98651) 22 0.59% 95.39%
[ 98651, 107619) 21 0.57% 95.96%
[ 107619, 116586) 17 0.46% 96.42%
[ 116586, 125554) 17 0.46% 96.88%
[ 125554, 134522) 13 0.35% 97.23%
[ 134522, 143490) 8 0.22% 97.44%
[ 143490, 152457) 5 0.13% 97.58%
[ 152457, 161425) 6 0.16% 97.74%
[ 161425, 170393) 15 0.40% 98.14%
[ 170393, 179360] 69 1.86% 100.00%
Attribute in nodes:785 : detailed_occupation_recode [CATEGORICAL]668 : detailed_industry_recode [CATEGORICAL]275 : capital_gains [NUMERICAL]220 : dividends_from_stocks [NUMERICAL]197 : capital_losses [NUMERICAL]178 : education [CATEGORICAL]128 : country_of_birth_mother [CATEGORICAL]116 : country_of_birth_father [CATEGORICAL]114 : age [NUMERICAL]98 : wage_per_hour [NUMERICAL]95 : state_of_previous_residence [CATEGORICAL]78 : detailed_household_and_family_stat [CATEGORICAL]67 : class_of_worker [CATEGORICAL]65 : sex [CATEGORICAL]65 : country_of_birth_self [CATEGORICAL]60 : weeks_worked_in_year [NUMERICAL]57 : tax_filer_stat [CATEGORICAL]54 : num_persons_worked_for_employer [NUMERICAL]30 : own_business_or_self_employed [CATEGORICAL]26 : marital_stat [CATEGORICAL]16 : member_of_a_labor_union [CATEGORICAL]15 : major_industry_code [CATEGORICAL]15 : full_or_part_time_employment_stat [CATEGORICAL]15 : fill_inc_questionnaire_for_veteran's_admin [CATEGORICAL]9 : hispanic_origin [CATEGORICAL]7 : race [CATEGORICAL]7 : major_occupation_code [CATEGORICAL]1 : veterans_benefits [CATEGORICAL]1 : region_of_previous_residence [CATEGORICAL]1 : reason_for_unemployment [CATEGORICAL]1 : migration_prev_res_in_sunbelt [CATEGORICAL]1 : migration_code-change_in_reg [CATEGORICAL]1 : detailed_household_summary_in_household [CATEGORICAL]1 : citizenship [CATEGORICAL]
Attribute in nodes with depth <= 0:33 : education [CATEGORICAL]29 : capital_gains [NUMERICAL]24 : capital_losses [NUMERICAL]14 : dividends_from_stocks [NUMERICAL]14 : detailed_household_and_family_stat [CATEGORICAL]12 : wage_per_hour [NUMERICAL]11 : weeks_worked_in_year [NUMERICAL]11 : detailed_occupation_recode [CATEGORICAL]11 : country_of_birth_self [CATEGORICAL]10 : state_of_previous_residence [CATEGORICAL]10 : age [NUMERICAL]9 : fill_inc_questionnaire_for_veteran's_admin [CATEGORICAL]8 : own_business_or_self_employed [CATEGORICAL]8 : marital_stat [CATEGORICAL]8 : full_or_part_time_employment_stat [CATEGORICAL]8 : class_of_worker [CATEGORICAL]6 : sex [CATEGORICAL]5 : tax_filer_stat [CATEGORICAL]4 : country_of_birth_father [CATEGORICAL]3 : race [CATEGORICAL]2 : hispanic_origin [CATEGORICAL]2 : detailed_industry_recode [CATEGORICAL]1 : reason_for_unemployment [CATEGORICAL]1 : num_persons_worked_for_employer [NUMERICAL]1 : country_of_birth_mother [CATEGORICAL]
Attribute in nodes with depth <= 1:140 : detailed_occupation_recode [CATEGORICAL]82 : capital_gains [NUMERICAL]65 : capital_losses [NUMERICAL]62 : education [CATEGORICAL]59 : detailed_industry_recode [CATEGORICAL]47 : dividends_from_stocks [NUMERICAL]31 : wage_per_hour [NUMERICAL]26 : detailed_household_and_family_stat [CATEGORICAL]23 : age [NUMERICAL]22 : state_of_previous_residence [CATEGORICAL]21 : country_of_birth_self [CATEGORICAL]21 : class_of_worker [CATEGORICAL]20 : weeks_worked_in_year [NUMERICAL]20 : sex [CATEGORICAL]15 : country_of_birth_father [CATEGORICAL]12 : own_business_or_self_employed [CATEGORICAL]11 : fill_inc_questionnaire_for_veteran's_admin [CATEGORICAL]10 : num_persons_worked_for_employer [NUMERICAL]9 : tax_filer_stat [CATEGORICAL]9 : full_or_part_time_employment_stat [CATEGORICAL]8 : marital_stat [CATEGORICAL]8 : country_of_birth_mother [CATEGORICAL]6 : member_of_a_labor_union [CATEGORICAL]5 : race [CATEGORICAL]2 : hispanic_origin [CATEGORICAL]1 : reason_for_unemployment [CATEGORICAL]
Attribute in nodes with depth <= 2:399 : detailed_occupation_recode [CATEGORICAL]249 : detailed_industry_recode [CATEGORICAL]170 : capital_gains [NUMERICAL]117 : dividends_from_stocks [NUMERICAL]116 : capital_losses [NUMERICAL]87 : education [CATEGORICAL]59 : wage_per_hour [NUMERICAL]45 : detailed_household_and_family_stat [CATEGORICAL]43 : country_of_birth_father [CATEGORICAL]43 : age [NUMERICAL]40 : country_of_birth_self [CATEGORICAL]38 : state_of_previous_residence [CATEGORICAL]38 : class_of_worker [CATEGORICAL]37 : sex [CATEGORICAL]36 : weeks_worked_in_year [NUMERICAL]33 : country_of_birth_mother [CATEGORICAL]28 : num_persons_worked_for_employer [NUMERICAL]26 : tax_filer_stat [CATEGORICAL]14 : own_business_or_self_employed [CATEGORICAL]14 : marital_stat [CATEGORICAL]12 : full_or_part_time_employment_stat [CATEGORICAL]12 : fill_inc_questionnaire_for_veteran's_admin [CATEGORICAL]8 : member_of_a_labor_union [CATEGORICAL]6 : race [CATEGORICAL]6 : hispanic_origin [CATEGORICAL]2 : major_occupation_code [CATEGORICAL]2 : major_industry_code [CATEGORICAL]1 : reason_for_unemployment [CATEGORICAL]1 : migration_prev_res_in_sunbelt [CATEGORICAL]1 : migration_code-change_in_reg [CATEGORICAL]
Attribute in nodes with depth <= 3:785 : detailed_occupation_recode [CATEGORICAL]668 : detailed_industry_recode [CATEGORICAL]275 : capital_gains [NUMERICAL]220 : dividends_from_stocks [NUMERICAL]197 : capital_losses [NUMERICAL]178 : education [CATEGORICAL]128 : country_of_birth_mother [CATEGORICAL]116 : country_of_birth_father [CATEGORICAL]114 : age [NUMERICAL]98 : wage_per_hour [NUMERICAL]95 : state_of_previous_residence [CATEGORICAL]78 : detailed_household_and_family_stat [CATEGORICAL]67 : class_of_worker [CATEGORICAL]65 : sex [CATEGORICAL]65 : country_of_birth_self [CATEGORICAL]60 : weeks_worked_in_year [NUMERICAL]57 : tax_filer_stat [CATEGORICAL]54 : num_persons_worked_for_employer [NUMERICAL]30 : own_business_or_self_employed [CATEGORICAL]26 : marital_stat [CATEGORICAL]16 : member_of_a_labor_union [CATEGORICAL]15 : major_industry_code [CATEGORICAL]15 : full_or_part_time_employment_stat [CATEGORICAL]15 : fill_inc_questionnaire_for_veteran's_admin [CATEGORICAL]9 : hispanic_origin [CATEGORICAL]7 : race [CATEGORICAL]7 : major_occupation_code [CATEGORICAL]1 : veterans_benefits [CATEGORICAL]1 : region_of_previous_residence [CATEGORICAL]1 : reason_for_unemployment [CATEGORICAL]1 : migration_prev_res_in_sunbelt [CATEGORICAL]1 : migration_code-change_in_reg [CATEGORICAL]1 : detailed_household_summary_in_household [CATEGORICAL]1 : citizenship [CATEGORICAL]
Attribute in nodes with depth <= 5:785 : detailed_occupation_recode [CATEGORICAL]668 : detailed_industry_recode [CATEGORICAL]275 : capital_gains [NUMERICAL]220 : dividends_from_stocks [NUMERICAL]197 : capital_losses [NUMERICAL]178 : education [CATEGORICAL]128 : country_of_birth_mother [CATEGORICAL]116 : country_of_birth_father [CATEGORICAL]114 : age [NUMERICAL]98 : wage_per_hour [NUMERICAL]95 : state_of_previous_residence [CATEGORICAL]78 : detailed_household_and_family_stat [CATEGORICAL]67 : class_of_worker [CATEGORICAL]65 : sex [CATEGORICAL]65 : country_of_birth_self [CATEGORICAL]60 : weeks_worked_in_year [NUMERICAL]57 : tax_filer_stat [CATEGORICAL]54 : num_persons_worked_for_employer [NUMERICAL]30 : own_business_or_self_employed [CATEGORICAL]26 : marital_stat [CATEGORICAL]16 : member_of_a_labor_union [CATEGORICAL]15 : major_industry_code [CATEGORICAL]15 : full_or_part_time_employment_stat [CATEGORICAL]15 : fill_inc_questionnaire_for_veteran's_admin [CATEGORICAL]9 : hispanic_origin [CATEGORICAL]7 : race [CATEGORICAL]7 : major_occupation_code [CATEGORICAL]1 : veterans_benefits [CATEGORICAL]1 : region_of_previous_residence [CATEGORICAL]1 : reason_for_unemployment [CATEGORICAL]1 : migration_prev_res_in_sunbelt [CATEGORICAL]1 : migration_code-change_in_reg [CATEGORICAL]1 : detailed_household_summary_in_household [CATEGORICAL]1 : citizenship [CATEGORICAL]
Condition type in nodes:2418 : ContainsBitmapCondition1018 : HigherCondition31 : ContainsCondition
Condition type in nodes with depth <= 0:137 : ContainsBitmapCondition101 : HigherCondition7 : ContainsCondition
Condition type in nodes with depth <= 1:448 : ContainsBitmapCondition278 : HigherCondition9 : ContainsCondition
Condition type in nodes with depth <= 2:1097 : ContainsBitmapCondition569 : HigherCondition17 : ContainsCondition
Condition type in nodes with depth <= 3:2418 : ContainsBitmapCondition1018 : HigherCondition31 : ContainsCondition
Condition type in nodes with depth <= 5:2418 : ContainsBitmapCondition1018 : HigherCondition31 : ContainsCondition
None
实验 2:目标编码决策森林
目标编码是一种常见的预处理技术,用于将分类特征转换为数字特征。 原封不动地使用高卡因度的分类特征可能会导致过度拟合。 目标编码的目的是用一个或多个数值来代替每个分类特征值,这些数值代表了特征值与目标标签的共现程度。
更准确地说,给定一个分类特征后,本例中的二进制目标编码器将产生三个新的数值特征:
正向频率(positive_frequency):每个特征值与正向目标标签共现的次数;
负向频率(negative_frequency):每个特征值与负向目标标签共现的次数:
每个特征值与正目标标签出现的次数: 每个特征值出现负目标标签的次数。 positive_probability(正概率): 根据特征值,目标标签为正的概率,计算公式为 positive_frequency / (positive_frequency + negative_frequency + correction)。
加入校正项是为了使稀有分类值的划分更加稳定。 修正项的默认值为 1.0。 请注意,目标编码对于无法自动学习密集表示分类特征的模型(如决策森林或核方法)是有效的。 如果使用神经网络模型,建议将分类特征编码为嵌入。
实现二进制目标编码器
为简单起见,我们假设 adapt 和调用方法的输入是预期的数据类型和形状,因此不添加验证逻辑。 建议将分类特征的词汇量_大小传递给 BinaryTargetEncoding 构造函数。 如果未指定,将在 adapt() 方法执行过程中计算。
class BinaryTargetEncoding(layers.Layer):def __init__(self, vocabulary_size=None, correction=1.0, **kwargs):super().__init__(**kwargs)self.vocabulary_size = vocabulary_sizeself.correction = correctiondef adapt(self, data):# data is expected to be an integer numpy array to a Tensor shape [num_exmples, 2].# This contains feature values for a given feature in the dataset, and target values.# Convert the data to a tensor.data = tf.convert_to_tensor(data)# Separate the feature values and target valuesfeature_values = tf.cast(data[:, 0], tf.dtypes.int32)target_values = tf.cast(data[:, 1], tf.dtypes.bool)# Compute the vocabulary_size of not specified.if self.vocabulary_size is None:self.vocabulary_size = tf.unique(feature_values).y.shape[0]# Filter the data where the target label is positive.positive_indices = tf.where(condition=target_values)postive_feature_values = tf.gather_nd(params=feature_values, indices=positive_indices)# Compute how many times each feature value occurred with a positive target label.positive_frequency = tf.math.unsorted_segment_sum(data=tf.ones(shape=(postive_feature_values.shape[0], 1), dtype=tf.dtypes.float64),segment_ids=postive_feature_values,num_segments=self.vocabulary_size,)# Filter the data where the target label is negative.negative_indices = tf.where(condition=tf.math.logical_not(target_values))negative_feature_values = tf.gather_nd(params=feature_values, indices=negative_indices)# Compute how many times each feature value occurred with a negative target label.negative_frequency = tf.math.unsorted_segment_sum(data=tf.ones(shape=(negative_feature_values.shape[0], 1), dtype=tf.dtypes.float64),segment_ids=negative_feature_values,num_segments=self.vocabulary_size,)# Compute positive probability for the input feature values.positive_probability = positive_frequency / (positive_frequency + negative_frequency + self.correction)# Concatenate the computed statistics for traget_encoding.target_encoding_statistics = tf.cast(tf.concat([positive_frequency, negative_frequency, positive_probability], axis=1),dtype=tf.dtypes.float32,)self.target_encoding_statistics = tf.constant(target_encoding_statistics)def call(self, inputs):# inputs is expected to be an integer numpy array to a Tensor shape [num_exmples, 1].# This includes the feature values for a given feature in the dataset.# Raise an error if the target encoding statistics are not computed.if self.target_encoding_statistics == None:raise ValueError(f"You need to call the adapt method to compute target encoding statistics.")# Convert the inputs to a tensor.inputs = tf.convert_to_tensor(inputs)# Cast the inputs int64 a tensor.inputs = tf.cast(inputs, tf.dtypes.int64)# Lookup target encoding statistics for the input feature values.target_encoding_statistics = tf.cast(tf.gather_nd(self.target_encoding_statistics, inputs),dtype=tf.dtypes.float32,)return target_encoding_statistics
让我们测试二进制目标编码器
data = tf.constant([[0, 1],[2, 0],[0, 1],[1, 1],[1, 1],[2, 0],[1, 0],[0, 1],[2, 1],[1, 0],[0, 1],[2, 0],[0, 1],[1, 1],[1, 1],[2, 0],[1, 0],[0, 1],[2, 0],]
)binary_target_encoder = BinaryTargetEncoding()
binary_target_encoder.adapt(data)
print(binary_target_encoder([[0], [1], [2]]))
tf.Tensor(
[[6. 0. 0.85714287][4. 3. 0.5 ][1. 5. 0.14285715]], shape=(3, 3), dtype=float32)
创建模型输入
def create_model_inputs():inputs = {}for feature_name in NUMERIC_FEATURE_NAMES:inputs[feature_name] = layers.Input(name=feature_name, shape=(), dtype=tf.float32)for feature_name in CATEGORICAL_FEATURE_NAMES:inputs[feature_name] = layers.Input(name=feature_name, shape=(), dtype=tf.string)return inputs
使用目标编码实现特征编码
def create_target_encoder():inputs = create_model_inputs()target_values = train_data[[TARGET_COLUMN_NAME]].to_numpy()encoded_features = []for feature_name in inputs:if feature_name in CATEGORICAL_FEATURE_NAMES:# Get the vocabulary of the categorical feature.vocabulary = sorted([str(value) for value in list(train_data[feature_name].unique())])# Create a lookup to convert string values to an integer indices.# Since we are not using a mask token nor expecting any out of vocabulary# (oov) token, we set mask_token to None and num_oov_indices to 0.lookup = layers.StringLookup(vocabulary=vocabulary, mask_token=None, num_oov_indices=0)# Convert the string input values into integer indices.value_indices = lookup(inputs[feature_name])# Prepare the data to adapt the target encoding.print("### Adapting target encoding for:", feature_name)feature_values = train_data[[feature_name]].to_numpy().astype(str)feature_value_indices = lookup(feature_values)data = tf.concat([feature_value_indices, target_values], axis=1)feature_encoder = BinaryTargetEncoding()feature_encoder.adapt(data)# Convert the feature value indices to target encoding representations.encoded_feature = feature_encoder(tf.expand_dims(value_indices, -1))else:# Expand the dimensions of the numerical input feature and use it as-is.encoded_feature = tf.expand_dims(inputs[feature_name], -1)# Add the encoded feature to the list.encoded_features.append(encoded_feature)# Concatenate all the encoded features.encoded_features = tf.concat(encoded_features, axis=1)# Create and return a Keras model with encoded features as outputs.return keras.Model(inputs=inputs, outputs=encoded_features)
使用预处理器创建梯度提升树模型
在这种情况下,我们使用目标编码作为梯度提升树模型的预处理器,让模型推断输入特征的语义。
def create_gbt_with_preprocessor(preprocessor):gbt_model = tfdf.keras.GradientBoostedTreesModel(preprocessing=preprocessor,num_trees=NUM_TREES,max_depth=MAX_DEPTH,min_examples=MIN_EXAMPLES,subsample=SUBSAMPLE,validation_ratio=VALIDATION_RATIO,task=tfdf.keras.Task.CLASSIFICATION,)gbt_model.compile(metrics=[keras.metrics.BinaryAccuracy(name="accuracy")])return gbt_model
训练和评估模型
gbt_model = create_gbt_with_preprocessor(create_target_encoder())
run_experiment(gbt_model, train_data, test_data)
### Adapting target encoding for: class_of_worker
### Adapting target encoding for: detailed_industry_recode
### Adapting target encoding for: detailed_occupation_recode
### Adapting target encoding for: education
### Adapting target encoding for: enroll_in_edu_inst_last_wk
### Adapting target encoding for: marital_stat
### Adapting target encoding for: major_industry_code
### Adapting target encoding for: major_occupation_code
### Adapting target encoding for: race
### Adapting target encoding for: hispanic_origin
### Adapting target encoding for: sex
### Adapting target encoding for: member_of_a_labor_union
### Adapting target encoding for: reason_for_unemployment
### Adapting target encoding for: full_or_part_time_employment_stat
### Adapting target encoding for: tax_filer_stat
### Adapting target encoding for: region_of_previous_residence
### Adapting target encoding for: state_of_previous_residence
### Adapting target encoding for: detailed_household_and_family_stat
### Adapting target encoding for: detailed_household_summary_in_household
### Adapting target encoding for: migration_code-change_in_msa
### Adapting target encoding for: migration_code-change_in_reg
### Adapting target encoding for: migration_code-move_within_reg
### Adapting target encoding for: live_in_this_house_1_year_ago
### Adapting target encoding for: migration_prev_res_in_sunbelt
### Adapting target encoding for: family_members_under_18
### Adapting target encoding for: country_of_birth_father
### Adapting target encoding for: country_of_birth_mother
### Adapting target encoding for: country_of_birth_self
### Adapting target encoding for: citizenship
### Adapting target encoding for: own_business_or_self_employed
### Adapting target encoding for: fill_inc_questionnaire_for_veteran's_admin
### Adapting target encoding for: veterans_benefits
### Adapting target encoding for: year
Use /tmp/tmpj_0h78ld as temporary training directory
Starting reading the dataset
198/200 [============================>.] - ETA: 0s
Dataset read in 0:00:06.793717
Training model
Model trained in 0:04:32.752691
Compiling model
200/200 [==============================] - 280s 1s/step
Test accuracy: 95.81%
实验 3:决策森林与训练嵌入
在这种情况下,我们建立了一个编码器模型,将分类特征编码为嵌入,其中给定分类特征的嵌入大小是其词汇量大小的平方根。 我们通过反向传播在一个简单的 NN 模型中训练这些嵌入。
嵌入编码器训练完成后,我们将其作为梯度提升树(Gradient Boosted Tree)模型输入特征的预处理器。
请注意,嵌入和决策森林模型不能在一个阶段内协同训练,因为决策森林模型不使用反向传播训练。 相反,必须在初始阶段对嵌入进行训练,然后将其作为决策森林模型的静态输入。
利用嵌入实现特征编码
def create_embedding_encoder(size=None):inputs = create_model_inputs()encoded_features = []for feature_name in inputs:if feature_name in CATEGORICAL_FEATURE_NAMES:# Get the vocabulary of the categorical feature.vocabulary = sorted([str(value) for value in list(train_data[feature_name].unique())])# Create a lookup to convert string values to an integer indices.# Since we are not using a mask token nor expecting any out of vocabulary# (oov) token, we set mask_token to None and num_oov_indices to 0.lookup = layers.StringLookup(vocabulary=vocabulary, mask_token=None, num_oov_indices=0)# Convert the string input values into integer indices.value_index = lookup(inputs[feature_name])# Create an embedding layer with the specified dimensionsvocabulary_size = len(vocabulary)embedding_size = int(math.sqrt(vocabulary_size))feature_encoder = layers.Embedding(input_dim=len(vocabulary), output_dim=embedding_size)# Convert the index values to embedding representations.encoded_feature = feature_encoder(value_index)else:# Expand the dimensions of the numerical input feature and use it as-is.encoded_feature = tf.expand_dims(inputs[feature_name], -1)# Add the encoded feature to the list.encoded_features.append(encoded_feature)# Concatenate all the encoded features.encoded_features = layers.concatenate(encoded_features, axis=1)# Apply dropout.encoded_features = layers.Dropout(rate=0.25)(encoded_features)# Perform non-linearity projection.encoded_features = layers.Dense(units=size if size else encoded_features.shape[-1], activation="gelu")(encoded_features)# Create and return a Keras model with encoded features as outputs.return keras.Model(inputs=inputs, outputs=encoded_features)
建立一个 NN 模型来训练嵌入模型
def create_nn_model(encoder):inputs = create_model_inputs()embeddings = encoder(inputs)output = layers.Dense(units=1, activation="sigmoid")(embeddings)nn_model = keras.Model(inputs=inputs, outputs=output)nn_model.compile(optimizer=keras.optimizers.Adam(),loss=keras.losses.BinaryCrossentropy(),metrics=[keras.metrics.BinaryAccuracy("accuracy")],)return nn_modelembedding_encoder = create_embedding_encoder(size=64)
run_experiment(create_nn_model(embedding_encoder),train_data,test_data,num_epochs=5,batch_size=256,
)
Epoch 1/5
200/200 [==============================] - 10s 27ms/step - loss: 8303.1455 - accuracy: 0.9193
Epoch 2/5
200/200 [==============================] - 5s 27ms/step - loss: 1019.4900 - accuracy: 0.9371
Epoch 3/5
200/200 [==============================] - 5s 27ms/step - loss: 612.2844 - accuracy: 0.9416
Epoch 4/5
200/200 [==============================] - 5s 27ms/step - loss: 858.9774 - accuracy: 0.9397
Epoch 5/5
200/200 [==============================] - 5s 26ms/step - loss: 842.3922 - accuracy: 0.9421
Test accuracy: 95.0%
利用嵌入式训练和评估梯度提升树模型
gbt_model = create_gbt_with_preprocessor(embedding_encoder)
run_experiment(gbt_model, train_data, test_data)
Use /tmp/tmpao5o88p6 as temporary training directory
Starting reading the dataset
199/200 [============================>.] - ETA: 0s
Dataset read in 0:00:06.722677
Training model
Model trained in 0:05:18.350298
Compiling model
200/200 [==============================] - 325s 2s/step
Test accuracy: 95.82%
结束语
TensorFlow 决策森林提供了强大的模型,尤其是在处理结构化数据时。 在我们的实验中,梯度提升树模型的测试准确率达到了 95.79%。 当使用带有分类特征的目标编码时,同一模型的测试准确率达到了 95.81%。 在预训练嵌入作为梯度提升树模型的输入时,我们取得了 95.82% 的测试准确率。
决策森林可以与神经网络一起使用,具体方法是:
1)使用神经网络学习输入数据的有用表示,然后使用决策森林完成监督学习任务;
2)创建决策森林和神经网络模型的集合。
请注意,TensorFlow 决策森林(目前)还不支持硬件加速器。 所有训练和推理都在 CPU 上完成。 此外,决策森林的训练程序需要一个适合内存的有限数据集。 然而,增加数据集的规模会带来收益递减,与大型神经网络模型相比,决策森林算法需要更少的示例才能收敛。
相关文章:
政安晨:【Keras机器学习示例演绎】(五十三)—— 使用 TensorFlow 决策森林进行分类
目录 简介 设置 准备数据 定义数据集元数据 配置超参数 实施培训和评估程序 实验 1:使用原始特征的决策森林 检查模型 实验 2:目标编码决策森林 创建模型输入 使用目标编码实现特征编码 使用预处理器创建梯度提升树模型 训练和评估模型 实验…...
51单片机:电脑通过串口控制LED亮灭(附溢出率和波特率详解)
一、功能实现 1.电脑通过串口发送数据:0F 2.点亮4个LED 二、注意事项 1.发送和接受数据的文本模式 2.串口要对应 3.注意串口的波特率要和程序中的波特率保持一致 4.有无校验位和停止位 三、如何使用串口波特率计算器 1.以本程序为例 2.生成代码如下 void Uar…...
Java中的消息中间件选择与比较
Java中的消息中间件选择与比较 大家好,我是免费搭建查券返利机器人省钱赚佣金就用微赚淘客系统3.0的小编,也是冬天不穿秋裤,天冷也要风度的程序猿! 在分布式系统中,消息中间件是一种关键组件,它能帮助不同…...
react基础语法,模板语法,ui渲染,jsx,useState状态管理
创建一个react应用 这里使用create-react-app的脚手架构建项目(结构简洁,基于webpack-cli), npx create-react-app [项目名称] 使用其他脚手架构建项目可以参考:react框架,使用vite和nextjs构建react项目…...
OJ-0710
示例1 input 4 100 200 300 5001 21 32 4output700100 200 500 300 示例2 input 4 100 200 300 500 1 2 1 3 1 4output1100100 200500300 示例3 input 6 100 200 300 400 300 550 1 2 1 3 1 4 2 5 2 6output1050100 200 300600 300400 import java.util.ArrayList; im…...
人工智能在自动驾驶中的目标检测研究
摘要 随着自动驾驶技术的快速发展,视觉识别作为核心技术之一,扮演着至关重要的角色。本文旨在探讨人工智能如何通过视觉识别在自动驾驶中进行目标检测。我们将详细讨论目标检测的基本原理、常用算法、最新进展、已有的开源项目及其在自动驾驶中的应用和…...
【合并两个有序数组】
合并两个有序数组 一、题目二、普通解法三、双指针 一、题目 二、普通解法 先合并后排序 补充:js合并数组方法详见https://blog.csdn.net/ACCPluzhiqi/article/details/131702269?fromshareblogdetail js排序方法见http://t.csdnimg.cn/wVCOP 时间复杂度:O(mn)…...
链表 OJ(一)
移除链表元素 题目连接: https://leetcode.cn/problems/remove-linked-list-elements/description/ 使用双指针法,开始时,一个指针指向头节点,另一个指针指向头节点的下一个结点,然后开始遍历链表删除结点。 这里要注…...
《Linux与Windows文件系统的区别》
Linux与Windows文件系统的区别 在计算机操作系统领域,Linux和Windows是两种广泛使用的操作系统,它们在文件系统方面有许多显著的差异。这篇博客将详细介绍这两种操作系统文件系统的区别,帮助读者更好地理解它们各自的特点和优势。 类别Linu…...
批量修改Git历史commit信息中的username
之前很长一段时间GitHub上的提交都在使用工作账户, 导致私人仓库中的提交者比较混乱. 在StackOver里面找到了一个bash脚本可以批量修改username, 在这里记录一下. 修改的步骤一共两步: 执行修改脚本将本地修改同步到Git服务器 首先我们来看脚本: #!/bin/shgit filter-branch…...
LabVIEW与ABB工业机器人据监控
1. 前言 随着工业自动化的发展,工业机器人在制造业中的应用越来越广泛。为了实现对工业机器人的高效监控和控制,本文介绍了利用OPC(OLE for Process Control)服务器将ABB工业机器人与LabVIEW连接起来的解决方案。通过OPC服务器…...
c++栈内存和堆内存的基本使用
c栈内存和堆内存的基本使用 #include <iostream>// 定义一个简单的结构体 struct Person {std::string name;int age; };int main() {// 栈内存分配int a 10; // 基本数据类型的栈内存分配Person person; // 结构体的栈内存分配person.name "John";person.a…...
快速入门,springboot知识点汇总
学习 springboot 应该像学习一门编程语言一样,首先要熟练掌握常用的知识,而对于不常用的内容可以简单了解一下。先对整个框架和语言有一个大致的轮廓,然后再逐步补充细节。 前序: Spring Boot 通过简化配置和提供开箱即用的特性,…...
Ubuntu20.04系统非root用户安装GAMIT10.71
(测试环境:20240701升级包和20240701数据,解算通过) QQ:8212714 群:302883438群文件(source安装包20240701升级包) 1、首先在计算机中安装VMware Workstation 16 Pro。建议:分配…...
stm32 开发板可以拿来做什么?
STM32开发板可以用来做许多不同的事情,具体取决于您的应用需求和编程能力。我收集归类了一份嵌入式学习包,对于新手而言简直不要太棒,里面包括了新手各个时期的学习方向编程教学、问题视频讲解、毕设800套和语言类教学,敲个22就可…...
latex英文转中文word,及一些latex相关工具分享
前言:想要转换latex生成的英文pdf文件为中文word文件 一、主要步骤 1、文字翻译:直接使用谷歌翻译等辅助将英文翻译成中文即可; 支持英文pdf文件全文翻译,再用迅捷PDF转换器之类的转成word,再手动调整。 https://app…...
EasyOCR: 简单易用的多语言OCR工具
EasyOCR: 简单易用的多语言OCR工具 1. 什么是EasyOCR?2. 使用场景3. 基本使用方法安装示例代码代码解释 4. 结语 1. 什么是EasyOCR? EasyOCR是一个基于Python的开源光学字符识别(OCR)工具,它支持80多种语言的文本识别。该项目由JaidedAI开发,旨在提供一个简单易用但功能强大…...
arm架构安装chrome
在ARM架构设备上安装谷歌软件或应用通常涉及到几个步骤,这取决于你要安装的具体谷歌产品,比如谷歌浏览器、Google Play服务或者是其他谷歌开发的软件。下面我会给出一些常见的指导步骤,以安装谷歌浏览器为例: 在Linux ARM64上安装…...
ETAS工具导入Com Arxml修改步骤
文章目录 前言Confgen之前的更改Confgen之后的修改CANCanIfComComMEcuM修改CanNmCanSMDCMCanTp生成RTE过程报错修改DEXT-诊断文件修改Extract问题总结前言 通讯协议栈开发一般通过导入DBC实现,ETAS工具本身导入DBC也是生成arxml后执行cfggen,本文介绍直接导入客户提供的arxml…...
Apache Kylin模型构建全解析:深入理解大数据的多维分析
引言 Apache Kylin是一个开源的分布式分析引擎,旨在为大数据提供快速的多维分析能力。它通过预计算技术,将数据转化为立方体模型(Cube),从而实现对Hadoop大数据集的秒级查询响应。本文将详细介绍Kylin中模型构建的全过…...
element-plus的文件上传组件el-upload
el-upload组件 支持多种风格,如文件列表,图片,图片卡片,支持多种事件,预览,删除,上传成功,上传中等钩子。 file-list:上传的文件集合,一定要用v-model:file-…...
等保测评视角下的哈尔滨智慧城市安全框架构建
随着智慧城市的兴起,哈尔滨作为东北地区的重要城市,正在积极探索和实践智慧城市安全框架的构建,以确保在数字化转型的过程中,既能享受科技带来的便利,又能有效防范和应对各类网络安全风险。 本文将从等保测评的视角出…...
Java中的数据缓存技术及其应用
Java中的数据缓存技术及其应用 大家好,我是免费搭建查券返利机器人省钱赚佣金就用微赚淘客系统3.0的小编,也是冬天不穿秋裤,天冷也要风度的程序猿! 在现代应用程序中,数据缓存是一种重要的技术手段,用于提…...
SQL 索引
一、索引的基本概念 **索引(Index)**是数据库中一种特殊的数据结构,用于帮助数据库管理系统(DBMS)快速访问数据表中的特定信息。索引类似于书籍的目录,可以加快数据检索的速度。 二、索引的作用 提高查询…...
free第一次成功,第二次失败
问题描述: 在一个函数中存在free,第一次进入此函数没有问题,但是第二次出错 strncpy(pdd_all_data[i].sensor_name,white_list[j].dev_name,strlen(pdd_all_data[i].sensor_name)); 上面代码都是使用strncpy不小心导致double free or corrup…...
各种音频处理器
在HiFi(高保真)音频系统中,通常需要使用一些特定类型的音频处理器,以确保音频信号的高保真和优质输出。以下是一些常见的音频处理器类型及其在HiFi系统中的应用: DAC(数模转换器): …...
深度学习探秘:Transformer模型跨框架实现大比拼
深度学习探秘:Transformer模型跨框架实现大比拼 自2017年Transformer模型问世以来,它在自然语言处理(NLP)领域引发了一场革命。其独特的自注意力机制为处理序列数据提供了全新的视角。随着深度学习框架的不断发展,Tra…...
京准电钟:云计算中NTP网络时间服务器的作用是什么?
京准电钟:云计算中NTP网络时间服务器的作用是什么? 京准电钟:云计算中NTP网络时间服务器的作用是什么? NTP是一种用于同步网络中设备时间的协议,广泛用于互联网和局域网中。NTP网络时间服务器则是基于NTP协议构建&…...
Apache中使用CGI
Apache24 使用Visual Studio 2022 // CGI2.cpp : 此文件包含 "main" 函数。程序执行将在此处开始并结束。 // #include <stdio.h> #include <stdlib.h>#include <stdio.h>void main() {//设置HTML语言printf("Content-type:text/html\n\n&q…...
宏任务与微任务对比【前端异步】
目录 简介微任务与宏任务的基本概念宏任务(Macrotasks)微任务(Microtasks)宏任务示例微任务示例微任务与宏任务的执行时序 结论 简介 在JavaScript的异步编程中,理解事件循环(Event Loop)是至关…...
展厅设计案例100例/全国最好网络优化公司
ARM开发经典学习网站推荐 1. EG3 关于嵌入式开发的站点,提供非常多关于嵌入式开发的资料。包括开发公司,技术文档,免费资源等等。版面包括busses & boards,embedded software,dsp,embedded systems,opensource,rtos,embedded chips,system-on-a-chip 等等。 强烈推荐…...
wordpress模板在线编辑/长沙seo优化
1 、结对项目的案例与论文 论文已阅读。 2、性格对合作的影响 我的MBTI为:ISFJ 照顾者型(内向实感情感判断)——值得信赖和依靠。 在团队合作中,外倾型的人一般会较为热情对工作积极…...
海城做网站/营销策划书模板
一、简介paramiko是一个基于SSH用于连接远程服务器并执行相关操作(SSHClient和SFTPClinet,即一个是远程连接,一个是上传下载服务),使用该模块可以对远程服务器进行命令或文件操作,值得一说的是,fabric和ansible内部的远程管理就是…...
在公司的小语种网站上/即刻搜索引擎入口
小编典典这样的事情应该做到:.column-left{ float: left; width: 33.333%; }.column-right{ float: right; width: 33.333%; }.column-center{ display: inline-block; width: 33.333%; }编辑要使用大量列来执行此操作,您可以构建一个非常简单的网格系统…...
电商网站有哪些使用场景/扬州seo推广
STP STP全称为“生成树协议”(Spanning Tree Protocol),是一种网络协议,用于在交换机网络中防止网络回路产生,保证网络的稳定和可靠性。它通过在网络中选择一条主路径(树形结构),并…...
什么做自己的网站 应招聘人才/广告推广图片
今晚从《C必知必会》上看到SFINAE这个C的特性,也就是substitution failure is not an error,可惜怎么看都不能够理解。后来google了一下,参考了两篇文章,算是有点明白其中的微妙了吧。 我们都知道对于非模板函数的重载来说&#x…...