Spaces:
Runtime error
Runtime error
import gradio as gr | |
import pandas as pd | |
import re | |
from pyspark.sql import SparkSession, Window | |
import pyspark.sql.functions as F | |
from llama_cpp import Llama | |
from loguru import logger # Import the logger from loguru | |
# Function to read the text file and create Spark DataFrame | |
def create_spark_dataframe(text): | |
# Get list of chapter strings | |
chapter_list = [x for x in re.split('CHAPTER .+', text) if len(x) > 100] | |
# Create Spark DataFrame | |
spark = SparkSession.builder.appName("Counting word occurrences from a book, under a microscope.").config("spark.driver.memory", "4g").getOrCreate() | |
spark.sparkContext.setLogLevel("WARN") | |
df = spark.createDataFrame(pd.DataFrame({'text': chapter_list, 'chapter': range(1, len(chapter_list) + 1)})) | |
return df | |
# Read the "War and Peace" text file and create Spark DataFrame | |
with open('war_and_peace.txt', 'r') as file: | |
text = file.read() | |
df_chapters = create_spark_dataframe(text) | |
# Define the Llama models | |
MODEL_Q8_0 = Llama(model_path="llama-2-7b-chat.ggmlv3.q8_0.bin", n_ctx=8192, n_batch=512) | |
MODEL_Q2_K = Llama(model_path="llama-2-7b-chat.ggmlv3.q2_K.bin", n_ctx=8192, n_batch=512) | |
# Function to summarize a chapter using the selected model | |
def llama2_summarize(chapter_text, model_version): | |
# Choose the model based on the model_version parameter | |
if model_version == "q8_0": | |
llm = MODEL_Q8_0 | |
elif model_version == "q2_K": | |
llm = MODEL_Q2_K | |
else: | |
return "Error: Invalid model_version." | |
# Template for this model version | |
template = """ | |
[INST] <<SYS>> | |
You are a helpful, respectful and honest assistant. | |
Always answer as helpfully as possible, while being safe. | |
Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. | |
Please ensure that your responses are socially unbiased and positive in nature. | |
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. | |
If you don't know the answer to a question, please don't share false information. | |
<</SYS>> | |
{INSERT_PROMPT_HERE} [/INST] | |
""" | |
# Create prompt | |
prompt = 'Summarize the following novel chapter in a single sentence (less than 100 words): ' + chapter_text | |
prompt = template.replace('INSERT_PROMPT_HERE', prompt) | |
# Log the input chapter text and model_version | |
logger.info(f"Input chapter text: {chapter_text}") | |
logger.info(f"Selected model version: {model_version}") | |
# Generate summary using the selected model | |
output = llm(prompt, max_tokens=-1, echo=False, temperature=0.2, top_p=0.1) | |
summary = output['choices'][0]['text'] | |
# Log the generated summary | |
logger.info(f"Generated summary: {summary}") | |
return summary | |
# Define the Gradio interface | |
iface = gr.Interface( | |
fn=llama2_summarize, | |
inputs=[ | |
gr.inputs.File(label="Upload Text File"), | |
"text", | |
], # chapter_text, model_version | |
outputs="text", # Summary text | |
live=False, | |
capture_session=True, | |
title="Llama2 Chapter Summarizer", | |
description="Upload the text file or enter the chapter text and model version ('q8_0' or 'q2_K') to get a summarized sentence.", | |
) | |
if __name__ == "__main__": | |
iface.launch(); |