Upload 4 files
Browse files
README.md
CHANGED
@@ -1,13 +1,15 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 4.
|
8 |
app_file: app.py
|
9 |
-
pinned:
|
10 |
-
|
|
|
|
|
11 |
---
|
12 |
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: RAG-Chatbot
|
3 |
+
emoji: ๐w๐
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: red
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.39.0
|
8 |
app_file: app.py
|
9 |
+
pinned: true
|
10 |
+
short_description: A retrieval system with chatbot integration
|
11 |
+
thumbnail: >-
|
12 |
+
https://cdn-uploads.huggingface.co/production/uploads/6527e89a8808d80ccff88b7a/XVgtQiizeFHIUUj1huwdv.png
|
13 |
---
|
14 |
|
15 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from datasets import load_dataset
|
3 |
+
import os
|
4 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
|
5 |
+
import torch
|
6 |
+
from threading import Thread
|
7 |
+
from sentence_transformers import SentenceTransformer
|
8 |
+
import faiss
|
9 |
+
import fitz # PyMuPDF
|
10 |
+
|
11 |
+
# ํ๊ฒฝ ๋ณ์์์ Hugging Face ํ ํฐ ๊ฐ์ ธ์ค๊ธฐ
|
12 |
+
token = os.environ.get("HF_TOKEN")
|
13 |
+
|
14 |
+
|
15 |
+
# ์๋ฒ ๋ฉ ๋ชจ๋ธ ๋ก๋
|
16 |
+
ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
|
17 |
+
|
18 |
+
# PDF์์ ํ
์คํธ ์ถ์ถ
|
19 |
+
def extract_text_from_pdf(pdf_path):
|
20 |
+
doc = fitz.open(pdf_path)
|
21 |
+
text = ""
|
22 |
+
for page in doc:
|
23 |
+
text += page.get_text()
|
24 |
+
return text
|
25 |
+
|
26 |
+
# ๋ฒ๋ฅ ๋ฌธ์ PDF ๊ฒฝ๋ก ์ง์ ๋ฐ ํ
์คํธ ์ถ์ถ
|
27 |
+
pdf_path = "laws.pdf" # ์ฌ๊ธฐ์ ์ค์ PDF ๊ฒฝ๋ก๋ฅผ ์
๋ ฅํ์ธ์.
|
28 |
+
law_text = extract_text_from_pdf(pdf_path)
|
29 |
+
|
30 |
+
# ๋ฒ๋ฅ ๋ฌธ์ ํ
์คํธ๋ฅผ ๋ฌธ์ฅ ๋จ์๋ก ๋๋๊ณ ์๋ฒ ๋ฉ
|
31 |
+
law_sentences = law_text.split('\n') # Adjust splitting based on your PDF structure
|
32 |
+
law_embeddings = ST.encode(law_sentences)
|
33 |
+
|
34 |
+
# FAISS ์ธ๋ฑ์ค ์์ฑ ๋ฐ ์๋ฒ ๋ฉ ์ถ๊ฐ
|
35 |
+
index = faiss.IndexFlatL2(law_embeddings.shape[1])
|
36 |
+
index.add(law_embeddings)
|
37 |
+
|
38 |
+
# Hugging Face์์ ๋ฒ๋ฅ ์๋ด ๋ฐ์ดํฐ์
๋ก๋
|
39 |
+
dataset = load_dataset("jihye-moon/LawQA-Ko")
|
40 |
+
data = dataset["train"]
|
41 |
+
|
42 |
+
# ์ง๋ฌธ ์ปฌ๋ผ์ ์๋ฒ ๋ฉํ์ฌ ์๋ก์ด ์ปฌ๋ผ์ ์ถ๊ฐ
|
43 |
+
data = data.map(lambda x: {"question_embedding": ST.encode(x["question"])}, batched=True)
|
44 |
+
data.add_faiss_index(column="question_embedding")
|
45 |
+
|
46 |
+
# LLaMA ๋ชจ๋ธ ์ค์
|
47 |
+
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
|
48 |
+
bnb_config = BitsAndBytesConfig(
|
49 |
+
load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
|
50 |
+
)
|
51 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
|
52 |
+
model = AutoModelForCausalLM.from_pretrained(
|
53 |
+
model_id,
|
54 |
+
torch_dtype=torch.bfloat16,
|
55 |
+
device_map="auto",
|
56 |
+
quantization_config=bnb_config,
|
57 |
+
token=token
|
58 |
+
)
|
59 |
+
|
60 |
+
SYS_PROMPT = """You are an assistant for answering legal questions.
|
61 |
+
You are given the extracted parts of legal documents and a question. Provide a conversational answer.
|
62 |
+
If you don't know the answer, just say "I do not know." Don't make up an answer.
|
63 |
+
you must answer korean."""
|
64 |
+
|
65 |
+
# ๋ฒ๋ฅ ๋ฌธ์ ๊ฒ์ ํจ์
|
66 |
+
def search_law(query, k=5):
|
67 |
+
query_embedding = ST.encode([query])
|
68 |
+
D, I = index.search(query_embedding, k)
|
69 |
+
return [(law_sentences[i], D[0][idx]) for idx, i in enumerate(I[0])]
|
70 |
+
|
71 |
+
# ๋ฒ๋ฅ ์๋ด ๋ฐ์ดํฐ ๊ฒ์ ํจ์
|
72 |
+
def search_qa(query, k=3):
|
73 |
+
scores, retrieved_examples = data.get_nearest_examples(
|
74 |
+
"question_embedding", ST.encode(query), k=k
|
75 |
+
)
|
76 |
+
return [retrieved_examples["answer"][i] for i in range(k)]
|
77 |
+
|
78 |
+
# ์ต์ข
ํ๋กฌํํธ ์์ฑ
|
79 |
+
def format_prompt(prompt, law_docs, qa_docs):
|
80 |
+
PROMPT = f"Question: {prompt}\n\nLegal Context:\n"
|
81 |
+
for doc in law_docs:
|
82 |
+
PROMPT += f"{doc[0]}\n" # Assuming doc[0] contains the relevant text
|
83 |
+
PROMPT += "\nLegal QA:\n"
|
84 |
+
for doc in qa_docs:
|
85 |
+
PROMPT += f"{doc}\n"
|
86 |
+
return PROMPT
|
87 |
+
|
88 |
+
# ์ฑ๋ด ์๋ต ํจ์
|
89 |
+
def talk(prompt, history):
|
90 |
+
law_results = search_law(prompt, k=3)
|
91 |
+
qa_results = search_qa(prompt, k=3)
|
92 |
+
|
93 |
+
retrieved_law_docs = [result[0] for result in law_results]
|
94 |
+
formatted_prompt = format_prompt(prompt, retrieved_law_docs, qa_results)
|
95 |
+
formatted_prompt = formatted_prompt[:2000] # GPU ๋ฉ๋ชจ๋ฆฌ ๋ถ์กฑ์ ํผํ๊ธฐ ์ํด ํ๋กฌํํธ ์ ํ
|
96 |
+
|
97 |
+
messages = [{"role": "system", "content": SYS_PROMPT}, {"role": "user", "content": formatted_prompt}]
|
98 |
+
|
99 |
+
# ๋ชจ๋ธ์๊ฒ ์์ฑ ์ง์
|
100 |
+
input_ids = tokenizer.apply_chat_template(
|
101 |
+
messages,
|
102 |
+
add_generation_prompt=True,
|
103 |
+
return_tensors="pt"
|
104 |
+
).to(model.device)
|
105 |
+
|
106 |
+
streamer = TextIteratorStreamer(
|
107 |
+
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
|
108 |
+
)
|
109 |
+
|
110 |
+
generate_kwargs = dict(
|
111 |
+
input_ids=input_ids,
|
112 |
+
streamer=streamer,
|
113 |
+
max_new_tokens=1024,
|
114 |
+
do_sample=True,
|
115 |
+
top_p=0.95,
|
116 |
+
temperature=0.75,
|
117 |
+
eos_token_id=tokenizer.eos_token_id,
|
118 |
+
)
|
119 |
+
|
120 |
+
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
121 |
+
t.start()
|
122 |
+
|
123 |
+
outputs = []
|
124 |
+
for text in streamer:
|
125 |
+
outputs.append(text)
|
126 |
+
yield "".join(outputs)
|
127 |
+
|
128 |
+
# Gradio ์ธํฐํ์ด์ค ์ค์
|
129 |
+
TITLE = "Legal RAG Chatbot"
|
130 |
+
DESCRIPTION = """A chatbot that uses Retrieval-Augmented Generation (RAG) for legal consultation.
|
131 |
+
This chatbot can search legal documents and previous legal QA pairs to provide answers."""
|
132 |
+
|
133 |
+
demo = gr.ChatInterface(
|
134 |
+
fn=talk,
|
135 |
+
chatbot=gr.Chatbot(
|
136 |
+
show_label=True,
|
137 |
+
show_share_button=True,
|
138 |
+
show_copy_button=True,
|
139 |
+
likeable=True,
|
140 |
+
layout="bubble",
|
141 |
+
bubble_full_width=False,
|
142 |
+
),
|
143 |
+
theme="Soft",
|
144 |
+
examples=[["What are the regulations on data privacy?"]],
|
145 |
+
title=TITLE,
|
146 |
+
description=DESCRIPTION,
|
147 |
+
)
|
148 |
+
|
149 |
+
# Gradio ๋ฐ๋ชจ ์คํ
|
150 |
+
demo.launch(debug=True)
|
laws.pdf
ADDED
Binary file (836 kB). View file
|
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
spaces
|
2 |
+
torch==2.2.0
|
3 |
+
transformers
|
4 |
+
sentence-transformers
|
5 |
+
faiss-gpu
|
6 |
+
datasets
|
7 |
+
accelerate
|
8 |
+
bitsandbytes
|
9 |
+
PyMuPDF
|