Hansimov commited on
Commit
03712cd
1 Parent(s): bee1f65

:gem: [Feature] Enable stream output of chat completions

Browse files
Files changed (1) hide show
  1. tests/openai.py +38 -12
tests/openai.py CHANGED
@@ -1,4 +1,6 @@
1
  import copy
 
 
2
  import uuid
3
 
4
  from pathlib import Path
@@ -46,17 +48,16 @@ class OpenaiAPI:
46
  "http": http_proxy,
47
  "https": http_proxy,
48
  }
 
 
49
  else:
50
  self.requests_proxies = None
51
 
52
  def log_request(self, url, method="GET"):
53
- if ENVER["http_proxy"]:
54
- logger.note(f"> Using Proxy:", end=" ")
55
- logger.mesg(f"{ENVER['http_proxy']}")
56
  logger.note(f"> {method}:", end=" ")
57
  logger.mesg(f"{url}", end=" ")
58
 
59
- def log_response(self, res: requests.Response, stream=False):
60
  status_code = res.status_code
61
  status_code_str = f"[{status_code}]"
62
 
@@ -64,12 +65,35 @@ class OpenaiAPI:
64
  logger_func = logger.success
65
  else:
66
  logger_func = logger.warn
 
67
  logger_func(status_code_str)
68
 
69
- if stream:
70
- logger_func(res.text)
71
- else:
72
- logger_func(res.json())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  def get_models(self):
75
  self.log_request(self.api_models)
@@ -111,7 +135,7 @@ class OpenaiAPI:
111
  "metadata": {},
112
  }
113
  ],
114
- "parent_message_id": str(uuid.uuid4()),
115
  "model": "text-davinci-002-render-sha",
116
  "timezone_offset_min": -480,
117
  "suggestions": [],
@@ -124,22 +148,24 @@ class OpenaiAPI:
124
  "websocket_request_id": str(uuid.uuid4()),
125
  }
126
  self.log_request(self.api_conversation, method="POST")
127
- res = requests.post(
 
128
  self.api_conversation,
129
  headers=requests_headers,
130
  json=post_data,
131
  proxies=self.requests_proxies,
132
  timeout=10,
133
  impersonate="chrome120",
 
134
  )
135
- self.log_response(res, stream=True)
136
 
137
 
138
  if __name__ == "__main__":
139
  api = OpenaiAPI()
140
  # api.get_models()
141
  api.auth()
142
- prompt = "who are you?"
143
  api.chat_completions(prompt)
144
 
145
  # python -m tests.openai
 
1
  import copy
2
+ import json
3
+ import re
4
  import uuid
5
 
6
  from pathlib import Path
 
48
  "http": http_proxy,
49
  "https": http_proxy,
50
  }
51
+ logger.note(f"> Using Proxy:", end=" ")
52
+ logger.mesg(f"{ENVER['http_proxy']}")
53
  else:
54
  self.requests_proxies = None
55
 
56
  def log_request(self, url, method="GET"):
 
 
 
57
  logger.note(f"> {method}:", end=" ")
58
  logger.mesg(f"{url}", end=" ")
59
 
60
+ def log_response(self, res: requests.Response, stream=False, verbose=False):
61
  status_code = res.status_code
62
  status_code_str = f"[{status_code}]"
63
 
 
65
  logger_func = logger.success
66
  else:
67
  logger_func = logger.warn
68
+
69
  logger_func(status_code_str)
70
 
71
+ if verbose:
72
+ if stream:
73
+ if not hasattr(self, "content_offset"):
74
+ self.content_offset = 0
75
+
76
+ for line in res.iter_lines():
77
+ line = line.decode("utf-8")
78
+ line = re.sub(r"^data:\s*", "", line)
79
+ if re.match(r"^\[DONE\]", line):
80
+ logger.success("\n[Finished]")
81
+ break
82
+ line = line.strip()
83
+ if line:
84
+ try:
85
+ data = json.loads(line, strict=False)
86
+ role = data["message"]["author"]["role"]
87
+ if role != "assistant":
88
+ continue
89
+ content = data["message"]["content"]["parts"][0]
90
+ delta_content = content[self.content_offset :]
91
+ self.content_offset = len(content)
92
+ logger_func(delta_content, end="")
93
+ except Exception as e:
94
+ logger.warn(e)
95
+ else:
96
+ logger_func(res.json())
97
 
98
  def get_models(self):
99
  self.log_request(self.api_models)
 
135
  "metadata": {},
136
  }
137
  ],
138
+ "parent_message_id": "",
139
  "model": "text-davinci-002-render-sha",
140
  "timezone_offset_min": -480,
141
  "suggestions": [],
 
148
  "websocket_request_id": str(uuid.uuid4()),
149
  }
150
  self.log_request(self.api_conversation, method="POST")
151
+ s = requests.Session()
152
+ res = s.post(
153
  self.api_conversation,
154
  headers=requests_headers,
155
  json=post_data,
156
  proxies=self.requests_proxies,
157
  timeout=10,
158
  impersonate="chrome120",
159
+ stream=True,
160
  )
161
+ self.log_response(res, stream=True, verbose=True)
162
 
163
 
164
  if __name__ == "__main__":
165
  api = OpenaiAPI()
166
  # api.get_models()
167
  api.auth()
168
+ prompt = "你的名字?"
169
  api.chat_completions(prompt)
170
 
171
  # python -m tests.openai