import os import json import subprocess from threading import Thread import requests import random import torch import spaces import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer from huggingface_hub import HfApi from datetime import datetime subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) MODEL_ID = os.environ.get("MODEL_ID") CHAT_TEMPLATE = os.environ.get("CHAT_TEMPLATE") MODEL_NAME = MODEL_ID.split("/")[-1] CONTEXT_LENGTH = int(os.environ.get("CONTEXT_LENGTH")) COLOR = os.environ.get("COLOR") EMOJI = os.environ.get("EMOJI") DESCRIPTION = os.environ.get("DESCRIPTION") DISCORD_WEBHOOK = os.environ.get("DISCORD_WEBHOOK") TOKEN = os.environ.get("TOKEN") api = HfApi() def send_discord(i,o): url = DISCORD_WEBHOOK embed1 = { "description": i, "title": "Input" } embed2 = { "description": o, "title": "Output" } data = { "content": "https://huggingface.co/spaces/speakleash/Bielik-7B-Instruct-v0.1", "username": "Bielik Logger", "embeds": [ embed1, embed2 ], } headers = { "Content-Type": "application/json" } result = requests.post(url, json=data, headers=headers) if 200 <= result.status_code < 300: print(f"Webhook sent {result.status_code}") else: print(f"Not sent with {result.status_code}, response:\n{result.json()}") # Load model device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16 ) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( MODEL_ID, device_map="auto", torch_dtype='auto', attn_implementation="flash_attention_2", ) @spaces.GPU() def generate(instruction, stop_tokens, temperature, max_new_tokens, top_k, repetition_penalty, top_p): streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) enc = tokenizer([instruction], return_tensors="pt", padding=True, truncation=True) input_ids, attention_mask = enc.input_ids, enc.attention_mask if input_ids.shape[1] > CONTEXT_LENGTH: input_ids = input_ids[:, -CONTEXT_LENGTH:] generate_kwargs = dict( {"input_ids": input_ids.to(device), "attention_mask": attention_mask.to(device)}, streamer=streamer, do_sample=True if temperature else False, temperature=temperature, max_new_tokens=max_new_tokens, top_k=top_k, repetition_penalty=repetition_penalty, top_p=top_p ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() outputs = [] for new_token in streamer: outputs.append(new_token) if new_token in stop_tokens: break yield "".join(outputs) def predict(message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p): repetition_penalty=float(repetition_penalty) print('LLL', [message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p]) # Format history with a given chat template if CHAT_TEMPLATE == "ChatML": stop_tokens = ["<|endoftext|>", "<|im_end|>"] instruction = '<|im_start|>system\n' + system_prompt + '\n<|im_end|>\n' for human, assistant in history: instruction += '<|im_start|>user\n' + human + '\n<|im_end|>\n<|im_start|>assistant\n' + assistant instruction += '\n<|im_start|>user\n' + message + '\n<|im_end|>\n<|im_start|>assistant\n' elif CHAT_TEMPLATE == "Mistral Instruct": stop_tokens = ["", "[INST]", "[INST] ", "", "[/INST]", "[/INST] "] instruction = '[INST] ' + system_prompt for human, assistant in history: instruction += human + ' [/INST] ' + assistant + '[INST]' instruction += ' ' + message + ' [/INST]' elif CHAT_TEMPLATE == "Bielik": stop_tokens = [""] prompt_builder = ["[INST] "] if system_prompt: prompt_builder.append(f"<>\n{system_prompt}\n<>\n\n") for human, assistant in history: prompt_builder.append(f"{human} [/INST] {assistant}[INST] ") prompt_builder.append(f"{message} [/INST]") instruction = ''.join(prompt_builder) else: raise Exception("Incorrect chat template, select 'ChatML' or 'Mistral Instruct'") print(instruction) for output_text in generate(instruction, stop_tokens, temperature, max_new_tokens, top_k, repetition_penalty, top_p): yield output_text send_discord(instruction, output_text) hfapi = HfApi() day=datetime.now().strftime("%Y-%m-%d") timestamp=datetime.now().timestamp() dd={ 'message': message, 'history': history, 'system_prompt':system_prompt, 'temperature':temperature, 'max_new_tokens':max_new_tokens, 'top_k':top_k, 'repetition_penalty':repetition_penalty, 'top_p':top_p, 'instruction':instruction, 'output':output_text, 'precision': 'auto '+str(model.dtype), } hfapi.upload_file( path_or_fileobj=json.dumps(dd, indent=2, ensure_ascii=False).encode('utf-8'), path_in_repo=f"{day}/{timestamp}.json", repo_id="speakleash/bielik-logs", repo_type="dataset", commit_message=f"X", token=TOKEN, run_as_future=True ) on_load=""" async()=>{ alert("Przed skorzystaniem z usługi użytkownicy muszą wyrazić zgodę na następujące warunki:\\n\\nProszę pamiętać, że przedstawiony tutaj model jest narzędziem eksperymentalnym, które wciąż jest rozwijane i doskonalone.\\n\\nW trakcie procesu tworzenia modelu podjęto środki mające na celu zminimalizowanie ryzyka generowania treści wulgarnych, niedozwolonych lub nieodpowiednich. Niemniej jednak, w rzadkich przypadkach, niepożądane treści mogą zostać wygenerowane. Jeśli napotkają Państwo na jakiekolwiek treści uznane za nieodpowiednie lub naruszające zasady, prosimy o kontakt w celu zgłoszenia tego faktu. Dzięki Państwa informacjom będziemy mogli podejmować dalsze działania mające na celu poprawę i rozwój modelu, tak aby był on bezpieczny i przyjazny dla użytkowników.\\n\\nNie wolno używać modelu do celów nielegalnych, szkodliwych, brutalnych, rasistowskich lub seksualnych. Proszę nie przesyłać żadnych prywatnych informacji. Serwis gromadzi dane dialogowe użytkownika i zastrzega sobie prawo do ich rozpowszechniania na podstawie licencji Creative Commons Uznanie autorstwa (CC-BY) lub podobnej."); } """ def vote(chatbot, data: gr.LikeData): day=datetime.now().strftime("%Y-%m-%d") timestamp=datetime.now().timestamp() api.upload_file( path_or_fileobj=json.dumps({"history":chatbot, 'index': data.index, 'liked': data.liked}, indent=2, ensure_ascii=False).encode('utf-8'), path_in_repo=f"liked/{day}/{timestamp}.json", repo_id="speakleash/bielik-logs", repo_type="dataset", commit_message=f"L", token=TOKEN, run_as_future=True ) # Create Gradio interface def update_examples(): exs = [ ["Kim jesteś?"], ["Ile to jest 9+2-1?"], ["Napisz mi coś miłego."] ] random.shuffle(exs) return gr.Dataset(samples=exs) with gr.Blocks(js=on_load) as demo: chatbot = gr.Chatbot(label="Chatbot", likeable=True, render=False) chatbot.like(vote, [chatbot], None) chat = gr.ChatInterface( predict, chatbot=chatbot, title=EMOJI + " " + MODEL_NAME + " - online chat demo", description=DESCRIPTION, examples=[ ["Kim jesteś?"], ["Ile to jest 9+2-1?"], ["Napisz mi coś miłego."] ], additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False), additional_inputs=[ gr.Textbox("", label="System prompt", render=False), gr.Slider(0, 1, 0.6, label="Temperature", render=False), gr.Slider(128, 4096, 1024, label="Max new tokens", render=False), gr.Slider(1, 80, 40, step=1, label="Top K sampling", render=False), gr.Slider(0, 2, 1.1, label="Repetition penalty", render=False), gr.Slider(0, 1, 0.95, label="Top P sampling", render=False), ], theme=gr.themes.Soft(primary_hue=COLOR), ) demo.load(update_examples, None, chat.examples_handler.dataset) demo.queue(max_size=20).launch() # chatbot = gr.Chatbot(label="Chatbot", likeable=True) # chatbot.like(vote, None, None) # gr.ChatInterface( # predict, # chatbot=chatbot, # title=EMOJI + " " + MODEL_NAME, # description=DESCRIPTION, # examples=[ # ["Kim jesteś?"], # ["Ile to jest 9+2-1?"], # ["Napisz mi coś miłego."] # ], # additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False), # additional_inputs=[ # gr.Textbox("", label="System prompt"), # gr.Slider(0, 1, 0.6, label="Temperature"), # gr.Slider(128, 4096, 1024, label="Max new tokens"), # gr.Slider(1, 80, 40, label="Top K sampling"), # gr.Slider(0, 2, 1.1, label="Repetition penalty"), # gr.Slider(0, 1, 0.95, label="Top P sampling"), # ], # theme=gr.themes.Soft(primary_hue=COLOR), # js=on_load, # ).queue().launch()