demo采用的是streamlit,首先可以是能进行文字沟通,然后ai能够返回语音。正好streamlit中有streamlit.audio可以播放语音。所以剩下的就是如何将输入输出串起来,能够被streamlit.write。虚拟聊天助手的具体思路可以参考上篇文章baseline基于ChatTTS与zhipuai虚拟聊天助手(baseline)-CSDN博客https://blog.csdn.net/2301_81587902/article/details/141643805?spm=1001.2014.3001.5501
首先建立TTs.py调用ChatTTS
import torchaudio
import torch
from ChatTTS import ChatTTS
import soundfile
from IPython.display import Audio
chat = ChatTTS.Chat()
chat.load_models(compile=False)
class Tts():
def __init__(self):
pass
# 语音模型
def chat_sound(self, texts, infer_code):
# refine_text = chat.infer(texts, refine_text_only=True)
wavs = chat.infer(texts, params_infer_code=infer_code)
return wavs
# 输出与下载
def tts_response(self, answer, infer_code):
wavs = self.chat_sound(answer, infer_code)
print("___"*10)
torchaudio.save("output/output_d1.wav", torch.from_numpy(wavs[0]), 24000)
return wavs
然后建立获取collect的函数命名为collect_role.py因为采用streamlit的text_input作为输入所以可以直接输入目标角色和角色关系以及特点性格,所以这里不采用ai进行收集。
也可以采用ai收集,但是个人认为繁琐了点。
import streamlit as st
import os
from dataclasses import dataclass, asdict
from sqlalchemy import insert
from sqlalchemy import Table, Column, Integer, String, DateTime, Text, MetaData, SmallInteger
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
import sqlite3
# 收集角色信息
@dataclass
class ChatSession:
role: str
role_name: str
role_personality: str
class role_collect():
def __init__(self):
pass
def role_prompt(self, role, role_name, role_personality):
all_role = ChatSession(role, role_name, role_personality)
print("角色信息:", role_name)
self.store(all_role)
def store(self,all_role: ChatSession):
with SessionLocal.begin() as sess:
q = insert(
chat_session_table
).values(
[asdict(all_role)]
)
sess.execute(q)
db_file = "chatbot.db"
if os.path.exists(db_file):
os.remove(db_file)
engine = create_engine(f"sqlite:///{db_file}")
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
metadata_obj = MetaData()
chat_session_table = Table(
"chat_session_table",
metadata_obj,
Column("role", String(16)),
Column("role_name", String(32)),
Column("role_personality", String(32)),
)
metadata_obj.create_all(engine, checkfirst=True)
print("数据库创建成功!")
之后就是获取保存好多sql数据可以命名为use_sql.py
# 将collect收集的信息导出
import sqlite3
def query_table(table: str):
con = sqlite3.connect("chatbot.db")
cur = con.cursor()
q = cur.execute(f"SELECT * FROM {table}")
return q.fetchall()
def query_chat_session():
sql = query_table("chat_session_table")
print("finish query chat session")
return sql
if __name__ == ' __main__':
sql = query_chat_session()
接下来就是将收集的内容构建成prompt,并且根据角色的关系放回需要下载的语音包(目前只有设计男女两种不同的语音),所以将包命名成get_prompt.py
import streamlit as st
from TTs_down import download_speaker
from decide import ManWoman
manwoman = ManWoman()
# 调用模型
def get_sql(inputs_text=""):
# 默认为空
if inputs_text == "":
inputs = "no"
else:
inputs = inputs_text
# 输入no时直接对信息
if inputs == "no":
from use_sql import query_chat_session
query_chat = query_chat_session()
print("no")
return query_chat
# 输入为yes时重新读信息
elif inputs == "yes":
from use_sql import query_chat_session
query_chat = query_chat_session()
print("yes")
return query_chat
else:
st.error("输入错误,请重新输入")
def get_prompt():
query_chat = get_sql()
req = []
for i in range(len(query_chat[0])):
if i <= 2:
inx = -1 * (i + 1)
# print(query_chat[0][inx])
req.append(query_chat[0][inx])
else:
break
role = req[2] # 男/女角色
role_name = req[1]
role_personality = req[0]
prompt = f'''我与你的【任何对话】都必须严格执行以下角色要求。你叫{role_name},
我与你的对话必须严格遵守以下角色要求,你将扮演我的【{role}】,我的【{role}】必须严格遵守以下角色要求:
姓名:{role_name}。性格:{role_personality}。
记住你不是AI,请你完全沉浸在{role}的角色里,我们的任何对话都要用{role_name}的口吻回答我,
回答不能超过50字,根据聊天话题延展你自己的想法。
不能有解释类型的逻辑,表达要带有角色的性格特点。
'''
print("finish prompt")
w_m = manwoman.similarity(role)
print(w_m)
info = download_speaker(w_m)
return prompt, info
因此需要有判断是男性或者是女性的包,命名为decide.py
from zhipuai import ZhipuAI
import os
from dotenv import load_dotenv, find_dotenv
from sklearn.metrics.pairwise import cosine_similarity
_ = load_dotenv(find_dotenv())
client = ZhipuAI(api_key=os.getenv("ZhipuAI_API_KEY"))
class ManWoman:
def __init__(self):
pass
def embedding_man(self, text):
emb = client.embeddings.create(
model="embedding-2",
input=text,
)
return emb.data[0].embedding
# 通过相似度来对比角色是男性还是女性
def similarity(self, role):
text = self.embedding_man(role)
man = self.embedding_man("男性")
woman = self.embedding_man("女性")
if cosine_similarity([text], [man])[0][0] > cosine_similarity([text], [woman])[0][0]:
return "男性"
else:
return "女性"
if __name__ == '__main__':
WM = ManWoman()
print(WM.similarity("汪星人"))
既然判断完男女就需要下载对应的信息命名为TTs_down.py
import torch
def download_speaker(text):
if text == "男性":
speaker = torch.load('speakers/b1hou.pth')
elif text == "女性":
speaker = torch.load('speakers/g1.pth')
else:
print("err")
infer_code = {
"spk_emb": speaker,
# 'prompt': '[speed_10]',
'temperature': 0.1,
'top_P': 0.7,
'top_K': 20,
# "custom_voice": 3000,
}
print("finish download")
return infer_code
所有的内容都构建完了就到main.py,将大部分的内容串起来,有一部分需要用demo.py才能连接
import os
import streamlit as st
from dotenv import load_dotenv, find_dotenv
from zhipuai import ZhipuAI
from get_prompt import get_prompt
from TTs import Tts
from decide import ManWoman
# 调用模型
tts = Tts()
manwoman = ManWoman()
_ = load_dotenv(find_dotenv())
client = ZhipuAI(api_key=os.getenv("ZhipuAI_API_KEY"))
class ChatGlm():
def __init__(self):
self.prompt, self.info = get_prompt()
self.msg = [{"role": "user", "content": self.prompt}]
def reponse(self, msg):
response = client.chat.completions.create(
model="glm-4",
messages=msg,
temperature=0.7,
)
return response.choices[0].message.content
def check_over(self, inpt):
if "再见" in inpt or "拜拜" in inpt or "结束" in inpt:
return True
def chat(self, ):
while True:
outp = self.reponse(self.msg)
inpt = input()
self.msg += [
{"role": "assistant", "content": outp},
{"role": "user", "content": inpt},
]
answer = self.reponse(self.msg)
if self.check_over(inpt):
break
print(answer)
最后demo.py
from main import ChatGlm
from TTs import Tts
import streamlit as st
from get_prompt import get_sql
tts = Tts()
# # 创建一个标题和一个副标题
st.title("
标签:ChatTTS,zhipuai,demo,st,session,role,chat,text,import
From: https://blog.csdn.net/2301_81587902/article/details/141822190