去年友人大竹講師在COSCUP 2016 和大家分享的議程是關於nas上用TensorFlow做圖形辨識
透過大竹的講解,也讓安迪兒開始接觸TensorFlow
謎之音: 感謝大竹帶安迪兒去講者晚宴大吃大喝的~XD
安迪兒那時就一直想試試用TensorFlow分析文字語意之類
直到最近安迪兒看到很熱門的AlphaGo就是棋靈王裡的sai~XD
就好玩也試著弄了聊天機器人(為了讓首領同事開心)
以下記錄一下建這隻bot過程
[相關文章]
大竹的COSCUP 2016 - NAS也會揀土豆
原文網址:http://kaichu.io/2016/08/22/retrain-inception-model-for-nas/TensorFlow Sequence-to-Sequence Models
原文網址:https://www.tensorflow.org/tutorials/seq2seq/TensorFlowのSeq2Seqモデルでチャットボットっぽいものを作ってみた
原文網址:http://qiita.com/San_/items/128bf1b5a898ad5c18f1
關於Sequence-to-Sequence Models
Sequence to Sequence模型是近幾年來熱門的一個基於RNN的模型
廣範的用在機器翻譯、自動問答系統等領域
並有不錯的成效果
有興趣可以參考上方相關文章裡面的連結
事前準備
- TensorFlow安裝:MAC上安裝TENSORFLOW
- slack機器人key申請:改天補上XD
開始製作對話機器人
安迪兒參考了日本的San關於他的會話AI製作文章
(附在上方相關文章裡)
TensorFlow最重要的就是符合相關目的資料來源
但怎麼收集中文的對話資料是有點麻煩的
資料餵不對,答非所問就算了,機器人會練歪喔
資料怎麼來?
安迪兒搜了一下 電影、字幕、小說等
最找到了這個
(流淚感謝好心大大的開源分享)
>dgk_lost_conv
chinese conversation corpus
可以用作聊天机器人的训练语料
https://github.com/rustch3n/dgk_lost_conv
整理資料
參考San文章的內容
>http://qiita.com/San_/items/128bf1b5a898ad5c18f1
https://github.com/sanshirookazaki/chat
找到資料後要整理一下訓練的資料
- 先utf8、繁簡轉換原始dkg_lost_conv
- 切分問、答2個詞檔
- 切分問、答2個測試檔用來跑評估模型用
- 建decode encode對話表
- 然後把以上的檔,文字內容轉向量化,讓TensorFlow跑訓練
開始訓練機器人了
執行 python 3.translate.py
import tensorflow as tf
from tensorflow.models.rnn.translate import seq2seq_model
import os
import numpy as np
import math
PAD_ID = 0
GO_ID = 1
EOS_ID = 2
UNK_ID = 3
train_encode_vec = 'train_encode.vec'
train_decode_vec = 'train_decode.vec'
test_encode_vec = 'test_encode.vec'
test_decode_vec = 'test_decode.vec'
vocabulary_encode_size = 5000
vocabulary_decode_size = 5000
buckets = [(5, 10), (10, 15), (20, 25), (40, 50)]
layer_size = 256
num_layers = 3
batch_size = 64
def read_data(source_path, target_path, max_size=None):
data_set = [[] for _ in buckets]
with tf.gfile.GFile(source_path, mode="r") as source_file:
with tf.gfile.GFile(target_path, mode="r") as target_file:
source, target = source_file.readline(), target_file.readline()
counter = 0
while source and target and (not max_size or counter < max_size):
counter += 1
source_ids = [int(x) for x in source.split()]
target_ids = [int(x) for x in target.split()]
target_ids.append(EOS_ID)
for bucket_id, (source_size, target_size) in enumerate(buckets):
if len(source_ids) < source_size and len(target_ids) < target_size:
data_set[bucket_id].append([source_ids, target_ids])
break
source, target = source_file.readline(), target_file.readline()
return data_set
model = seq2seq_model.Seq2SeqModel(source_vocab_size=vocabulary_encode_size, target_vocab_size=vocabulary_decode_size,
buckets=buckets, size=layer_size, num_layers=num_layers, max_gradient_norm= 5.0,
batch_size=batch_size, learning_rate=0.5, learning_rate_decay_factor=0.97, forward_only=False)
config = tf.ConfigProto()
config.gpu_options.allocator_type = 'BFC'
with tf.Session(config=config) as sess:
ckpt = tf.train.get_checkpoint_state('.')
if ckpt != None:
print(ckpt.model_checkpoint_path)
model.saver.restore(sess, ckpt.model_checkpoint_path)
else:
sess.run(tf.global_variables_initializer())
train_set = read_data(train_encode_vec, train_decode_vec)
test_set = read_data(test_encode_vec, test_decode_vec)
train_bucket_sizes = [len(train_set[b]) for b in range(len(buckets))]
train_total_size = float(sum(train_bucket_sizes))
train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / train_total_size for i in range(len(train_bucket_sizes))]
loss = 0.0
total_step = 0
previous_losses = []
while True:
random_number_01 = np.random.random_sample()
bucket_id = min([i for i in range(len(train_buckets_scale)) if train_buckets_scale[i] > random_number_01])
encoder_inputs, decoder_inputs, target_weights = model.get_batch(train_set, bucket_id)
_, step_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, False)
loss += step_loss / 500
total_step += 1
print(total_step)
if total_step % 5000 == 0:
print(model.global_step.eval(), model.learning_rate.eval(), loss)
if len(previous_losses) > 2 and loss > max(previous_losses[-3:]):
sess.run(model.learning_rate_decay_op)
previous_losses.append(loss)
checkpoint_path = "seq2seq.ckpt"
model.saver.save(sess, checkpoint_path, global_step=model.global_step)
loss = 0.0
for bucket_id in range(len(buckets)):
if len(test_set[bucket_id]) == 0:
continue
encoder_inputs, decoder_inputs, target_weights = model.get_batch(test_set, bucket_id)
_, eval_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True)
eval_ppx = math.exp(eval_loss) if eval_loss < 300 else float('inf')
print(bucket_id, eval_ppx)
會跑很久
安迪兒用Mac只有用CPU在跑
跑了1星期不關機,每天都100%(有點可怕XD)
也才跑了30多萬筆
沒顯卡跑超慘的
如果跑不完可以停下,下次會在接著跑
程式跑一陣子會定期存檔
checkpoint裡面會有記錄跑到那
執行bot
執行安迪兒的 4.bot01.py
export BOT_ID="slack bot的id"
export SLACK_TOKEN="slack bot的api key"
python 4.bot01.py &
程式碼:
import tensorflow as tf
from tensorflow.models.rnn.translate import seq2seq_model
import os
import numpy as np
import time
from slackclient import SlackClient
BOT_ID = os.environ["BOT_ID"]
AT_BOT = "<@" + BOT_ID + ">"
EXAMPLE_COMMAND = "orange"
slack_token = os.environ["SLACK_TOKEN"]
slack_client = SlackClient(slack_token)
PAD_ID = 0
GO_ID = 1
EOS_ID = 2
UNK_ID = 3
train_encode_vocabulary = 'train_encode_vocabulary'
train_decode_vocabulary = 'train_decode_vocabulary'
def handle_command(command, channel):
response = str(command)
if command.startswith(EXAMPLE_COMMAND):
response = "Sure...write some more code then I can do that!"
slack_client.api_call("chat.postMessage", channel = channel, text = response, as_user = True)
def parse_slack_output(slack_rtm_output):
output_list = slack_rtm_output
if output_list and len(output_list) > 0:
for output in output_list:
if output and 'text' in output and AT_BOT in output['text']:
return output['text'].split(AT_BOT)[1].strip().lower(),output['channel']
return None, None
def read_vocabulary(input_file):
tmp_vocab = []
with open(input_file, "r") as f:
tmp_vocab.extend(f.readlines())
tmp_vocab = [line.strip() for line in tmp_vocab]
vocab = dict([(x, y) for (y, x) in enumerate(tmp_vocab)])
return vocab, tmp_vocab
vocab_en, _, = read_vocabulary(train_encode_vocabulary)
_, vocab_de, = read_vocabulary(train_decode_vocabulary)
vocabulary_encode_size = 5000
vocabulary_decode_size = 5000
buckets = [(5, 10), (10, 15), (20, 25), (40, 50)]
layer_size = 256
num_layers = 3
batch_size = 1
model = seq2seq_model.Seq2SeqModel(source_vocab_size = vocabulary_encode_size, target_vocab_size = vocabulary_decode_size,
buckets = buckets, size = layer_size, num_layers = num_layers, max_gradient_norm = 5.0,
batch_size = batch_size, learning_rate = 0.5, learning_rate_decay_factor = 0.99, forward_only = True)
model.batch_size = 1
READ_WEBSOCKET_DELAY = 1
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state('.')
if ckpt != None:
print(ckpt.model_checkpoint_path)
model.saver.restore(sess, ckpt.model_checkpoint_path)
else :
print("none ckpt modle")
if slack_client.rtm_connect():
print("StarterBot connected and running!")
while True:
command, channel = parse_slack_output(slack_client.rtm_read())
if command and channel:
print type(command)
input_string = command.encode("utf8")
input_string_vec = []
for words in input_string.strip():
input_string_vec.append(vocab_en.get(words, UNK_ID))
bucket_id = min([b for b in range(len(buckets)) if buckets[b][0] > len(input_string_vec)])
encoder_inputs, decoder_inputs, target_weights = model.get_batch({bucket_id: [(input_string_vec, [])]}, bucket_id)
_, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True)
outputs = [int(np.argmax(logit, axis = 1)) for logit in output_logits]
if EOS_ID in outputs:
outputs = outputs[: outputs.index(EOS_ID)]
response = "".join([tf.compat.as_str(vocab_de[output]) for output in outputs])
print("response:" + response)
# handle_command(response, channel) unicode(unicodestring,"utf-8") response.encode("utf8")
slack_client.api_call("chat.postMessage", channel = channel, text = unicode(response,"utf-8"), as_user = True)
time.sleep(READ_WEBSOCKET_DELAY)
else :
print("Connection failed. Invalid Slack token or bot ID?")
以上安迪兒跑不完,所以只跑了數天就停了
只用train了一點點的模型去跑,結果出現一堆在未知列表裡的字(–UNK–)
人生吶~需要更好的顯卡~XD
附上安迪兒的程式碼和沒訓練完的檔案
想訓練的人可以接著跑,或是直接執行來玩玩
> https://github.com/bowwowxx/tensorbot.git
來看一下成果
登入slack 和BOT聊一下天
雖然對話還是怪怪的,不過還算蠻有趣的
感覺問一下,奇怪的問題小三啥的比較會回
正常的問題不太知道,或是沒在清單列表裡
可能和都是電影八卦對白有點關係~XD
真是太好笑了…