rohan13 commited on
Commit
7d98b2f
β€’
1 Parent(s): c5c4fb2

parser and prompt template

Browse files
Files changed (1) hide show
  1. utils.py +60 -8
utils.py CHANGED
@@ -1,17 +1,20 @@
1
  import os
2
  import pickle
 
 
3
 
4
- from langchain import LLMChain, OpenAI
5
- from langchain.agents import ConversationalAgent, AgentExecutor, Tool
6
- from langchain.memory import ConversationBufferWindowMemory
 
7
  from langchain.chains import ConversationalRetrievalChain
8
- from langchain.text_splitter import CharacterTextSplitter
9
  from langchain.document_loaders import DirectoryLoader, TextLoader, UnstructuredHTMLLoader
10
- import faiss
11
- from langchain.vectorstores.faiss import FAISS
12
  from langchain.embeddings import OpenAIEmbeddings
13
-
14
-
 
 
 
15
 
16
  pickle_file = "open_ai.pkl"
17
  index_file = "open_ai.index"
@@ -26,6 +29,55 @@ memory = ConversationBufferWindowMemory(memory_key="chat_history")
26
 
27
  gpt_3_5_index = None
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def get_search_index():
30
  global gpt_3_5_index
31
  if os.path.isfile(pickle_file) and os.path.isfile(index_file) and os.path.getsize(pickle_file) > 0:
 
1
  import os
2
  import pickle
3
+ import re
4
+ from typing import List, Union
5
 
6
+ import faiss
7
+ from langchain import OpenAI, LLMChain
8
+ from langchain.agents import ConversationalAgent
9
+ from langchain.agents import Tool, AgentExecutor, LLMSingleActionAgent, AgentOutputParser
10
  from langchain.chains import ConversationalRetrievalChain
 
11
  from langchain.document_loaders import DirectoryLoader, TextLoader, UnstructuredHTMLLoader
 
 
12
  from langchain.embeddings import OpenAIEmbeddings
13
+ from langchain.memory import ConversationBufferWindowMemory
14
+ from langchain.prompts import BaseChatPromptTemplate
15
+ from langchain.schema import AgentAction, AgentFinish, HumanMessage
16
+ from langchain.text_splitter import CharacterTextSplitter
17
+ from langchain.vectorstores.faiss import FAISS
18
 
19
  pickle_file = "open_ai.pkl"
20
  index_file = "open_ai.index"
 
29
 
30
  gpt_3_5_index = None
31
 
32
+ class CustomOutputParser(AgentOutputParser):
33
+
34
+ def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:
35
+ # Check if agent replied without using tools
36
+ if "AI:" in llm_output:
37
+ return AgentFinish(return_values={"output": llm_output.split("AI:")[-1].strip()},
38
+ log=llm_output)
39
+ # Check if agent should finish
40
+ if "Final Answer:" in llm_output:
41
+ return AgentFinish(
42
+ # Return values is generally always a dictionary with a single `output` key
43
+ # It is not recommended to try anything else at the moment :)
44
+ return_values={"output": llm_output.split("Final Answer:")[-1].strip()},
45
+ log=llm_output,
46
+ )
47
+ # Parse out the action and action input
48
+ regex = r"Action: (.*?)[\n]*Action Input:[\s]*(.*)"
49
+ match = re.search(regex, llm_output, re.DOTALL)
50
+ if not match:
51
+ raise ValueError(f"Could not parse LLM output: `{llm_output}`")
52
+ action = match.group(1).strip()
53
+ action_input = match.group(2)
54
+ # Return the action and action input
55
+ return AgentAction(tool=action, tool_input=action_input.strip(" ").strip('"'), log=llm_output)
56
+
57
+ # Set up a prompt template
58
+ class CustomPromptTemplate(BaseChatPromptTemplate):
59
+ # The template to use
60
+ template: str
61
+ # The list of tools available
62
+ tools: List[Tool]
63
+
64
+ def format_messages(self, **kwargs) -> str:
65
+ # Get the intermediate steps (AgentAction, Observation tuples)
66
+ # Format them in a particular way
67
+ intermediate_steps = kwargs.pop("intermediate_steps")
68
+ thoughts = ""
69
+ for action, observation in intermediate_steps:
70
+ thoughts += action.log
71
+ thoughts += f"\nObservation: {observation}\nThought: "
72
+ # Set the agent_scratchpad variable to that value
73
+ kwargs["agent_scratchpad"] = thoughts
74
+ # Create a tools variable from the list of tools provided
75
+ kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools])
76
+ # Create a list of tool names for the tools provided
77
+ kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools])
78
+ formatted = self.template.format(**kwargs)
79
+ return [HumanMessage(content=formatted)]
80
+
81
  def get_search_index():
82
  global gpt_3_5_index
83
  if os.path.isfile(pickle_file) and os.path.isfile(index_file) and os.path.getsize(pickle_file) > 0: