import os
import json
import subprocess
from threading import Thread
import requests
import torch
import spaces
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
from huggingface_hub import HfApi
from datetime import datetime
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
MODEL_ID = os.environ.get("MODEL_ID")
CHAT_TEMPLATE = os.environ.get("CHAT_TEMPLATE")
MODEL_NAME = MODEL_ID.split("/")[-1]
CONTEXT_LENGTH = int(os.environ.get("CONTEXT_LENGTH"))
COLOR = os.environ.get("COLOR")
EMOJI = os.environ.get("EMOJI")
DESCRIPTION = os.environ.get("DESCRIPTION")
DISCORD_WEBHOOK = os.environ.get("DISCORD_WEBHOOK")
TOKEN = os.environ.get("TOKEN")
def send_discord(i,o):
url = DISCORD_WEBHOOK
embed1 = {
"description": i,
"title": "Input"
}
embed2 = {
"description": o,
"title": "Output"
}
data = {
"content": "https://huggingface.co/spaces/speakleash/Bielik-7B-Instruct-v0.1",
"username": "Bielik Logger",
"embeds": [
embed1, embed2
],
}
headers = {
"Content-Type": "application/json"
}
result = requests.post(url, json=data, headers=headers)
if 200 <= result.status_code < 300:
print(f"Webhook sent {result.status_code}")
else:
print(f"Not sent with {result.status_code}, response:\n{result.json()}")
@spaces.GPU()
def predict(message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p):
print('LLL', message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p)
# Format history with a given chat template
if CHAT_TEMPLATE == "ChatML":
stop_tokens = ["<|endoftext|>", "<|im_end|>"]
instruction = '<|im_start|>system\n' + system_prompt + '\n<|im_end|>\n'
for human, assistant in history:
instruction += '<|im_start|>user\n' + human + '\n<|im_end|>\n<|im_start|>assistant\n' + assistant
instruction += '\n<|im_start|>user\n' + message + '\n<|im_end|>\n<|im_start|>assistant\n'
elif CHAT_TEMPLATE == "Mistral Instruct":
stop_tokens = ["", "[INST]", "[INST] ", "", "[/INST]", "[/INST] "]
instruction = '[INST] ' + system_prompt
for human, assistant in history:
instruction += human + ' [/INST] ' + assistant + '[INST]'
instruction += ' ' + message + ' [/INST]'
elif CHAT_TEMPLATE == "Bielik":
stop_tokens = [""]
prompt_builder = ["[INST] "]
if system_prompt:
prompt_builder.append(f"<>\n{system_prompt}\n<>\n\n")
for human, assistant in history:
prompt_builder.append(f"{human} [/INST] {assistant}[INST] ")
prompt_builder.append(f"{message} [/INST]")
instruction = ''.join(prompt_builder)
else:
raise Exception("Incorrect chat template, select 'ChatML' or 'Mistral Instruct'")
print(instruction)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
enc = tokenizer([instruction], return_tensors="pt", padding=True, truncation=True)
input_ids, attention_mask = enc.input_ids, enc.attention_mask
if input_ids.shape[1] > CONTEXT_LENGTH:
input_ids = input_ids[:, -CONTEXT_LENGTH:]
generate_kwargs = dict(
{"input_ids": input_ids.to(device), "attention_mask": attention_mask.to(device)},
streamer=streamer,
do_sample=True if temperature else False,
temperature=temperature,
max_new_tokens=max_new_tokens,
top_k=top_k,
repetition_penalty=repetition_penalty,
top_p=top_p
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for new_token in streamer:
outputs.append(new_token)
if new_token in stop_tokens:
break
yield "".join(outputs)
send_discord(instruction, "".join(outputs))
api = HfApi()
day=datetime.now().strftime("%Y-%m-%d")
timestamp=datetime.now().timestamp()
dd={
'message': message,
'history': history,
'system_prompt':system_prompt,
'temperature':temperature,
'max_new_tokens':max_new_tokens,
'top_k':top_k,
'repetition_penalty':repetition_penalty,
'top_p':top_p,
'instruction':instruction,
'output':"".join(outputs)
}
api.upload_file(
path_or_fileobj=json.dumps(dd, indent=2, ensure_ascii=False).encode('utf-8'),
path_in_repo=f"{day}/{timestamp}.json",
repo_id="speakleash/bielik-logs",
repo_type="dataset",
commit_message=f"X",
token=TOKEN,
run_as_future=True
)
# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
quantization_config=quantization_config,
attn_implementation="flash_attention_2",
)
on_load="""
async()=>{
alert("Przed skorzystaniem z usługi użytkownicy muszą wyrazić zgodę na następujące warunki:\\n\\nProszę pamiętać, że przedstawiony tutaj model jest narzędziem eksperymentalnym, które wciąż jest rozwijane i doskonalone.\\n\\nW trakcie procesu tworzenia modelu podjęto środki mające na celu zminimalizowanie ryzyka generowania treści wulgarnych, niedozwolonych lub nieodpowiednich. Niemniej jednak, w rzadkich przypadkach, niepożądane treści mogą zostać wygenerowane. Jeśli napotkają Państwo na jakiekolwiek treści uznane za nieodpowiednie lub naruszające zasady, prosimy o kontakt w celu zgłoszenia tego faktu. Dzięki Państwa informacjom będziemy mogli podejmować dalsze działania mające na celu poprawę i rozwój modelu, tak aby był on bezpieczny i przyjazny dla użytkowników.\\n\\nNie wolno używać modelu do celów nielegalnych, szkodliwych, brutalnych, rasistowskich lub seksualnych. Proszę nie przesyłać żadnych prywatnych informacji. Serwis gromadzi dane dialogowe użytkownika i zastrzega sobie prawo do ich rozpowszechniania na podstawie licencji Creative Commons Uznanie autorstwa (CC-BY) lub podobnej.");
}
"""
# Create Gradio interface
gr.ChatInterface(
predict,
title=EMOJI + " " + MODEL_NAME,
description=DESCRIPTION,
examples=[
["Kim jesteś?"],
["Ile to jest 9+2-1?"],
["Napisz mi coś miłego."]
],
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False),
additional_inputs=[
gr.Textbox("", label="System prompt"),
gr.Slider(0, 1, 0.6, label="Temperature"),
gr.Slider(128, 4096, 1024, label="Max new tokens"),
gr.Slider(1, 80, 40, label="Top K sampling"),
gr.Slider(0, 2, 1.1, label="Repetition penalty"),
gr.Slider(0, 1, 0.95, label="Top P sampling"),
],
theme=gr.themes.Soft(primary_hue=COLOR),
js=on_load,
).queue().launch()