代码:
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain import PromptTemplate
from langchain_community.llms import LlamaCpp
from langchain.chains import RetrievalQA
import streamlit as st
from HtmlTemplates import bot_template , user_template , css
import torch
def set_prompt():
custom_prompt_template = """[INST] <<SYS>>
You are a trained to guide people about Indian Law. You will answer user's query with your knowledge and use context provided.
Do not say thank you and tell you are an AI Assistant and be open about everything.
Always complete the sentence you are generating
<</SYS>>
Use the following pieces of context to answer the users question.
Context : {context}
Question : {question}
Answer : [/INST]
"""
prompt = PromptTemplate(template=custom_prompt_template, input_variables=["context", "question"])
return prompt
def retrieval_qa_chain(llm, prompt, db):
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type='stuff',
retriever=db.as_retriever(search_kwargs={'k': 6}),
chain_type_kwargs={'prompt': prompt}
)
return qa_chain
def qa_pipeline():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embeddings = HuggingFaceEmbeddings(model_name = 'multi-qa-mpnet-base-dot-v1' , model_kwargs = {'device': device})
db = FAISS.load_local("vectorstore", embeddings , allow_dangerous_deserialization=True)
llm = LlamaCpp(model_path = path,
temperature = temperature,
n_ctx = 2048,
n_batch = 128,
n_gpu_layers = -1,
max_tokens = max_tokens,
verbose = False )
print(path)
qa_prompt = set_prompt()
chain = retrieval_qa_chain(llm, qa_prompt, db)
return chain
def handle_user_input(user_question):
with st.spinner("Generating response ..."):
response = chain(user_question)
response = response['result']
st.write(bot_template.replace("{{MSG}}",response),unsafe_allow_html=True)
st.set_page_config(page_title = "Your personal Law ChatBot", page_icon = ":bot:")
st.write(css , unsafe_allow_html=True)
global chain, path, temperature, max_tokens
with st.sidebar:
model = st.selectbox("Select Model :",("Llama2 7b (Faster)" , "Llama2 13b (Can answer complex queries)"))
if model == 'Llama2 13b (Can answer complex queries)':
path = "Models/llama-2-13b-chat.Q4_K_M.gguf"
elif model == 'Llama2 7b (Faster)':
path = "Models/llama-2-7b-chat.Q4_K_M.gguf"
temperature = st.slider(label="Temperature",
min_value=0.1,
max_value=1.0,
value=0.3,
step=0.05
)
max_tokens = st.slider(label="Max Tokens",
min_value=256,
max_value=4096,
value=1024,
step=64
)
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
st.header("Your personal Law ChatBot :books:")
user_question = st.chat_input("Ask a question :")
chain = qa_pipeline()
if user_question:
st.write(user_template.replace("{{MSG}}",user_question),unsafe_allow_html=True)
handle_user_input(user_question)
for chat in st.session_state.chat_history:
st.write(user_template.replace("{{MSG}}",chat["User"]),unsafe_allow_html=True)
st.write(bot_template.replace("{{MSG}}",chat["Bot"]),unsafe_allow_html=True)
当输出完全生成时才会显示。 有没有一种方法可以在生成它时在streamlit上进行流式传输??
我尝试了langchain中的StreamingStdCallbackHandler(),但它只在命令行中进行流式传输。 我想要得到类似ChatGPT的效果。 我希望它发生在流式网页中。
你可以使用 Streamlit 的
st.write
函数结合生成式回调来流式传输 LLM 的输出。
以下是如何修改代码以实现此目的的方法:
import streamlit as st
from HtmlTemplates import bot_template, user_template, css
import torch
from langchain.callbacks import StreamlitCallbackHandler
# ... 你现有的导入和函数 ...
def handle_user_input(user_question):
st.write(user_template.replace("{{MSG}}", user_question), unsafe_allow_html=True)
# 创建一个占位符,用于显示响应
response_placeholder = st.empty()
# 创建 StreamlitCallbackHandler 实例
stream_handler = StreamlitCallbackHandler(response_placeholder)
# 调用 LLM 时传递 stream_handler
response = chain(user_question, callbacks=[stream_handler])
# 获取完整响应(如果需要)
full_response = response['result']
# 可选:在流式传输完成后更新占位符
response_placeholder.write(bot_template.replace("{{MSG}}", full_response), unsafe_allow_html=True)
# ... 你现有的 Streamlit 代码 ...
说明:
-
导入 StreamlitCallbackHandler:
python from langchain.callbacks import StreamlitCallbackHandler
-
创建 StreamlitCallbackHandler: 在
handle_user_input
函数中,创建一个StreamlitCallbackHandler
实例,并传入st.empty()
创建的占位符。python stream_handler = StreamlitCallbackHandler(response_placeholder)
-
传递回调: 调用
chain
时,将stream_handler
作为callbacks
参数传递。python response = chain(user_question, callbacks=[stream_handler])
现在,当 LLM 生成其响应时,它将通过
StreamlitCallbackHandler
流式传输到 Streamlit 应用程序,从而创建类似 ChatGPT 的体验。