首页 > 其他分享 >基于keras采用LSTM实现多标签文本分类

基于keras采用LSTM实现多标签文本分类

时间:2023-03-17 16:34:28浏览次数:32  
标签:keras 标签 label test train print import LSTM data

我先抓取博客园知识库的文章标题和分类

代码:

#coding=utf-8

import os
import sys
import requests
from lxml import etree,html
import lxml
import time
import re

filepath = 'data/bokeyuan_fenlei.csv'


def zhuaqudata():
    page = 1
    print("开始抓取%s页..." % page)
    (haslast,titles,fenleis) = getwenzhangandnext(page)
    for i,title in enumerate(titles):
        fenlei = fenleis[i]
        print('[%s] %s' % (fenlei, title))
        writefile(filepath, "[%s] %s\n" % (fenlei, title))
    print()
    while haslast:
        page = page + 1
        print("开始抓取%s页..." % page)
        (haslast,titles,fenleis) = getwenzhangandnext(page)
        for i,title in enumerate(titles):
            fenlei = fenleis[i]
            print('[%s] %s' % (fenlei, title))
            writefile(filepath, "[%s] %s\n" % (fenlei, title))
        print()
        
def getwenzhangandnext(page):
    baseurl = 'https://kb.cnblogs.com/'
    if page == 1:
        url = baseurl
    else:
        url = baseurl + str(page)+'/'
    print(url)
    content = geturl(url)
    htmlcontent = etree.HTML(content)
    
    titles = []
    fenleis = []
    ps = htmlcontent.xpath('//div[@class="list_block"]//div[@class="msg_title"]//p')
    for p in ps:
        phtml = html.tostring(p).decode('utf-8')
        pcontent = etree.HTML(phtml)
        if not 'span' in phtml:
            continue
        else:
            title = pcontent.xpath('//a//@title')[0]
            fenlei = pcontent.xpath('//span//text()')[0]
            titles.append(title)
            fenleis.append(fenlei)

    haslasttext = str(htmlcontent.xpath('//div[@id="pager_block"]//div[@id="pager"]//a[last()]//text()')[0])
    
    for i,title in enumerate(titles):
        titles[i] = formatstr(title)
        
    for i,fenlei in enumerate(fenleis):
        fenleis[i] = formatstr(fenlei)

    haslast = 0
    if 'next' in haslasttext.lower():
        haslast = 1
        #print("存在下一页")
    else:
        #print("不存在下一页")
        pass
    
    time.sleep(3)
    return haslast,titles,fenleis

def formatstr(str):
    res = re.findall('[0-9a-zA-Z\u4e00-\u9fa5:、?!,]', str)
    return ''.join(res)

def readfile(filepath):
    fp = open(filepath, 'r', encoding='utf-8')
    res = fp.read()
    fp.close()
    return res

def writefile(filepath, s):
    fp = open(filepath, 'a', encoding='utf-8')
    fp.write(s)
    fp.close()

def geturl(url):
    header = {
        'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:95.0) Gecko/20100101 Firefox/95.0'
    }
    res = requests.get(url,headers=header)
    res.encoding = res.apparent_encoding
    return res.text

if __name__ == '__main__':
    zhuaqudata()

 

结果:

 

 

然后通过程序读出文件,建立数据和标签的对应关系,进行编码,建模,训练,测试

代码:

#coding=utf-8

import os
import sys
import re
import jieba
from sklearn.preprocessing import MultiLabelBinarizer
from keras.preprocessing.text import Tokenizer
from keras_preprocessing.sequence import pad_sequences
from keras.models import Sequential,Model,load_model
import numpy as np
from keras.layers import Dense, Input, Flatten, Dropout, LSTM
from keras.layers import Conv1D, MaxPooling1D, Embedding, GlobalMaxPooling1D, SpatialDropout1D
import random

filepath = 'data/bokeyuan_fenlei.csv'
stopwordfilepath = 'data/cn_stopwords.txt'

def readfile(filepath):
    fp = open(filepath, 'r', encoding='utf-8')
    res = fp.read()
    fp.close()
    return res

def writefile(filepath, s):
    fp = open(filepath, 'a', encoding='utf-8')
    fp.write(s)
    fp.close()
    
def duqushuju():
    text = readfile(filepath)
    stop_text = readfile(stopwordfilepath)
    stopwords = [i for i in stop_text.split('\n') if i.strip()]
    res = re.findall('\[(.*?)\](.*?)\n', text)
    titles = []
    fenleis = []
    #random.shuffle(res)
    for i,j in res:
        fenleis.append([i])
        titles.append(contentsplit(j, stopwords))
    
    trainlen = 0#int(len(fenleis) * 0.8)
    
    if trainlen > 0:
        train_data = titles[:trainlen]
        train_label = fenleis[:trainlen]
        test_data = titles[trainlen:]
        test_label = fenleis[trainlen:]
    else:
        train_data = titles[:]
        train_label = fenleis[:]
        test_data = titles[:]
        test_label = fenleis[:]
        
    all_data = titles
    all_fenlei = fenleis
    
    return all_data,all_fenlei,train_data,train_label,test_data,test_label

def contentsplit(segment, stopwords):
    segment = formatstr(segment)
    segments = jieba.cut(segment)
    segments = [i for i in segments if i.strip() and i.strip() not in stopwords and len(i) > 1]
    seg = " ".join(segments)
    return seg

def formatstr(str):
    res = re.findall('[0-9a-zA-Z\u4e00-\u9fa5]', str)
    return ''.join(res)
    
if __name__ == '__main__':
    all_data,all_fenlei,train_data,train_label,test_data,test_label = duqushuju()
    print('总分类大小:%s' % len(all_fenlei))
    print('总标题大小:%s' % len(all_data))
    print('训练分类大小:%s' % len(train_label))
    print('训练标题大小:%s' % len(train_data))
    print('测试分类大小:%s' % len(test_label))
    print('测试标题大小:%s' % len(test_data))
    
    train_dict = {}
    for i,j in enumerate(train_label):
        train_dict[i] = j
    
    # 标签向量化
    mutil_lab = MultiLabelBinarizer()
    train_label_code = mutil_lab.fit_transform(train_label)
    
    mutil_lab = MultiLabelBinarizer()
    test_label_code = mutil_lab.fit_transform(test_label)
    
    tokenizer = Tokenizer(num_words=40000, filters='!"#$%&()*+,-./:;<=>?@[\]^_`{|}~', lower=True)
    tokenizer.fit_on_texts(train_data)
    #print(tokenizer.word_index)
    
    # 利用Tokenizer 向量化文本
    x_data = tokenizer.texts_to_sequences(train_data)
    x_data = pad_sequences(x_data, 100)
    y_data = np.array(train_label_code)
    
    # 利用Tokenizer 向量化文本
    x_test_data = tokenizer.texts_to_sequences(test_data)
    x_test_data = pad_sequences(x_test_data, 100)
    y_test_data = np.array(test_label_code)
    
    print("训练集的大小为: ", x_data.shape, "训练集标签的大小为: ", y_data.shape)
    print("测试集的大小为: ", x_test_data.shape, "测试集标签的大小为: ", y_test_data.shape)
    
    model_path = 'models/wenben_fenlei_lstm.h5'
    if os.path.exists(model_path):
        model = load_model(model_path)
    else:
        # 构建模型
        inputs = Input(shape=(100,))
        embed = Embedding(40000, 100, input_length=x_data.shape[1])(inputs)
        dropout = SpatialDropout1D(0.2)(embed)
        
        # 注意LSTM层的参数是为了能够用上cuDNN的加速
        lstm = LSTM(100, dropout=0.2, recurrent_dropout=0, activation='tanh', recurrent_activation='sigmoid')(dropout)
        output = Dense(y_data.shape[1], activation='sigmoid')(lstm)
        model = Model(inputs, output)
        model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
        model.summary()# 评估模型
        
        model.fit(x_data, y_data, batch_size=16, epochs=20, validation_data=(x_test_data, y_test_data))
        
        model.save(model_path)
    
    n = 3
    pre = model.predict(x_data[:n], n)
    for i in range(n):
        print('[%s] %s' % (','.join(train_label[i]), train_data[i]))
        print('预测值为:%s'  % ','.join(train_dict[pre[i].argmax()]))
        print()
    
    
    ceshi_data = ['FWT/快速沃尔什变换 入门指南', '如何在 Apinto 实现 HTTP 与gRPC 的协议转换 (下)', '万字血书Vue—Vue语法', '云图说丨初识华为云安全云脑——新一代云安全运营中心']
    # 利用Tokenizer 向量化文本
    x_ceshi_data = tokenizer.texts_to_sequences(ceshi_data)
    x_ceshi_data = pad_sequences(x_ceshi_data, 100)
    
    n = 4
    pre = model.predict(x_ceshi_data[:n], n)
    for i in range(n):
        print('%s' % ceshi_data[i])
        print('预测值为:%s'  % ','.join(train_dict[pre[i].argmax()]))
        print()

 

停词的data/cn_stopwords.txt 你可以随便创建一个,空的也没有问题,只是会影响到切词准确与否的问题

我先对训练库的前三个标题做了预测,基本正确,后对4个博客文章的标题做了预测,至少是出结果了。

效果:

 

参考:https://blog.csdn.net/qq_56154355/article/details/125685955

标签:keras,标签,label,test,train,print,import,LSTM,data
From: https://www.cnblogs.com/xuxiaobo/p/17227240.html

相关文章

  • 【python爬虫】bs4遍历、搜索文档树 bs4使用css选择器 selenium基本使用 selenium查
    目录上节回顾今日内容0bs4遍历文档树1bs4搜索文档树1.1find方法的其他参数2css选择器3selenium基本使用4无界面浏览器4.1模拟登录百度5selenium其它用法5.0查找标......
  • p标签单行文本溢出和多行文本溢出显示省略号解决方法
    单行p{ overflow:hidden; text-overflow:ellipsis; white-space:nowrap;}多行//多行显示省略号,数字3为超出3行显示,p{ display:-webkit-box; -webkit-box......
  • tensorflow.keras.datasets 中关于imdb.load_data的使用说明
    python深度学习在加载数据时(num_words=10000)所代表的意义首先写一段深度学习加载数据集的代码:fromkeras.datasetsimportreuters(train_data,train_labels),(test_dat......
  • HTML5新增标签
    HTML5新增标签 HTML5是HTML最新的修订版本,2014年10月由万维网联盟(w3c)完成标准制定在HTML5出现之前,我们一般采用DIV+CSS布局我们的页面.但是这样的布局方式不仅使......
  • Day02 2.2、HTML基础之表单标签
    二、表单标签是HTML中最终的标签之一,主要是提供了输入框或按钮等标签提供给用户进行交互输入数据。将来表单可以提交到指定服务端程序中进行数据处理。1form标签......
  • Day02 2.3、HTML基础之表单标签的基本使用
    三、表格标签表格系列标签主要是可以数据以表格的格式展示出来。但是现在table表格已经很少使用了,而是改成div+css实现更漂亮的表格。标签描述<table></table......
  • Day02 2.3、HTML基础之标签的练习案例
    使用table+表单,把课堂上的form标签的代码,整理成以下格式(不要外观):<!DOCTYPEhtml><htmllang="en"><head><metacharset="UTF-8"><title>Title</title></h......
  • Day02 2.1、HTML基础之列表标签
    一、列表标签列表是一种结构标签可以让网页的内容形成列表格式。列表标签在HTML中提供提供了4种:无序列表(UnorderList,ul)就是没有序号的,内容不分先后......
  • img标签如何添加动态src地址
    把图片当成模块用require引入图片地址(不带图片名称),后面加上循环遍历的图片名称拼接后就可以展示图片。即 require('@/assets/images/home/'+item.url)  // item.......
  • 6_JSTL格式化标签
    ​ JSTL格式化标签格式化标签库格式化标签库,也叫作fmt标签,是JTSL中的第二大组成部分,主要解决数据显示格式问题,让JSP页面的数据格式更加规范格式化标签库导入的语......