couplet

command
v1.0.1 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Jun 17, 2023 License: MIT Imports: 10 Imported by: 0

README

对对联

该示例使用GPT模型来进行自动对对联,训练过程中使用开源数据集couplet-dataset进行训练,最终效果如下

$ go run main.go evaluate --model model1M 投石向天跟命斗
load embedding...
model loaded
inputs: [1256 190 382 11 2671 620 491]
map[1.9197799:[品] 3.3221273:[佐] 3.426636:[闭]]
map[1.3936338:[乡] 2.4676633:[堂] 4.7822757:[门]]
map[4.367066:[邀] 6.651759:[伴] 8.640822:[问]]
map[2.7545571:[角] 3.9995427:[卷] 4.0938125:[海]]
map[0.34838527:[兼] 0.7547303:[溅] 5.4571667:[与]]
map[3.0720437:[初] 3.3736944:[先] 7.8570585:[时]]
map[-0.6402834:[禅] 0.3423911:[留] 4.0661287:[争]]
闭门问海与时争

$ go run main.go evaluate --model model1M 忽忽几晨昏,离别间之,疾病间之,不及终年同静好
load embedding...
model loaded
inputs: [1410 1410 144 807 1132 2 725 431 128 250 2 1620 914 128 250 2 14 1267 408 31 96 313 95]
map[3.2363486:[偶] 3.5864816:[恰] 3.9696732:[茕]]
map[1.8859183:[已] 2.6491146:[正] 5.8938594:[茕]]
map[3.6261964:[大] 4.3181977:[正] 5.94763:[小]]
map[0.8318534:[秀] 1.1576941:[尧] 4.0629478:[儿]]
map[1.1707921:[卧] 2.194219:[诗] 4.736803:[女]]
map[3.2818966:[丝] 8.290359:[?] 12.927086:[,]]
map[-0.19938715:[也] 2.2089927:[香] 7.078493:[孱]]
map[2.0709915:[英] 2.7168777:[盛] 7.2097497:[羸]]
map[1.8610702:[边] 4.1446333:[下] 5.9395804:[若]]
map[1.3027655:[淹] 1.3096942:[公] 4.4640884:[此]]
map[1.6467404:[归] 2.05125:[?] 9.954161:[,]]
map[2.7599795:[登] 2.9466662:[雄] 5.2124147:[娇]]
map[-0.4418694:[歌] 0.5701082:[缤] 5.0582767:[憨]]
map[3.122329:[惊] 3.4507442:[外] 8.34822:[若]]
map[0.24478196:[了] 1.7582227:[可] 5.895115:[此]]
map[1.1890669:[千] 6.6342874:[?] 12.040789:[,]]
map[2.8971107:[堪] 3.3520665:[何] 6.9046826:[更]]
map[0.6985532:[种] 0.8558793:[立] 6.1792207:[烦]]
map[1.037328:[有] 1.7161521:[几] 1.8144897:[二]]
map[0.2053172:[领] 0.6163587:[载] 2.3612778:[老]]
map[3.3555043:[永] 6.1910315:[与] 8.306862:[费]]
map[4.0877986:[谐] 4.6328783:[闲] 7.0630946:[精]]
map[0.90152544:[强] 1.8571142:[坤] 3.6386757:[神]]
茕茕小儿女,孱羸若此,娇憨若此,更烦二老费精神

$ go run main.go evaluate --model model1M 我是谁
load embedding...
model loaded
inputs: [85 62 192]
map[3.805028:[妻] 4.5206566:[谁] 5.46013:[你]]
map[1.9964612:[连] 3.2622404:[当] 4.869252:[为]]
map[0.652869:[相] 1.5072656:[以] 2.6890714:[自]]
你为自

共计134万参数,词表大小4435个字(只训练了前1万个样本)

+------------------------+---------+
|          NAME          |  COUNT  |
+------------------------+---------+
| transformer0_attention |   62208 |
| transformer0_dense     |   66048 |
| transformer0_output    |   65664 |
| transformer1_attention |   62208 |
| transformer1_dense     |   66048 |
| transformer1_output    |   65664 |
| transformer2_attention |   62208 |
| transformer2_dense     |   66048 |
| transformer2_output    |   65664 |
| transformer3_attention |   62208 |
| transformer3_dense     |   66048 |
| transformer3_output    |   65664 |
| output                 |  572244 |
| total                  | 1347924 |
+------------------------+---------+

train 200, cost=56m41.294214246s, loss=0.646975

模型参数配置

const embeddingDim = 128 // 128个float32表示一个字向量
const paddingSize = 34   // 最长为34
const heads = 8
const batchSize = 128
const epoch = 200
const lr = 0.001
const transformerSize = 4

编译

调整logic/model/params.go中的参数后使用以下命令进行编译

go build

模型训练

# 下载数据集
./couplet download
# 对数据集进行裁剪,提高训练速度
./couplet cut 10000
# 模型训练
./couplet train

模型推理

$ ./couplet evaluate --model ./model1M 丹枫江冷人初去
load embedding...
model loaded
inputs: [338 758 51 394 6 539 155]
map[3.2917404:[青] 3.3120193:[刚] 3.6883862:[绿]]
map[1.7707646:[桃] 2.7732375:[致] 6.6006026:[柳]]
map[3.0857203:[琐] 3.4510996:[岭] 6.9275184:[堤]]
map[2.6540723:[红] 4.5055385:[寒] 6.616006:[新]]
map[4.680617:[眼] 5.455213:[燕] 6.5080013:[月]]
map[1.6231927:[好] 2.3431666:[不] 7.3107166:[复]]
map[2.0388474:[欢] 4.897517:[归] 11.395214:[来]]
绿柳堤新月复来

由于GPT模型是一个字一个字进行推理的,因此输出内容中的每一行表示该位置上的输出字概率

Documentation

The Go Gopher

There is no documentation for this package.

Directories

Path Synopsis
logic

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL