Need help with GPT2-NewsTitle?
Click the “chat” button below for chat support from the developer who created it, or find similar developers for support.

About the developer

liucongg
203 Stars 37 Forks Apache License 2.0 24 Commits 0 Opened issues

Description

Chinese NewsTitle Generation Project by GPT2.带有超级详细注释的中文GPT2新闻标题生成项目。

Services available

!
?

Need anything else?

Contributors list

No Data

GPT2-NewsTitle

带有超详细注释的GPT2新闻标题生成项目

UpDate 01.02.2021

  • 从网上收集数据,将清华新闻数据、搜狗新闻数据等新闻数据集,以及开源的一些摘要数据进行整理清洗,构建一个较完善的中文摘要数据集。
  • 数据集清洗时,仅进行了简单地规则清洗。例如:清洗htlm标记、去除多余空字符、去除图片标记等。
  • 处理后数据集详细信息,见数据集描述

| 数据 | 原始数据/项目地址 | 处理后文件下载地址 | | ------ | ------ | ------ | | 清华新闻数据 | 地址 | 百度云盘 提取码: vhol | | 搜狗新闻数据 | 地址 | 百度云盘 提取码:ode6 | | nlpcc2017摘要数据 | 地址 | 百度云盘 提取码:e0zq | | csl摘要数据 | 地址 | 百度云盘 提取码:0qot | | 教育培训行业摘要数据 | 地址 | 百度云盘 提取码:kjz3 | | lcsts摘要数据 | 地址 | 百度云盘 提取码:bzov | | 神策杯2018摘要数据 | 地址 | 百度云盘 提取码:6f4f | | 万方摘要数据 | 地址 | 百度云盘 提取码: p69g| | 微信公众号摘要数据 | 地址 | 百度云盘 提取码: 5has| | 微博数据 | 地址 | 百度云盘 提取码: 85t5| | news2016zh新闻数据 | 地址 | 百度云盘 提取码: qsj1 |

数据集集合:百度云盘 提取码: 7am8

项目描述

  • 本项目是一个带有超级详细中文注释的基于GPT2模型的新闻标题生成项目。
  • 本项目参考了GPT2-ChineseGPT2-chitchatCDial-GPTGPT2等多个GPT2开源项目,并根据自己的理解,将代码进行重构,添加详细注释,希望可以帮助到有需要的人。
  • 本项目使用HuggingFace的transformers实现GPT2模型代码编写、训练及测试。
  • 本项目通过Flask框架搭建了一个Web服务,将新闻摘要生成模型进行工程化,可以通过页面可视化地体验新闻标题生成效果。
  • 本项目的代码详细讲解,可以自行阅读代码,也可查看代码注释介绍
  • 本项目提供的新闻标题模型是一个6层的小模型(其实是穷人没人卡,只能训练小模型),并且在训练该模型过程中,没有加载预训练的GPT2模型而是随机初始化的参数,并且训练轮数较少(5轮,还没收敛完),因此效果一般。如果想要更好效果的模型,可以按照个人需求训练一个模型。
  • 本项目的目的是带领大家走一遍GPT2生成模型的训练、测试及部署全部流程。

文件结构

  • config
    • config.json 模型的配置信息,包含nctx、nembd、nhead、nlayer等。
  • vocab
    • vocab.txt 字典文件,该字典为大小为13317,删除了将原始字典中的“##中文”,并且增加了“[Content]”、“[Title]”、“[Space]”等标记。
  • data_dir 存放数据的文件夹
  • templates 存放html页面的文件夹
  • data_helper.py 数据预处理文件,将数据进行简单的清洗
  • data_set.py 数据类文件,定义模型所需的数据类,方便模型训练使用
  • model.py GPT2模型文件,主要对transformers包中GPT2LMHeadModel的重写,修改计算loss部分,只计算预测title部分的loss
  • train.py 通过新闻正文生成新闻标题的GPT2模型的训练文件
  • generate_title.py 根据训练好的模型,进行新闻标题生成,预测文件
  • http_server.py 构建web服务文件

运行环境

  • gevent == 1.3a1
  • flask == 0.12.2
  • transformers == 3.0.2

详细见requirements.txt文件

数据集

数据来源于新浪微博,数据链接:https://www.jianshu.com/p/8f52352f0748?tdsourcetag=spcqqaiomsg

| 数据描述 | 下载地址 | | ------ | ------------- | | 原始数据 | 百度网盘,提取码: nqzi | | 处理后数据 | 百度网盘,提取码: duba |

原始数据为直接从网上下载的新闻数据,处理后数据为使用data_helper.py处理过的数据,可直接用于训练。

模型参数

详细见config/config.json文件

| 参数 | 值 | | ------ | ------------- | | initializerrange | 0.02 | | layernormepsilon | 1e-05 | | nctx | 512 | | nembd | 768 | | nhead | 12 | | nlayer | 6 | | npositions | 512 | | vocab_size | 13317 |

注意:模型输入除了各个词的向量表示外,还包括文字段落向量表示和位置向量表示。

模型文件分享

| 模型 | 下载地址 | | ------ | ------------- | | GPT2模型 | 百度网盘,提取码: 165b |

模型训练

python3 train.py
或
python3 train.py --output_dir output_dir/(自定义保存模型路径) 

训练参数可自行添加,包含参数具体如下:

| 参数 | 类型 | 默认值 | 描述 | | ------ | ------ | ------ | ------ | | device | str | "0" | 设置训练或测试时使用的显卡 | | configpath | str | "config/config.json" | 模型参数配置信息 | | vocabpath | str | "vocab/vocab.txt" | 词表,该词表为小词表,并增加了一些新的标记 | | trainfilepath | str | "datadir/traindata.json" | 新闻标题生成的训练数据 | | testfilepath | str | "datadir/testdata.json" | 新闻标题生成的测试数据 | | pretrainedmodelpath | str | None | 预训练的GPT2模型的路径 | | datadir | str | "datadir" | 生成缓存数据的存放路径 | | numtrainepochs | int | 5 | 模型训练的轮数 | | trainbatchsize | int | 16 | 训练时每个batch的大小 | | testbatchsize | int | 8 | 测试时每个batch的大小 | | learningrate | float | 1e-4 | 模型训练时的学习率 | | warmupproportion | float | 0.1 | warm up概率,即训练总步长的百分之多少,进行warm up操作 | | adamepsilon | float | 1e-8 | Adam优化器的epsilon值 | | loggingsteps | int | 20 | 保存训练日志的步数 | | evalsteps | int | 4000 | 训练时,多少步进行一次测试 | | gradientaccumulationsteps | int | 1 | 梯度积累 | | maxgradnorm | float | 1.0 | | | outputdir | str | "outputdir/" | 模型输出路径 | | seed | int | 2020 | 随机种子 | | maxlen | int | 512 | 输入模型的最大长度,要比config中n_ctx小 |

或者修改train.py文件中的set_args函数内容,可修改默认值。

本项目提供的模型,共训练了5个epoch,模型训练损失和测试集损失分别如下:

模型其实还没有训练完全,按照loss走势,还可以继续训练。

模型测试

python3 generate_title.py
或
python3 generate_title.py --top_k 3 --top_p 0.9999 --generate_max_len 32

参数可自行添加,包含参数具体如下:

| 参数 | 类型 | 默认值 | 描述 | | ------ | ------ | ------ | ------ | | device | str | "0" | 设置训练或测试时使用的显卡 | | modelpath | str | "outputdir/checkpoint-139805" | 模型文件路径 | | vocabpath | str | "vocab/vocab.txt" | 词表,该词表为小词表,并增加了一些新的标记 | | batchsize | int | 3 | 生成标题的个数 | | generatemaxlen | int | 32 | 生成标题的最大长度 | | repetitionpenalty | float | 1.2 | 重复处罚率 | | topk | int | 5 | 解码时保留概率最高的多少个标记 | | topp | float | 0.95 | 解码时保留概率累加大于多少的标记 | | maxlen | int | 512 | 输入模型的最大长度,要比config中n_ctx小 |

测试结果如下: ``` 从测试集中抽一篇 content: 今日,中国三条重要高铁干线——兰新高铁、贵广铁路和南广铁路将开通运营。其中兰新高铁是中国首条高原高铁,全长1776公里,最高票价658元。贵广铁路最贵车票320元,南广铁路最贵车票206.5元,这两条线路大大缩短西南与各地的时空距离。出行更方便了!中国“高铁版图”再扩容 三条重要高铁今日开通 title: 生成的第1个标题为:中国“高铁版图”再扩容 三条重要高铁今日开通 生成的第2个标题为:贵广铁路最高铁版图 生成的第3个标题为:出行更方便了!中国“高铁版图”再扩容三条重要高铁今日开通

从网上随便找一篇新闻 content: 值岁末,一年一度的中央经济工作会议牵动全球目光。今年的会议,背景特殊、节点关键、意义重大。12月16日至18日。北京,京西宾馆。站在“两个一百年”奋斗目标的历史交汇点上,2020年中央经济工作会议谋划着中国经济发展大计。习近平总书记在会上发表了重要讲话,深刻分析国内外经济形势,提出2021年经济工作总体要求和政策取向,部署重点任务,为开局“十四五”、开启全面建设社会主义现代化国家新征程定向领航。 title: 生成的第1个标题为:习近平总书记在京会上发表重大计划 提出2025年经济工作总体要求和政策 生成的第2个标题为:习近平总书记在会上发表重要讲话 生成的第3个标题为:习近平总书记在会上发表重要讲话,深刻分析国内外经济形势 ``` 解码采用topk和topp解码策略,有一定的随机性,可重复生成。

启动Flask服务

python3 http_server.py
或
python3 http_server.py --http_id "0.0.0.0" --port 5555

本地测试直接使用"127.0.0.1:5555/news-title-generate",如果给他人访问,只需将"127.0.0.1"替换成的电脑的IP地址即可。

具体如下图所示:

未来工作

  • [x] 后期可能会将清华新闻数据、搜狗新闻数据等新闻数据集进行整理清洗,构建一个较完善的新闻标题数据集。
  • [ ] 后期可能会使用新闻数据训练一个小的GPT2预训练模型。
  • [ ] 后期可能会对已上传的新闻标题模型进行更新,训练一个效果较好的模型。

致谢

参考

Citing

@misc{GPT2-NewsTitle,
  author = {Cong Liu},
  title = {Chinese NewsTitle Generation Project by GPT2},
  year = {2020},
  publisher = {GitHub},
  journal = {GitHub repository},
  url="https://github.com/liucongg/GPT2-NewsTitle",
}

联系作者

e-mail:[email protected]

知乎:刘聪NLP

公众号:NLP工作站

We use cookies. If you continue to browse the site, you agree to the use of cookies. For more information on our use of cookies please see our Privacy Policy.