EZO-AutoCoTRAG-Qwen2.5-72B-Instruct_q4 / modeling_custom_qwen.py
AXCXEPT's picture
Create modeling_custom_qwen.py
93ddbeb verified
raw
history blame
No virus
16.8 kB
from transformers import Qwen2Config, Qwen2ForCausalLM
import torch
import requests
from bs4 import BeautifulSoup
from duckduckgo_search import DDGS
import logging
import re
# ログの設定
logging.basicConfig(level=logging.INFO)
class CustomQwen2Config(Qwen2Config):
model_type = "custom_qwen2config"
def __init__(self, **kwargs):
super().__init__(**kwargs)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
config = super().from_dict(config_dict, **kwargs)
return config
def to_dict(self):
output = super().to_dict()
output["model_type"] = self.model_type
return output
class CustomQwen2Model(Qwen2ForCausalLM):
config_class = CustomQwen2Config
def __init__(self, config):
super().__init__(config)
self.tokenizer = None
self.embedding_model = None
self.max_iterations = 5 # Maximum number of times to recreate keywords
self.use_search = True
self.top_k = 3 # of documents to retrieve for each search
self.max_search_attempts = 3 # of search attempts for each keyword
def set_tokenizer(self, tokenizer=None):
self.tokenizer = tokenizer
# パラメータ設定メソッド
def set_max_iterations(self, max_iterations):
self.max_iterations = max_iterations
def set_use_search(self, use_search):
self.use_search = use_search
def set_top_k(self, top_k):
self.top_k = top_k
def generate_step(self, input_ids, max_new_tokens=150):
"""
Generates output from input_ids and returns tokenized output.
"""
input_ids = input_ids.to(self.device)
output_ids = super().generate(input_ids, max_new_tokens=max_new_tokens)
return output_ids # Return tokenized results
def extract_response(self, output_ids, keyword):
"""
Extracts the tokens following a specific keyword from the generated response.
Returns extracted text.
"""
# Decode generated output to text
raw_response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
# Extract text after keywords
pattern = rf"{re.escape(keyword)}\s*(.*)"
match = re.search(pattern, raw_response, re.DOTALL)
if match:
# Return matched parts
extracted_text = match.group(1).strip()
return extracted_text
else:
# Return empty string if keyword not found
return "[ALL]" + raw_response
def generate(self, input_ids, max_new_tokens=150, **kwargs):
logging.info(f"Maximum keyword regeneration attempts: {self.max_iterations}")
logging.info(f"External URL reference: {'Enabled' if self.use_search else 'Disabled'}")
logging.info(f"k_top value: {self.top_k}")
org_instruction = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
# of attempts to re-create keywords
keyword_attempt = 0
sufficient_info = False
summarized_info = ""
while keyword_attempt < self.max_iterations and not sufficient_info:
logging.info(f"Keyword regeneration attempt: {keyword_attempt + 1}/{self.max_iterations}")
# When using external references
if self.use_search:
logging.info("Retrieving relevant information using external URL references...")
for search_attempt in range(1, self.max_search_attempts + 1):
logging.info(f"Search attempt: {search_attempt}/{self.max_search_attempts}")
relevant_docs = self.retrieve_relevant_information(org_instruction, top_k=self.top_k)
summarized_info = self.summarize_documents(relevant_docs, org_instruction)
# Determine whether to accept or reject the answer.
sufficient_info = self.is_answer_sufficient(summarized_info, org_instruction)
if sufficient_info:
logging.info("Sufficient information found.")
break
else:
logging.info("Insufficient information. Attempting next search.")
if not sufficient_info:
# Regenerate keywords
new_keywords = self.generate_new_keywords(org_instruction)
if new_keywords:
org_instruction = self.update_instruction_with_new_keywords(org_instruction, new_keywords)
logging.info(f"Retrying search with new keywords: {new_keywords}")
else:
logging.warning("Failed to generate new keywords.")
break
else:
summarized_info = ""
sufficient_info = False
keyword_attempt += 1
if not sufficient_info:
logging.info("Relevant data sources not found. Performing self-reasoning.")
final_response = self.self_reasoning(org_instruction, max_new_tokens)
else:
# Perform normal answer generation process
final_response = self.generate_answer(org_instruction, summarized_info, max_new_tokens)
# Return final answer
final_response_ids = self.tokenizer.encode(final_response, return_tensors="pt").to(self.device)
return final_response_ids
def retrieve_relevant_information(self, user_input, top_k=3):
search_query = self.generate_search_query(user_input)
logging.info(f"Generated search query: {search_query}")
if not search_query:
logging.warning("Search query is empty.")
return ["No relevant information found."]
with DDGS() as ddgs:
search_results = ddgs.text(
keywords=search_query,
region='wt-wt',
safesearch='off',
timelimit=None,
max_results=20
)
search_results = list(search_results)
if not search_results:
return ["No relevant information found."]
# Filtering search results
documents = []
for result in search_results:
if 'body' in result and result['body']:
documents.append(result['body'])
elif 'snippet' in result and result['snippet']:
documents.append(result['snippet'])
# Select top k documents
documents = documents[:top_k]
return documents
def generate_search_query(self, user_input):
"""
Generates a search query using the model's inference.
"""
# Create prompt
prompt = f"""
User's question:
{user_input}
Organize what you need to know to answer this problem and list three keywords to research.
Keywords:
-"""
# Encode prompt
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
# Generate output from model
output_ids = self.generate_step(input_ids, max_new_tokens=50)
# Extract keywords from output
generated_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
# Extract keyword section
pattern = r"Keywords:\s*(.*)" # Changed from "Keywords:\s*(.*)"
match = re.search(pattern, generated_text, re.DOTALL)
if match:
keywords_text = match.group(1).strip()
# Listify keywords
keywords = re.findall(r"-\s*(.*)", keywords_text)
search_query = ' '.join(keywords)
logging.info(f"Generated search query: {search_query}")
return search_query
else:
logging.warning("Failed to generate keywords.")
return ""
def generate_new_keywords(self, user_input):
"""
Attempts to regenerate keywords.
"""
prompt = f"""
User's question:
{user_input}
Insufficient information was obtained. Please generate new keywords.
List three new keywords.
Keywords:
-"""
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
output_ids = self.generate_step(input_ids, max_new_tokens=50)
generated_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
pattern = r"Keywords:\s*(.*)" # Changed from "Keywords:\s*(.*)"
match = re.search(pattern, generated_text, re.DOTALL)
if match:
keywords_text = match.group(1).strip()
keywords = re.findall(r"-\s*(.*)", keywords_text)
search_query = ' '.join(keywords)
logging.info(f"Regenerated search query: {search_query}")
return search_query
else:
logging.warning("Failed to extract regenerated keywords.")
return ""
def update_instruction_with_new_keywords(self, instruction, new_keywords):
"""
Incorporates new keywords into the original instruction.
"""
# Simply appends new keywords to the original instruction.
updated_instruction = f"{instruction} Keywords: {new_keywords}"
return updated_instruction
def is_answer_sufficient(self, summarized_info, user_input):
"""
Determines if the summarized information is sufficient to answer the question.
"""
prompt = f"""
User's question:
{user_input}
Retrieved information:
{summarized_info}
Based on this information, determine if you can answer the user's question.
If yes, respond with "Yes". If no, respond with "No" only.
"""
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
output_ids = self.generate_step(input_ids, max_new_tokens=10)
generated_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
if "Yes" in generated_text:
return True
else:
return False
def generate_answer(self, user_input, summarized_info, max_new_tokens=150):
"""
Generates an answer based on the retrieved information.
"""
# Step 1: Understanding the question and extracting key points
step1_prompt = f"""
#User's question:
{user_input}
#Step 1: Understanding the question and extracting key points
Accurately understand the user's question or instructions.
Output the rules for answering and the tasks to be performed in a bullet list.
#Rules for answering and tasks to be performed:
"""
step1_input_ids = self.tokenizer.encode(step1_prompt, return_tensors="pt").to(self.device)
outputs_step1 = self.generate_step(step1_input_ids, max_new_tokens=max_new_tokens)
step1_response = self.extract_response(outputs_step1, "#Rules for answering and tasks to be performed:")
logging.info("Understanding the question...\n======================\n" + step1_response)
# Step 2: Considerations for problem-solving
step2_prompt = f"""
#Step 2: Considerations for problem-solving
Based on the content of Step 1, consider approaches and necessary information for solving the problem.
#Step 2 response:
"""
step2_input_ids = self.tokenizer.encode(step1_response + step2_prompt, return_tensors="pt").to(self.device)
outputs_step2 = self.generate_step(step2_input_ids, max_new_tokens=max_new_tokens)
step2_response = self.extract_response(outputs_step2, "#Step 2 response:")
logging.info("Considering approaches for problem-solving...\n======================\n" + step2_response)
# Step 3: Creating the initial answer
step3_prompt = f"""
#Step 3: Creating the initial answer
Based on the content so far, create an initial answer to the user's question.
Your information may not be up-to-date. Fully consider information from the internet.
#Latest internet information:
{summarized_info}
#Initial answer:
"""
step3_input_ids = self.tokenizer.encode(step2_response + step3_prompt, return_tensors="pt").to(self.device)
outputs_step3 = self.generate_step(step3_input_ids, max_new_tokens=max_new_tokens)
step3_response = self.extract_response(outputs_step3, "#Initial answer:")
logging.info("Creating the initial answer...\n======================\n" + step3_response)
# Step 4: Reflection (Self-verification)
reflection_prompt = f"""
#Step 4: Reflection (Self-verification)
Verify whether the initial answer accurately responds to the user's question or instructions, and point out any errors or areas for improvement.
Be cautious of overinterpreting the instructions and critically assess whether you have accurately understood them.
Your information may not be up-to-date. Fully consider information from the internet.
Reconfirm the user's question and provide an accurate answer to the question itself. (Ensure that you provide an answer to the question itself)
#User's question:
{user_input}
#Latest internet information:
{summarized_info}
#Initial answer:
{step3_response}
#Reflection result:
"""
reflection_input_ids = self.tokenizer.encode(reflection_prompt, return_tensors="pt").to(self.device)
outputs_reflection = self.generate_step(reflection_input_ids, max_new_tokens=max_new_tokens)
reflection_response = self.extract_response(outputs_reflection, "#Reflection result:")
logging.info("Performing reflection...\n======================\n" + reflection_response)
# Step 5: Creating the final answer
final_prompt = f"""
#Step 5: Creating the final answer
Based on the reflection results, modify the initial answer as needed.
Your knowledge may not be up-to-date. Fully consider information from the internet.
Reconfirm the user's question, and check for overinterpretation, misunderstandings, omissions, and careless mistakes.
Create the final answer incorporating these.
#Initial answer:
{step3_response}
#Reflection result:
{reflection_response}
#Latest internet information:
{summarized_info}
#User's question:
{user_input}
Please provide the final answer to the user's question.
#Final answer:
"""
final_input_ids = self.tokenizer.encode(final_prompt, return_tensors="pt").to(self.device)
outputs_final = self.generate_step(final_input_ids, max_new_tokens=max_new_tokens)
final_response = self.extract_response(outputs_final, "#Final answer:").strip()
return final_response
def self_reasoning(self, user_input, max_new_tokens=150):
"""
Generates an answer based on self-reasoning.
"""
prompt = f"""
User's question:
{user_input}
No relevant information was found on the internet. Please use your own knowledge and reasoning to answer.
#Answer based on self-reasoning:
"""
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
output_ids = self.generate_step(input_ids, max_new_tokens=max_new_tokens)
generated_text = self.extract_response(output_ids, "#Answer based on self-reasoning:").strip()
logging.info("Answer based on self-reasoning:\n======================\n" + generated_text)
return generated_text
def process_document(self, doc, user_input):
"""
Determines if each document is relevant to the user's question and generates an answer if applicable.
"""
# Create prompt
prompt = f"""
User's question:
{user_input}
Content of the document:
{doc[:2000]} # Truncate if too long
Do not think of the question superficially. Use paradoxes and rephrasing to organize.
Create an answer to the question based on the content of this document.
Understand the points of disagreement between your own thoughts and the answer you would create based on this document, and prioritize the answer based on the document.
Answer:
"""
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device)
output_ids = self.generate_step(input_ids, max_new_tokens=500)
generated_text = self.extract_response(output_ids, "Answer:")
logging.info("Document processing result: " + generated_text)
# Return empty string if deemed low relevance
if "low relevance" in generated_text:
return ""
else:
return generated_text.strip()
def summarize_documents(self, documents, user_input):
"""
Processes each document and summarizes relevant information.
"""
summaries = []
for doc in documents:
processed_text = self.process_document(doc, user_input)
if processed_text:
summaries.append(processed_text)
return "\n\n".join(summaries)