当前位置:网站首页>基于LSTM模型实现新闻分类
基于LSTM模型实现新闻分类
2022-07-01 21:32:00 【拼命_小李】
1、简述LSTM模型
LSTM是长短期记忆神经网络,根据论文检索数据大部分应用于分类、机器翻译、情感识别等场景,在文本中,主要使用tensorflow及keras,搭建LSTM模型实现新闻分类案例。(只讨论和实现其模型的应用案例,不去叙述起实现原理)
2、 数据处理
需要有新闻数据和停用词文档做前期的数据准备工作,使用jieba分词和pandas对初始数据进行预处理工作,数据总量为12000。初始数据集如下图:

首先读取停用词列表,其次使用pandas对数据文件读取,使用jieba库对每行数据进行分词及停用词的处理,处理代码如下图:
def get_custom_stopwords(stop_words_file):
with open(stop_words_file,encoding='utf-8') as f:
stopwords = f.read()
stopwords_list = stopwords.split('\n')
custom_stopwords_list = [i for i in stopwords_list]
return custom_stopwords_list
cachedStopWords = get_custom_stopwords("stopwords.txt")
import pandas as np
import jieba
data = np.read_csv("sohu_test.txt", sep="\t",header=None)
lable_dict = {v:k for k,v in enumerate(data[0].unique())}
data[0] = data[0].map(lable_dict)
def chinese_word_cut(mytext):
return " ".join([word for word in jieba.cut(mytext) if word not in cachedStopWords])
data[1] = data[1].apply(chinese_word_cut)
data3、文本数据向量化
设置模型初始参数batch_size:每轮的数据批次,class_size:类别,epochs:训练轮数,num_words:最常出现的词的数(该变量在进行Embeding的时候需要填入词汇表大小+1),max_len :文本向量的维度,使用Tokenizer实现向量构建并作padding
batch_size = 32
class_size = 12
epochs = 300
num_words = 5000
max_len = 600
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing import sequence
tokenizer = Tokenizer(num_words=num_words)
tokenizer.fit_on_texts(data[1])
# print(tokenizer.word_index)
# train = tokenizer.texts_to_matrix(data[1])
train = tokenizer.texts_to_sequences(data[1])
train = sequence.pad_sequences(train,maxlen=max_len)4、模型搭建
使用train_test_split对数据集进行数据拆分,并搭建模型
from tensorflow.keras.layers import *
from tensorflow.keras import Sequential
from tensorflow.keras.models import load_model
from tensorflow.keras import optimizers
from keras.utils import np_utils
from sklearn.model_selection import train_test_split
from tensorflow.keras.callbacks import ModelCheckpoint
import numpy as np
lable = np_utils.to_categorical(data[0], num_classes=12)
X_train, X_test, y_train, y_test = train_test_split(train, lable, test_size=0.1, random_state=200)
model = Sequential()
model.add(Embedding(num_words+1, 128, input_length=max_len))
model.add(LSTM(128,dropout=0.2, recurrent_dropout=0.2))
# model.add(Dense(64,activation="relu"))
# model.add(Dropout(0.2))
# model.add(Dense(32,activation="relu"))
# model.add(Dropout(0.2))
model.add(Dense(class_size,activation="softmax"))
# 载入模型
# model = load_model('my_model2.h5')
model.compile(optimizer = 'adam', loss='categorical_crossentropy',metrics=['accuracy'])
checkpointer = ModelCheckpoint("./model/model_{epoch:03d}.h5", verbose=0, save_best_only=False, save_weights_only=False, period=2)
model.fit(X_train, y_train, validation_data = (X_test, y_test), epochs=epochs, batch_size=batch_size, callbacks=[checkpointer])
# model.fit(X_train, y_train, validation_split = 0.2, shuffle=True, epochs=epochs, batch_size=batch_size, callbacks=[checkpointer])
model.save('my_model4.h5')
# print(model.summary())训练过程如下图所示:

5、模型训练结果可视化
import matplotlib.pyplot as plt
# 绘制训练 & 验证的准确率值
plt.plot(model.history.history['accuracy'])
plt.plot(model.history.history['val_accuracy'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()
# 绘制训练 & 验证的损失值
plt.plot(model.history.history['loss'])
plt.plot(model.history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()
6、模型预测
text = "独家新闻 章子怡 机场 暴走 奔 情郎 图 提要 6 月 明星 充实 一边 奉献 爱心 积极参与 赈灾 重建 活动 "
text = [" ".join([word for word in jieba.cut(text) if word not in cachedStopWords])]
# tokenizer = Tokenizer(num_words=num_words)
# tokenizer.fit_on_texts(text)
seq = tokenizer.texts_to_sequences(text)
padded = sequence.pad_sequences(seq, maxlen=max_len)
# np.expand_dims(padded,axis=0)
test_pre = test_model.predict(padded)
test_pre.argmax(axis=1)
7、代码下载
边栏推荐
- leetcode刷题:栈与队列01(用栈实现队列)
- Architect graduation summary
- Comprehensive evaluation and detailed inventory of high-quality note taking software (I) note, obsedian, remnote, flowus
- NSI脚本的测试
- C中main函数的几种写法
- 基于YOLOv5的口罩佩戴检测方法
- 300 linear algebra Lecture 4 linear equations
- 2022年高处安装、维护、拆除考题模拟考试平台操作
- PHP gets the external chain address of wechat applet and applet store
- Internship: complex JSON format data compilation interface
猜你喜欢

【智能QbD风险评估工具】上海道宁为您带来LeanQbD介绍、试用、教程

升级版手机检测微信工具小程序源码-支持多种流量主模式

编程英语生词笔记本

MySQL数据库驱动(JDBC Driver)jar包下载

联想电脑怎么连接蓝牙耳机?

十三届蓝桥杯B组国赛

Comprehensive evaluation and detailed inventory of high-quality note taking software (I) note, obsedian, remnote, flowus

运放-滞回(迟滞)比较器全流程实战计算

【深度学习】利用深度学习监控女朋友的微信聊天?

如何用OpenMesh创建一个四棱锥
随机推荐
Gaussdb (for MySQL):partial result cache, which accelerates the operator by caching intermediate results
GCC编译
安装mysql时出现:需要这两个包perl(Data::Dumper),perl(JSON)
新版Free手机、PC、平板、笔记本四端网站缩略展示图在线一键生成网站源码
leetcode刷题:二叉树01(二叉树的前序遍历)
杰理之、产线装配环节【篇】
8K HDR!| Hevc hard solution for chromium - principle / Measurement Guide
EDA工具对芯片产业的重要性知识科普
2022年低压电工考试试题及答案
喜马拉雅自研网关架构演进过程
Test of NSI script
人才近悦远来,望城区夯实“强省会”智力底座
从20s优化到500ms,我用了这三招
EURA eurui E1000 series inverter uses PID to realize the relevant parameter setting and wiring of constant pressure water supply function
【智能QbD风险评估工具】上海道宁为您带来LeanQbD介绍、试用、教程
面试题:MySQL的union all和union有什么区别、MySQL有哪几种join方式(阿里面试题)[通俗易懂]
AirServer手机第三方投屏电脑软件
Vulnerability recurrence - Net ueeditor upload
2022安全员-B证考试练习题模拟考试平台操作
internship:复杂json格式数据编写接口