Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,585 Bytes
7a919c0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
# Copyright (c) OpenMMLab. All rights reserved.
"""Search enhancement proxy."""
import argparse
import json
import os
import pytoml
from loguru import logger
from .llm_client import ChatClient
class SourceGraphProxy:
"""A class to serve as a proxy for interacting with the Source Graph.
Args:
config_path (dict): Path to the configuration file.
topk (int, optional): Top K results to consider from the search. Defaults to 1. # noqa E501
language (str, optional): Language for the system prompts - 'zh' for Chinese and 'en' for English. Defaults to 'zh'. # noqa E501
Attributes:
config_path (str): The path of the configuration file.
sg_config (dict): Configuration settings for sourcegraph search.
topk (int): Top K results to consider from the search.
language (str): Language for the system prompts.
CHOICE_TEMPLATE (str): Template string for generating choice based on selected language. # noqa E501
KEYWORDS_TEMPLATE (str): Template string for generating keywords based on selected language. # noqa E501
"""
def __init__(self,
config_path: dict,
topk=1,
language: str = 'zh') -> None:
"""Init searcher with config."""
self.config_path = config_path
self.sg_config = None
with open(self.config_path, encoding='utf8') as f:
config = pytoml.load(f)
self.sg_config = config['sg_search']
self.topk = topk
self.language = language
if self.language == 'zh':
self.CHOICE_TEMPLATE = '“{}”\n请仔细阅读以上问题,请问应该查询以下哪个开源项目:\n' # noqa E501
self.KEYWORDS_TEMPLATE = '“{}”\n请仔细阅读以上问题,提取其中可用作搜索引擎的关键字,关键字之间,分隔,不要解释。' # noqa E501
else:
self.CHOICE_TEMPLATE = '"{}"\nPlease read the above question carefully, which of the following open-source projects should this question refer to: \n' # noqa E501
self.KEYWORDS_TEMPLATE = '"{}"\nPlease read the above questions carefully, extract the keywords which can be used as search engines, between keywords, separate, do not explain.' # noqa E501
def command(self, txt: str):
"""Executes a shell command and returns its output.
Args:
txt (str): Command to be executed in the shell.
Returns:
str: Output of the shell command execution.
"""
logger.debug('cmd: {}'.format(txt))
cmd = os.popen(txt)
return cmd.read().rstrip().lstrip()
def extract_sg_result(self, jsonstr):
"""Extracts the desired data from the source graph result.
Args:
jsonstr (str): JSON string containing source graph search result.
Returns:
list: List of dictionaries each contains 'filepath' and 'content' of the files returned by source graph. # noqa E501
"""
ret = []
try:
root = json.loads(jsonstr)
results = root['Results']
for result in results:
if 'FileMatch' != result['__typename']:
continue
content = result['file']['content']
path = result['file']['path']
ret.append({'filepath': path, 'content': content})
if len(ret) >= self.topk:
break
except Exception as e:
logger.warning('{} when source graph parse {}'.format(
str(e), jsonstr))
return ret
def choose_repo(self, llm_client, question, groupname):
"""Interactively assists user to select a repository for search based
on user's question.
Args:
llm_client: Client instance for LLM.
question (str): User's question.
groupname (str): Name of the user's group.
Returns:
str: The ID of selected repository.
"""
prompt = self.CHOICE_TEMPLATE.format(question)
keys = self.sg_config.keys()
skip = ['binary_src_path', 'src_access_token']
repos = {}
for key in keys:
if key in skip:
continue
introduction = self.sg_config[key]['introduction']
prompt += f'* {key} {introduction}\n'
repos[key] = self.sg_config[key]
prompt += '* none '
choice = llm_client.generate_response(prompt=prompt,
backend='remote').strip()
target_repo_id = None
for key in repos.keys():
if key in choice:
target_repo_id = repos[key]['github_repo_id']
break
return target_repo_id
def search(self, llm_client, question, groupname):
"""Performs a search operation in the selected repository based on the
user's question.
Args:
llm_client: Client instance for LLM.
question (str): User's question.
groupname (str): Name of the user's group.
Returns:
str: Search result from source graph in JSON format.
"""
repo_id = self.choose_repo(llm_client, question, groupname)
if repo_id is None:
logger.warning('cannot choose repo_id')
return ''
ENV = 'export SRC_ACCESS_TOKEN="{}" && '.format(
self.sg_config['src_access_token'])
BINARY = self.sg_config['binary_src_path']
prompt = self.KEYWORDS_TEMPLATE.format(question)
entities = []
entity_str = ''
try:
entity_str = llm_client.generate_response(prompt=prompt)
entities = [item for item in entity_str.split(',') if item.strip()]
except Exception as e:
logger.error('parse {} failed {}.'.format(entity_str, str(e)))
# return ''
entities = []
search_items = []
for entity in entities:
# search doc and source code based on entities
# search -json 'repo:open-compass/opencompass summarizers'
cmd_doc = '''{} search -json 'repo:{} lang:MarkDown {}' '''.format(
BINARY, repo_id, entity)
cmd_return = self.command(ENV + cmd_doc)
search_items += self.extract_sg_result(cmd_return)
cmd_python = '''{} search -json 'repo:{} lang:Python {}' '''.format( # noqa E501
BINARY, repo_id, entity)
cmd_return = self.command(ENV + cmd_python)
search_items += self.extract_sg_result(cmd_return)
search_text = json.dumps(search_items, ensure_ascii=False, indent=2)
return search_text
def parse_args():
"""Parses command line arguments."""
parser = argparse.ArgumentParser(description='Source graph proxy search')
parser.add_argument(
'--config_path',
default='config.ini',
help= # noqa E251
'Source graph proxy configuration path. Default value is config.ini')
args = parser.parse_args()
return args
if __name__ == '__main__':
"""Test search."""
logger.add('logs/sg_search.log', rotation='4MB')
args = parse_args()
llm = ChatClient(config_path=args.config_path)
sg = SourceGraphProxy(config_path=args.config_path)
context = sg.search(llm,
question='请问triviaqa 5shot结果怎么在summarizer里输出呢',
groupname='opencompass')
print(context)
|