from transformers import AutoModel, AutoTokenizer import streamlit as st st.set_page_config( page_title="ChatGLM2-6b 演示", page_icon=":robot:", layout='wide' ) @st.cache_resource def get_model(): tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True) model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).cuda() # 多显卡支持,使用下面两行代替上面一行,将num_gpus改为你实际的显卡数量 # from utils import load_model_on_gpus # model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2) model = model.eval() return tokenizer, model tokenizer, model = get_model() st.title("ChatGLM2-6B") max_length = st.sidebar.slider( 'max_length', 0, 32768, 8192, step=1 ) top_p = st.sidebar.slider( 'top_p', 0.0, 1.0, 0.8, step=0.01 ) temperature = st.sidebar.slider( 'temperature', 0.0, 1.0, 0.8, step=0.01 ) if 'history' not in st.session_state: st.session_state.history = [] if 'past_key_values' not in st.session_state: st.session_state.past_key_values = None for i, (query, response) in enumerate(st.session_state.history): with st.chat_message(name="user", avatar="user"): st.markdown(query) with st.chat_message(name="assistant", avatar="assistant"): st.markdown(response) with st.chat_message(name="user", avatar="user"): input_placeholder = st.empty() with st.chat_message(name="assistant", avatar="assistant"): message_placeholder = st.empty() prompt_text = st.text_area(label="用户命令输入", height=100, placeholder="请在这儿输入您的命令") button = st.button("发送", key="predict") if button: input_placeholder.markdown(prompt_text) history, past_key_values = st.session_state.history, st.session_state.past_key_values for response, history, past_key_values in model.stream_chat(tokenizer, prompt_text, history, past_key_values=past_key_values, max_length=max_length, top_p=top_p, temperature=temperature, return_past_key_values=True): message_placeholder.markdown(response) st.session_state.history = history st.session_state.past_key_values = past_key_values