File size: 16,803 Bytes
93ddbeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
from transformers import Qwen2Config, Qwen2ForCausalLM
import torch
import requests
from bs4 import BeautifulSoup
from duckduckgo_search import DDGS
import logging
import re

# ログの設定
logging.basicConfig(level=logging.INFO)

class CustomQwen2Config(Qwen2Config):
    model_type = "custom_qwen2config"

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
        config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
        config = super().from_dict(config_dict, **kwargs)
        return config

    def to_dict(self):
        output = super().to_dict()
        output["model_type"] = self.model_type
        return output

class CustomQwen2Model(Qwen2ForCausalLM):
    config_class = CustomQwen2Config

    def __init__(self, config):
        super().__init__(config)
        self.tokenizer = None
        self.embedding_model = None
        self.max_iterations = 5  # Maximum number of times to recreate keywords
        self.use_search = True
        self.top_k = 3  # of documents to retrieve for each search
        self.max_search_attempts = 3  # of search attempts for each keyword

    def set_tokenizer(self, tokenizer=None):
        self.tokenizer = tokenizer

    # パラメータ設定メソッド
    def set_max_iterations(self, max_iterations):
        self.max_iterations = max_iterations

    def set_use_search(self, use_search):
        self.use_search = use_search

    def set_top_k(self, top_k):
        self.top_k = top_k

    def generate_step(self, input_ids, max_new_tokens=150):
        """
        Generates output from input_ids and returns tokenized output.
        """
        input_ids = input_ids.to(self.device)
        output_ids = super().generate(input_ids, max_new_tokens=max_new_tokens)
        return output_ids  # Return tokenized results

    def extract_response(self, output_ids, keyword):
        """
        Extracts the tokens following a specific keyword from the generated response.
        Returns extracted text.
        """
        # Decode generated output to text
        raw_response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
        
        # Extract text after keywords
        pattern = rf"{re.escape(keyword)}\s*(.*)"
        match = re.search(pattern, raw_response, re.DOTALL)

        if match:
            # Return matched parts
            extracted_text = match.group(1).strip()
            return extracted_text
        else:
            # Return empty string if keyword not found
            return "[ALL]" + raw_response

    def generate(self, input_ids, max_new_tokens=150, **kwargs):
        logging.info(f"Maximum keyword regeneration attempts: {self.max_iterations}")
        logging.info(f"External URL reference: {'Enabled' if self.use_search else 'Disabled'}")
        logging.info(f"k_top value: {self.top_k}")

        org_instruction = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)

        # of attempts to re-create keywords
        keyword_attempt = 0
        sufficient_info = False
        summarized_info = ""

        while keyword_attempt < self.max_iterations and not sufficient_info:
            logging.info(f"Keyword regeneration attempt: {keyword_attempt + 1}/{self.max_iterations}")

            # When using external references
            if self.use_search:
                logging.info("Retrieving relevant information using external URL references...")
                for search_attempt in range(1, self.max_search_attempts + 1):
                    logging.info(f"Search attempt: {search_attempt}/{self.max_search_attempts}")
                    relevant_docs = self.retrieve_relevant_information(org_instruction, top_k=self.top_k)
                    summarized_info = self.summarize_documents(relevant_docs, org_instruction)

                    # Determine whether to accept or reject the answer.
                    sufficient_info = self.is_answer_sufficient(summarized_info, org_instruction)
                    if sufficient_info:
                        logging.info("Sufficient information found.")
                        break
                    else:
                        logging.info("Insufficient information. Attempting next search.")

                if not sufficient_info:
                    # Regenerate keywords
                    new_keywords = self.generate_new_keywords(org_instruction)
                    if new_keywords:
                        org_instruction = self.update_instruction_with_new_keywords(org_instruction, new_keywords)
                        logging.info(f"Retrying search with new keywords: {new_keywords}")
                    else:
                        logging.warning("Failed to generate new keywords.")
                        break

            else:
                summarized_info = ""
                sufficient_info = False

            keyword_attempt += 1

        if not sufficient_info:
            logging.info("Relevant data sources not found. Performing self-reasoning.")
            final_response = self.self_reasoning(org_instruction, max_new_tokens)
        else:
            # Perform normal answer generation process
            final_response = self.generate_answer(org_instruction, summarized_info, max_new_tokens)

        # Return final answer
        final_response_ids = self.tokenizer.encode(final_response, return_tensors="pt").to(self.device)
        return final_response_ids

    def retrieve_relevant_information(self, user_input, top_k=3):
        search_query = self.generate_search_query(user_input)
        logging.info(f"Generated search query: {search_query}")

        if not search_query:
            logging.warning("Search query is empty.")
            return ["No relevant information found."]

        with DDGS() as ddgs:
            search_results = ddgs.text(
                keywords=search_query,
                region='wt-wt',
                safesearch='off',
                timelimit=None,
                max_results=20
            )
            search_results = list(search_results)

        if not search_results:
            return ["No relevant information found."]

        # Filtering search results
        documents = []
        for result in search_results:
            if 'body' in result and result['body']:
                documents.append(result['body'])
            elif 'snippet' in result and result['snippet']:
                documents.append(result['snippet'])

        # Select top k documents
        documents = documents[:top_k]
        return documents

    def generate_search_query(self, user_input):
        """
        Generates a search query using the model's inference.
        """
        # Create prompt
        prompt = f"""
User's question:
{user_input}

Organize what you need to know to answer this problem and list three keywords to research.

Keywords:
-"""
        # Encode prompt
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
        # Generate output from model
        output_ids = self.generate_step(input_ids, max_new_tokens=50)
        # Extract keywords from output
        generated_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
        # Extract keyword section
        pattern = r"Keywords:\s*(.*)"  # Changed from "Keywords:\s*(.*)"
        match = re.search(pattern, generated_text, re.DOTALL)
        if match:
            keywords_text = match.group(1).strip()
            # Listify keywords
            keywords = re.findall(r"-\s*(.*)", keywords_text)
            search_query = ' '.join(keywords)
            logging.info(f"Generated search query: {search_query}")
            return search_query
        else:
            logging.warning("Failed to generate keywords.")
            return ""

    def generate_new_keywords(self, user_input):
        """
        Attempts to regenerate keywords.
        """
        prompt = f"""
User's question:
{user_input}

Insufficient information was obtained. Please generate new keywords.
List three new keywords.

Keywords:
-"""
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
        output_ids = self.generate_step(input_ids, max_new_tokens=50)
        generated_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
        pattern = r"Keywords:\s*(.*)"  # Changed from "Keywords:\s*(.*)"
        match = re.search(pattern, generated_text, re.DOTALL)
        if match:
            keywords_text = match.group(1).strip()
            keywords = re.findall(r"-\s*(.*)", keywords_text)
            search_query = ' '.join(keywords)
            logging.info(f"Regenerated search query: {search_query}")
            return search_query
        else:
            logging.warning("Failed to extract regenerated keywords.")
            return ""

    def update_instruction_with_new_keywords(self, instruction, new_keywords):
        """
        Incorporates new keywords into the original instruction.
        """
        # Simply appends new keywords to the original instruction.
        updated_instruction = f"{instruction} Keywords: {new_keywords}"
        return updated_instruction

    def is_answer_sufficient(self, summarized_info, user_input):
        """
        Determines if the summarized information is sufficient to answer the question.
        """
        prompt = f"""
User's question:
{user_input}

Retrieved information:
{summarized_info}

Based on this information, determine if you can answer the user's question.
If yes, respond with "Yes". If no, respond with "No" only.
"""
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
        output_ids = self.generate_step(input_ids, max_new_tokens=10)
        generated_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()

        if "Yes" in generated_text:
            return True
        else:
            return False

    def generate_answer(self, user_input, summarized_info, max_new_tokens=150):
        """
        Generates an answer based on the retrieved information.
        """
        # Step 1: Understanding the question and extracting key points
        step1_prompt = f"""
#User's question:
{user_input}

#Step 1: Understanding the question and extracting key points
Accurately understand the user's question or instructions.
Output the rules for answering and the tasks to be performed in a bullet list.

#Rules for answering and tasks to be performed:
"""
        step1_input_ids = self.tokenizer.encode(step1_prompt, return_tensors="pt").to(self.device)
        outputs_step1 = self.generate_step(step1_input_ids, max_new_tokens=max_new_tokens)
        step1_response = self.extract_response(outputs_step1, "#Rules for answering and tasks to be performed:")
        logging.info("Understanding the question...\n======================\n" + step1_response)

        # Step 2: Considerations for problem-solving
        step2_prompt = f"""
#Step 2: Considerations for problem-solving
Based on the content of Step 1, consider approaches and necessary information for solving the problem.

#Step 2 response:
"""
        step2_input_ids = self.tokenizer.encode(step1_response + step2_prompt, return_tensors="pt").to(self.device)
        outputs_step2 = self.generate_step(step2_input_ids, max_new_tokens=max_new_tokens)
        step2_response = self.extract_response(outputs_step2, "#Step 2 response:")
        logging.info("Considering approaches for problem-solving...\n======================\n" + step2_response)

        # Step 3: Creating the initial answer
        step3_prompt = f"""
#Step 3: Creating the initial answer
Based on the content so far, create an initial answer to the user's question.
Your information may not be up-to-date. Fully consider information from the internet.

#Latest internet information:
{summarized_info}

#Initial answer:
"""
        step3_input_ids = self.tokenizer.encode(step2_response + step3_prompt, return_tensors="pt").to(self.device)
        outputs_step3 = self.generate_step(step3_input_ids, max_new_tokens=max_new_tokens)
        step3_response = self.extract_response(outputs_step3, "#Initial answer:")
        logging.info("Creating the initial answer...\n======================\n" + step3_response)

        # Step 4: Reflection (Self-verification)
        reflection_prompt = f"""
#Step 4: Reflection (Self-verification)
Verify whether the initial answer accurately responds to the user's question or instructions, and point out any errors or areas for improvement.
Be cautious of overinterpreting the instructions and critically assess whether you have accurately understood them.
Your information may not be up-to-date. Fully consider information from the internet.
Reconfirm the user's question and provide an accurate answer to the question itself. (Ensure that you provide an answer to the question itself)

#User's question:
{user_input}

#Latest internet information:
{summarized_info}

#Initial answer:
{step3_response}

#Reflection result:
"""
        reflection_input_ids = self.tokenizer.encode(reflection_prompt, return_tensors="pt").to(self.device)
        outputs_reflection = self.generate_step(reflection_input_ids, max_new_tokens=max_new_tokens)
        reflection_response = self.extract_response(outputs_reflection, "#Reflection result:")
        logging.info("Performing reflection...\n======================\n" + reflection_response)

        # Step 5: Creating the final answer
        final_prompt = f"""
#Step 5: Creating the final answer
Based on the reflection results, modify the initial answer as needed.
Your knowledge may not be up-to-date. Fully consider information from the internet.
Reconfirm the user's question, and check for overinterpretation, misunderstandings, omissions, and careless mistakes.
Create the final answer incorporating these.

#Initial answer:
{step3_response}

#Reflection result:
{reflection_response}

#Latest internet information:
{summarized_info}

#User's question:
{user_input}

Please provide the final answer to the user's question.
#Final answer:
"""
        final_input_ids = self.tokenizer.encode(final_prompt, return_tensors="pt").to(self.device)
        outputs_final = self.generate_step(final_input_ids, max_new_tokens=max_new_tokens)
        final_response = self.extract_response(outputs_final, "#Final answer:").strip()

        return final_response

    def self_reasoning(self, user_input, max_new_tokens=150):
        """
        Generates an answer based on self-reasoning.
        """
        prompt = f"""
User's question:
{user_input}

No relevant information was found on the internet. Please use your own knowledge and reasoning to answer.

#Answer based on self-reasoning:
"""
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
        output_ids = self.generate_step(input_ids, max_new_tokens=max_new_tokens)
        generated_text = self.extract_response(output_ids, "#Answer based on self-reasoning:").strip()
        logging.info("Answer based on self-reasoning:\n======================\n" + generated_text)
        return generated_text

    def process_document(self, doc, user_input):
        """
        Determines if each document is relevant to the user's question and generates an answer if applicable.
        """
        # Create prompt
        prompt = f"""
User's question:
{user_input}

Content of the document:
{doc[:2000]}  # Truncate if too long

Do not think of the question superficially. Use paradoxes and rephrasing to organize.
Create an answer to the question based on the content of this document.
Understand the points of disagreement between your own thoughts and the answer you would create based on this document, and prioritize the answer based on the document.

Answer:
"""
        input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device)
        output_ids = self.generate_step(input_ids, max_new_tokens=500)
        generated_text = self.extract_response(output_ids, "Answer:")
        logging.info("Document processing result: " + generated_text)
        # Return empty string if deemed low relevance
        if "low relevance" in generated_text:
            return ""
        else:
            return generated_text.strip()

    def summarize_documents(self, documents, user_input):
        """
        Processes each document and summarizes relevant information.
        """
        summaries = []
        for doc in documents:
            processed_text = self.process_document(doc, user_input)
            if processed_text:
                summaries.append(processed_text)
        return "\n\n".join(summaries)