File size: 4,696 Bytes
2846f20 3ae0fe0 2846f20 38e18ec 2846f20 0546764 2846f20 0546764 2846f20 0546764 2846f20 0546764 2846f20 3ae0fe0 2846f20 0546764 2846f20 38e18ec 2846f20 0546764 2846f20 5756334 3ae0fe0 38e18ec 3ae0fe0 0546764 2846f20 0546764 2846f20 0546764 2846f20 0546764 3ae0fe0 2846f20 99296d0 3ae0fe0 2846f20 3ae0fe0 2846f20 0546764 2846f20 38e18ec 2846f20 38e18ec 0546764 2846f20 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import streamlit as st
import replicate
import os
from transformers import AutoTokenizer
# # Assuming you have a specific tokenizers for Llama; if not, use an appropriate one like this
# tokenizer = AutoTokenizer.from_pretrained("allenai/llama")
# text = "Example text to tokenize."
# tokens = tokenizer.tokenize(text)
# num_tokens = len(tokens)
# print("Number of tokens:", num_tokens)
# Set assistant icon to Snowflake logo
icons = {"assistant": "./Snowflake_Logomark_blue.svg", "user": "⛷️"}
# App title
st.set_page_config(page_title="Snowflake Arctic")
# Replicate Credentials
with st.sidebar:
st.title('Snowflake Arctic')
if 'REPLICATE_API_TOKEN' in st.secrets:
#st.success('API token loaded!', icon='✅')
replicate_api = st.secrets['REPLICATE_API_TOKEN']
else:
replicate_api = st.text_input('Enter Replicate API token:', type='password')
if not (replicate_api.startswith('r8_') and len(replicate_api)==40):
st.warning('Please enter your Replicate API token.', icon='⚠️')
st.markdown("**Don't have an API token?** Head over to [Replicate](https://replicate.com) to sign up for one.")
#else:
# st.success('API token loaded!', icon='✅')
os.environ['REPLICATE_API_TOKEN'] = replicate_api
st.subheader("Adjust model parameters")
temperature = st.sidebar.slider('temperature', min_value=0.01, max_value=5.0, value=0.3, step=0.01)
top_p = st.sidebar.slider('top_p', min_value=0.01, max_value=1.0, value=0.9, step=0.01)
# Store LLM-generated responses
if "messages" not in st.session_state.keys():
st.session_state.messages = [{"role": "assistant", "content": "Hi. I'm Arctic, a new, efficient, intelligent, and truly open language model created by Snowflake AI Research. Ask me anything."}]
# Display or clear chat messages
for message in st.session_state.messages:
with st.chat_message(message["role"], avatar=icons[message["role"]]):
st.write(message["content"])
def clear_chat_history():
st.session_state.messages = [{"role": "assistant", "content": "Hi. I'm Arctic, a new, efficient, intelligent, and truly open language model created by Snowflake AI Research. Ask me anything."}]
st.sidebar.button('Clear chat history', on_click=clear_chat_history)
st.sidebar.caption('Built by [Snowflake](https://snowflake.com/) to demonstrate [Snowflake Arctic](https://www.snowflake.com/blog/arctic-open-and-efficient-foundation-language-models-snowflake).')
@st.cache_resource(show_spinner=False)
def get_tokenizer():
"""Get a tokenizer to make sure we're not sending too much text
text to the Model. Eventually we will replace this with ArcticTokenizer
"""
return AutoTokenizer.from_pretrained("huggyllama/llama-7b")
def get_num_tokens(prompt):
"""Get the number of tokens in a given prompt"""
tokenizer = get_tokenizer()
tokens = tokenizer.tokenize(prompt)
return len(tokens)
# Function for generating Snowflake Arctic response
def generate_arctic_response():
prompt = []
for dict_message in st.session_state.messages:
if dict_message["role"] == "user":
prompt.append("<|im_start|>user\n" + dict_message["content"] + "<|im_end|>")
else:
prompt.append("<|im_start|>assistant\n" + dict_message["content"] + "<|im_end|>")
prompt.append("<|im_start|>assistant")
prompt.append("")
prompt_str = "\n".join(prompt)
if get_num_tokens(prompt_str) >= 3072:
st.error("Conversation length too long. Please keep it under 3072 tokens.")
st.button('Clear chat history', on_click=clear_chat_history, key="clear_chat_history")
st.stop()
for event in replicate.stream("snowflake/snowflake-arctic-instruct",
input={"prompt": prompt_str,
"prompt_template": r"{prompt}",
"temperature": temperature,
"top_p": top_p,
}):
yield str(event)
# User-provided prompt
if prompt := st.chat_input(disabled=not replicate_api):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user", avatar="⛷️"):
st.write(prompt)
# Generate a new response if last message is not from assistant
if st.session_state.messages[-1]["role"] != "assistant":
with st.chat_message("assistant", avatar="./Snowflake_Logomark_blue.svg"):
response = generate_arctic_response()
full_response = st.write_stream(response)
message = {"role": "assistant", "content": full_response}
st.session_state.messages.append(message) |