![只能用于文本与图像数据?No!看TabTransformer嚆��]()
vuee7k+aehOWMluS4muWKoeaVsOaNrueyvuWHhuW7uuaooQ==" title="只能用于文本与图像数据?No!看TabTransformer嚆��"https://www.huyubaike.com/tag/1905.html" target="_blank" class="yzm-keyword-link">vuee7k+aehOWMluS4muWKoeaVsOaNrueyvuWHhuW7uuaooQ==" src="https://m.360buyimg.com/jdcms/jfs/t1/39983/17/20163/10837/635f9121Ea059ac16/848c1783167b3d90.png">
作者:韩信子@ShowMeAI
深度学习实战系列:https://www.showmeai.tech/tutorials/42
TensorFlow 实战系列:https://www.showmeai.tech/tutorials/43
本文地址:https://www.showmeai.tech/article-detail/315
声明:版权所有,转载请联系平台与作者并注明出处
收藏ShowMeAI查看更多精彩内容
![只能用于文本与图像数据?No!看TabTransformer嚆��]()
vuee7k+aehOWMluS4muWKoeaVsOaNrueyvuWHhuW7uuaooQ==" title="只能用于文本与图像数据?No!看TabTransformer嚆��"https://www.huyubaike.com/tag/1905.html" target="_blank" class="yzm-keyword-link">vuee7k+aehOWMluS4muWKoeaVsOaNrueyvuWHhuW7uuaooQ==" src="https://m.360buyimg.com/jdcms/jfs/t1/135876/13/30245/112173/635f9124E1cb8d088/ea4df6cd965f3bb3.png">
自 Transformers 出现以来,基于它的结构已经颠覆了自然语言处理和计算机视觉,带来各种非结构化数据业务场景和任务的巨大效果突破,接着大家把目光转向了结构化业务数据,它是否能在结构化表格数据上同样有惊人的效果表现呢?
答案是YES!亚马逊在论文中提出的 TabTransformer,是一种把结构调整后适应于结构化表格数据的网络结构,它更擅长于捕捉传统结构化表格数据中不同类型的数据信息,并将其结合以完成预估任务。下面ShowMeAI给大家讲解构建 TabTransformer 并将其应用于结构化数据上的过程。
![只能用于文本与图像数据?No!看TabTransformer嚆��]()
vuee7k+aehOWMluS4muWKoeaVsOaNrueyvuWHhuW7uuaooQ==" title="只能用于文本与图像数据?No!看TabTransformer嚆��"https://www.huyubaike.com/tag/1905.html" target="_blank" class="yzm-keyword-link">vuee7k+aehOWMluS4muWKoeaVsOaNrueyvuWHhuW7uuaooQ==" src="https://m.360buyimg.com/jdcms/jfs/t1/182516/5/22588/18046/635f9125E38822dd9/03b4990395908bef.png">
环境设置
本篇使用到的深度学习框架为TensorFlow,大家需要安装2.7或更高版本, 我们还需要安装一下 TensorFlow插件addons,安装的过程大家可以通过下述命令完成:
![只能用于文本与图像数据?No!看TabTransformer嚆��]()
vuee7k+aehOWMluS4muWKoeaVsOaNrueyvuWHhuW7uuaooQ==" title="只能用于文本与图像数据?No!看TabTransformer嚆��"https://www.huyubaike.com/tag/1905.html" target="_blank" class="yzm-keyword-link">vuee7k+aehOWMluS4muWKoeaVsOaNrueyvuWHhuW7uuaooQ==" src="https://m.360buyimg.com/jdcms/jfs/t1/124212/23/30940/82382/635f9125E70d9c44b/b7f503820bf42152.png">
关于本篇代码实现中使用到的TensorFlow工具库,大家可以查看ShowMeAI制作的TensorFlow速查手册快学快用:
接下来我们导入工具库
数据说明
ShowMeAI在本例中使用到的是 美国人口普查收入数据集,任务是根据人口基本信息预测其年收入是否可能超过 50,000 美元,是一个二分类问题。
![只能用于文本与图像数据?No!看TabTransformer嚆��]()
vuee7k+aehOWMluS4muWKoeaVsOaNrueyvuWHhuW7uuaooQ==" title="只能用于文本与图像数据?No!看TabTransformer嚆��"https://www.huyubaike.com/tag/1905.html" target="_blank" class="yzm-keyword-link">vuee7k+aehOWMluS4muWKoeaVsOaNrueyvuWHhuW7uuaooQ==" src="https://m.360buyimg.com/jdcms/jfs/t1/190739/12/29611/117225/635f9129Ef4c046dd/752a60d4143332f3.png">
数据集可以在以下地址下载:
https://archive.ics.uci.edu/ml/datasets/Adult
https://archive.ics.uci.edu/ml/machine-learning-databases/adult/
数据从美国1994年人口普查数据库抽取而来,可以用来预测居民收入是否超过50K/year。该数据集类变量为年收入是否超过50k,属性变量包含年龄、工种、学历、职业、人种等重要信息,值得一提的是,14个属性变量中有7个类别型变量。数据集各属性是:其中序号0~13是属性,14是类别。
字段序号 |
字段名 |
含义 |
类型 |
0 |
age |
年龄 |
Double |
1 |
workclass |
工作类型* |
string |
2 |
fnlwgt |
序号 |
string |
3 |
education |
教育程度* |
string |
4 |
education_num |
受教育时间 |
double |
5 |
maritial_status |
婚姻状况* |
string |
6 |
occupation |
职业* |
string |
7 |
relationship |
关系* |
string |
8 |
race |
种族* |
string |
9 |
sex |
性别* |
string |
10 |
capital_gain |
资本收益 |
string |
11 |
capital_loss |
资本损失 |
string |
12 |
hours_per_week |
每周工作小时数 |
double |
13 |
native_country |
原籍* |
string |
14(label) |
income |
收入标签 |
string |
![只能用于文本与图像数据?No!看TabTransformer嚆��]()
vuee7k+aehOWMluS4muWKoeaVsOaNrueyvuWHhuW7uuaooQ==" title="只能用于文本与图像数据?No!看TabTransformer嚆��"https://www.huyubaike.com/tag/1905.html" target="_blank" class="yzm-keyword-link">vuee7k+aehOWMluS4muWKoeaVsOaNrueyvuWHhuW7uuaooQ==" src="https://m.360buyimg.com/jdcms/jfs/t1/169843/18/31590/341224/635f912aEa1056b90/b57aa7852232ebea.png">
我们先用pandas读取数据到dataframe中:
我们做点数据清洗,把测试集第一条记录剔除(它不是有效的数据示例),把类标签中的尾随的“点”去掉。
再把训练集和测试集存回单独的 CSV 文件中。
模型原理
TabTransformer的模型架构如下所示:
![只能用于文本与图像数据?No!看TabTransformer嚆��]()
vuee7k+aehOWMluS4muWKoeaVsOaNrueyvuWHhuW7uuaooQ==" title="只能用于文本与图像数据?No!看TabTransformer嚆��"https://www.huyubaike.com/tag/1905.html" target="_blank" class="yzm-keyword-link">vuee7k+aehOWMluS4muWKoeaVsOaNrueyvuWHhuW7uuaooQ==" src="https://m.360buyimg.com/jdcms/jfs/t1/109683/38/22950/55348/635f912bE6605ea18/98b5835edaa200b0.png">
我们可以看到,类别型的特征,很适合在 embedding 后,送入 transformer 模块进行深度交叉组合与信息挖掘,得到的信息与右侧的连续值特征进行拼接,再送入全连接的 MLP 模块进行组合和完成最后的任务(分类或者回归)。
模型实现
定义数据集元数据
要实现模型,我们先对输入数据字段,区分不同的类型(数值型特征与类别型特征)。我们会对不同类型的特征,使用不同的方式进行处理和完成特征工程(例如数值型的特征进行幅度缩放,类别型的特征进行编码处理)。
配置超参数
我们为神经网络的结构和训练过程的超参数进行设置,如下。
实现数据读取管道
下面我们定义一个输入函数,它负责读取和解析文件,并对特征和标签处理,放入 tf.data.Dataset
,以便后续训练和评估。
模型构建与评估
① 创建模型输入
基于 TensorFlow 的输入要求,我们将模型的输入定义为字典,其中『key/键』是特征名称,『value/值』为 keras.layers.Input
具有相应特征形状的张量和数据类型。
② 编码特征
我们定义一个encode_inputs
函数,返回encoded_categorical_feature_list
和 numerical_feature_list
。我们将分类特征编码为嵌入,使用固定的embedding_dims
对于所有功能, 无论他们的词汇量大小。 这是 Transformer 模型所必需的。
③ MLP模块实现
网络中不可或缺的部分是 MLP 全连接板块,下面是它的简单实现:
④ 模型实现1:基线模型
为了对比效果,我们先简单使用MLP(多层前馈网络)进行建模,代码和注释如下。
上述模型构建完成之后,我们通过plot_model操作,绘制出模型结构如下:
![只能用于文本与图像数据?No!看TabTransformer嚆��]() vuee7k+aehOWMluS4muWKoeaVsOaNrueyvuWHhuW7uuaooQ==" title="只能用于文本与图像数据?No!看TabTransformer嚆��"https://www.huyubaike.com/tag/1905.html" target="_blank" class="yzm-keyword-link">vuee7k+aehOWMluS4muWKoeaVsOaNrueyvuWHhuW7uuaooQ==" src="https://m.360buyimg.com/jdcms/jfs/t1/51096/23/22612/328131/635f912eEcaf6b052/37fecc62d473c408.png">
|
接下来我们训练和评估一下基线模型:
输出的训练过程日志如下:
我们可以看到基线模型(全连接MLP网络)实现了约 82% 的验证准确度。
⑤ 模型实现2:TabTransformer
![只能用于文本与图像数据?No!看TabTransformer嚆��]()
vuee7k+aehOWMluS4muWKoeaVsOaNrueyvuWHhuW7uuaooQ==" title="只能用于文本与图像数据?No!看TabTransformer嚆��"https://www.huyubaike.com/tag/1905.html" target="_blank" class="yzm-keyword-link">vuee7k+aehOWMluS4muWKoeaVsOaNrueyvuWHhuW7uuaooQ==" src="https://m.360buyimg.com/jdcms/jfs/t1/173399/25/30888/55069/635f912fEb8d0053b/f2b72ffddc77e601.png">
TabTransformer 架构的工作原理如下:
- 所有类别型特征都被编码为嵌入,使用相同的
embedding_dims
。
- 将列嵌入(每个类别型特征的一个嵌入向量)添加类别型特征嵌入中。
- 嵌入的类别型特征被输入到一系列的 Transformer 块中。 每个 Transformer 块由一个多头自注意力层和一个前馈层组成。
- 最终 Transformer 层的输出, 与输入的数值型特征连接,并输入到最终的 MLP 块中。
- 尾部由一个
softmax
结构完成分类。
def create_tabtransformer_classifier(
num_transformer_blocks,
num_heads,
embedding_dims,
mlp_hidden_units_factors,
dropout_rate,
use_column_embedding=False,
):
inputs = create_model_inputs()
encoded_categorical_feature_list, numerical_feature_list = encode_inputs(
inputs, embedding_dims
)
encoded_categorical_features = tf.stack(encoded_categorical_feature_list, axis=1)
numerical_features = layers.concatenate(numerical_feature_list)
if use_column_embedding:
num_columns = encoded_categorical_features.shape[1]
column_embedding = layers.Embedding(
input_dim=num_columns, output_dim=embedding_dims
)
column_indices = tf.range(start=0, limit=num_columns, delta=1)
encoded_categorical_features = encoded_categorical_features + column_embedding(
column_indices
)
for block_idx in range(num_transformer_blocks):
attention_output = layers.MultiHeadAttention(
num_heads=num_heads,
key_dim=embedding_dims,
dropout=dropout_rate,
name=f"multihead_attention_{block_idx}",
)(encoded_categorical_features, encoded_categorical_features)
x = layers.Add(name=f"skip_connection1_{block_idx}")(
[attention_output, encoded_categorical_features]
)
x = layers.LayerNormalization(name=f"layer_norm1_{block_idx}", epsilon=1e-6)(x)
feedforward_output = create_mlp(
hidden_units=[embedding_dims],
dropout_rate=dropout_rate,
activation=keras.activations.gelu,
normalization_layer=layers.LayerNormalization(epsilon=1e-6),
name=f"feedforward_{block_idx}",
)(x)
x = layers.Add(name=f"skip_connection2_{block_idx}")([feedforward_output, x])
encoded_categorical_features = layers.LayerNormalization(
name=f"layer_norm2_{block_idx}", epsilon=1e-6
)(x)
categorical_features = layers.Flatten()(encoded_categorical_features)
numerical_features = layers.LayerNormalization(epsilon=1e-6)(numerical_features)
features = layers.concatenate([categorical_features, numerical_features])
mlp_hidden_units = [
factor * features.shape[-1] for factor in mlp_hidden_units_factors
]
features = create_mlp(
hidden_units=mlp_hidden_units,
dropout_rate=dropout_rate,
activation=keras.activations.selu,
normalization_layer=layers.BatchNormalization(),
name="MLP",
)(features)
outputs = layers.Dense(units=1, activation="sigmoid", name="sigmoid")(features)
model = keras.Model(inputs=inputs, outputs=outputs)
return model
tabtransformer_model = create_tabtransformer_classifier(
num_transformer_blocks=NUM_TRANSFORMER_BLOCKS,
num_heads=NUM_HEADS,
embedding_dims=EMBEDDING_DIMS,
mlp_hidden_units_factors=MLP_HIDDEN_UNITS_FACTORS,
dropout_rate=DROPOUT_RATE,
)
print("Total model weights:", tabtransformer_model.count_params())
keras.utils.plot_model(tabtransformer_model, show_shapes=True, rankdir="LR")
最终输出的模型结构示意图如下(因为模型结构较深,总体很长,点击放大)
![只能用于文本与图像数据?No!看TabTransformer嚆��]() vuee7k+aehOWMluS4muWKoeaVsOaNrueyvuWHhuW7uuaooQ==" title="只能用于文本与图像数据?No!看TabTransformer嚆��"https://www.huyubaike.com/tag/1905.html" target="_blank" class="yzm-keyword-link">vuee7k+aehOWMluS4muWKoeaVsOaNrueyvuWHhuW7uuaooQ==" src="https://m.360buyimg.com/jdcms/jfs/t1/141485/13/30331/407790/635f9133Eb5647fae/fceec18aac5d6e89.png">
|
下面我们训练和评估一下TabTransformer 模型的效果:
TabTransformer 模型实现了约 85% 的验证准确度,相比于直接使用全连接网络效果有一定的提升。
参考资料
vuee7k+aehOWMluS4muWKoeaVsOaNrueyvuWHhuW7uuaooQ==" title="只能用于文本与图像数据?No!看TabTransformer嚆��"https://www.huyubaike.com/tag/1905.html" target="_blank" class="yzm-keyword-link">vuee7k+aehOWMluS4muWKoeaVsOaNrueyvuWHhuW7uuaooQ==" src="https://m.360buyimg.com/jdcms/jfs/t1/161974/11/31854/110576/635f9134E98c0eab4/9062d84760f69033.png">
标签:
留言评论