当前位置:网站首页>fastText学习——文本分类
fastText学习——文本分类
2022-07-29 05:21:00 【Quinn-ntmy】
之前主要有One-hot、Bag of Words、N-gram、TF-IDF词向量表示方法,但它们存在不足:
- 转换得到的向量维度很高,需要较长训练时间;
- 没有考虑单词与单词之间的关系,只是进行了统计。
且优于TF-IDF具体表现在:
1、FastText用单词的Embedding叠加获得的文档向量,将相似的句子分为一类;
2、FastText学习到的Embedding空间维度比较低,可以快速进行训练。
后将深度学习应用于文本表示,典型例子:fastText、Word2Vec、Bert。
接下来本文主要介绍fastText:
- 通过Embedding层将单词映射到稠密空间,然后将将整篇文档的词及n-gram向量叠加平均得到文档向量,进而做softmax多分类操作——可以大大减少模型训练时间。
- 主要涉及2个trick:字符级n-gram特征的引入以及分层Softmax分类
- 三层神经网络:输入层、隐含层和输出层。
输入是多个向量化的单词,附加了字符级别的n-gram作为特征输入;输出都是一个特定的target(文本对应的类标);隐含层是对多个词向量的叠加平均。
首先看一下fastText的网络结构:
# 使用keras实现FastText网络结构
from __future__ import unicode_literals
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding
from tensorflow.keras.layers import GlobalAveragePooling1D
from tensorflow.keras.layers import Dense
Vocab_size = 2000
Embedding_dim = 100
Max_words = 500
Class_num = 5
def build_fastText():
model = Sequential() # 看做一个容器
# 通过embedding层,将词汇映射成Embedding_dim维向量
model.add(Embedding(Vocab_size, Embedding_dim, input_length=Max_words))
# 通过GlobalAveragePooling1D,平均文档中所有词的embedding
model.add(GlobalAveragePooling1D())
# 通过输出层Softmax分类(真实的fastText这块是分层Softmax),得到类别概率分布
model.add(Dense(Class_num, activation='softmax'))
# 定义损失函数、优化器、分类度量指标
model.compile(loss='categorical_crossentropy', optimizer='SGD', metrics=['accuracy'])
return model
if __name__ == '__main__':
model = build_fastText()
print(model.summary())
fastText文本分类流程图:
【注】:重要知识点基本在注释上
- 数据读取
import pandas as pd
from sklearn.metrics import f1_score
train_df = pd.read_csv('../data/train_set.csv', sep='\t', nrows=15000)
- 数据处理
将数据转换为fastText需要的格式
train_df['label_ft'] = '__label__' + train_df['label'].astype(str)
# __label__: 类别前缀,__label__后面接类别,比如上面是str
train_df[['text', 'label_ft']].iloc[:-5000].to_csv('train.csv', index=None, header=None, sep='\t')
# iloc方法提供了基于整数的索引方式
- 训练模型
import fasttext
# fasttext.train_unsupervised()无监督用来训练词向量;fasttext.train_supervised()训练一个监督模型,返回一个模型对象
model = fasttext.train_supervised('train.csv', lr=1.0, wordNgrams=2,
verbose=2, minCount=1, epoch=25, loss='hs')
# 'hs'指hierarchical softmax(分层softmax) 类别数较多时,通过构建一个霍夫曼编码树来加速softmax layer的计算,和word2vec中的trick相同;
# minCount——词频阈值, 小于该值在初始化时会过滤掉
# verbose =0时,不输出日志信息,进度条、loss、acc这些都不输出; =1时,输出带进度条的输出日志信息; =2时,为每个epoch输出一行记录(不带进度条)
fasttext.train_supervised()
参数说明:
input 训练文件路径(必须)
lr 学习率 default 0.1
label 类别前缀 default __label__
lrUpdateRate 学习率更新速率 default 100
dim 词向量维度 default 100
ws 上下文窗口大小 default 5, cbow
epoch epochs 数量 default 5
minCount 最低词频 default 5
minCountLabel 类别阈值,类别小于该值初始化时会过滤掉
wordNgrams n-gram设置 default 1
loss 损失函数 {ns,hs,softmax} default softmax
minn 最小字符长度 default 0
maxn 最大字符长度 default 0
thread 线程数量 default 12
t 采样阈值 default 0.0001
silent 禁用 c++ 扩展日志输出 default 1
encoding 指定 input_file 编码 default utf-8
pretrainedVectors 预训练的词向量文件路径, 如果word出现在文件夹中初始化不再随机 default None
- 模型预测及评估
val_pred = [model.predict(x)[0][0].split('__')[-1] for x in train_df.iloc[-5000:]['text']]
print(f1_score(train_df['label'].values[-5000:].astype(str), val_pred, average='macro')) # 得分是0.8214....
边栏推荐
- 30 knowledge points that must be mastered in quantitative development [what is level-2 data]
- Process management of day02 operation
- 【Transformer】SOFT: Softmax-free Transformer with Linear Complexity
- 微信小程序源码获取(附工具的下载)
- 第三周周报 ResNet+ResNext
- Flutter正在被悄悄放弃?浅析Flutter的未来
- 虚假新闻检测论文阅读(二):Semi-Supervised Learning and Graph Neural Networks for Fake News Detection
- Thinkphp6 output QR code image format to solve the conflict with debug
- [overview] image classification network
- How to PR an open source composer project
猜你喜欢
Intelligent security of the fifth space ⼤ real competition problem ----------- PNG diagram ⽚ converter
How to PR an open source composer project
研究生新生培训第一周:深度学习和pytorch基础
【Transformer】ACMix:On the Integration of Self-Attention and Convolution
How to make interesting apps for deep learning with zero code (suitable for novices)
ROS教程(Xavier)
[overview] image classification network
Ribbon learning notes II
Super simple integration of HMS ml kit to realize parent control
【ML】机器学习模型之PMML--概述
随机推荐
[DL] introduction and understanding of tensor
Detailed explanation of tool classes countdownlatch and cyclicbarrier of concurrent programming learning notes
[pycharm] pycharm remote connection server
【Transformer】ACMix:On the Integration of Self-Attention and Convolution
Super simple integration HMS ml kit face detection to achieve cute stickers
Detailed explanation of atomic operation class atomicinteger in learning notes of concurrent programming
Technology that deeply understands the principle of MMAP and makes big manufacturers love it
研究生新生培训第二周:卷积神经网络基础
Huawei 2020 school recruitment written test programming questions read this article is enough (Part 1)
Are you sure you know the interaction problem of activity?
Exploration of flutter drawing skills: draw arrows together (skill development)
【Transformer】AdaViT: Adaptive Vision Transformers for Efficient Image Recognition
Reporting Services- Web Service
第三周周报 ResNet+ResNext
微信小程序源码获取(附工具的下载)
【数据库】数据库课程设计一一疫苗接种数据库
ASM插桩:学完ASM Tree api,再也不用怕hook了
主流实时流处理计算框架Flink初体验。
【Transformer】SOFT: Softmax-free Transformer with Linear Complexity
Most PHP programmers don't understand how to deploy safe code