From 3f6b6d84c4ff3c65d76316c52359f73dade97505 Mon Sep 17 00:00:00 2001 From: Grey_D Date: Sat, 25 Mar 2023 00:12:18 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=F0=9F=8E=B8=20update=20together=20with?= =?UTF-8?q?=20testing?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit update chatgpt wrapper to keep allignment with RESTful testing --- utils/chatgpt.py | 150 +++++++++++++++++++++++++++++++++++------------ 1 file changed, 112 insertions(+), 38 deletions(-) diff --git a/utils/chatgpt.py b/utils/chatgpt.py index 96658e2..07732c0 100644 --- a/utils/chatgpt.py +++ b/utils/chatgpt.py @@ -1,14 +1,16 @@ # -*- coding: utf-8 -*- +import dataclasses import json import re import time +from typing import Any, Dict, List, Tuple from uuid import uuid1 -import datetime import loguru import requests + from config.chatgpt_config import ChatGPTConfig logger = loguru.logger @@ -26,6 +28,31 @@ logger = loguru.logger # is_debugging: bool = False +@dataclasses.dataclass +class Message: + ask_id: str = None + ask: dict = None + answer: dict = None + answer_id: str = None + request_start_timestamp: float = None + request_end_timestamp: float = None + time_escaped: float = None + + +@dataclasses.dataclass +class Conversation: + conversation_id: str = None + message_list: List[Message] = dataclasses.field(default_factory=list) + + def __hash__(self): + return hash(self.conversation_id) + + def __eq__(self, other): + if not isinstance(other, Conversation): + return False + return self.conversation_id == other.conversation_id + + class ChatGPT: def __init__(self, config: ChatGPTConfig): self.config = config @@ -35,7 +62,7 @@ class ChatGPT: self.cf_clearance = config.cf_clearance self.session_token = config.session_token # conversation_id: message_id - self.latest_message_id_dict = {} + self.conversation_dict: Dict[str, Conversation] = {} self.headers = dict( { "cookie": f"cf_clearance={self.cf_clearance}; _puid={self._puid}; " @@ -54,11 +81,26 @@ class ChatGPT: return "Bearer " + authorization def get_latest_message_id(self, conversation_id): + # Get continuous conversation message id url = f"https://chat.openai.com/backend-api/conversation/{conversation_id}" r = requests.get(url, headers=self.headers, proxies=self.proxies) return r.json()["current_node"] + def _parse_message_raw_output(self, response: requests.Response): + # parse message raw output + last_line = None + for line in response.iter_lines(): + if line: + decoded_line = line.decode("utf-8") + if len(decoded_line) == 12: + break + if "data:" in decoded_line: + last_line = decoded_line + result = json.loads(last_line[5:]) + return result + def send_new_message(self, message): + # 发送新会话窗口消息,返回会话id logger.info(f"send_new_message") url = "https://chat.openai.com/backend-api/conversation" message_id = str(uuid1()) @@ -74,44 +116,58 @@ class ChatGPT: "parent_message_id": str(uuid1()), "model": self.model, } - - r = requests.post(url, headers=self.headers, json=data, proxies=self.proxies) + start_time = time.time() + message: Message = Message() + message.ask_id = message_id + message.ask = data + message.request_start_timestamp = start_time + r = requests.post( + url, headers=self.headers, json=data, proxies=self.proxies, stream=True + ) if r.status_code != 200: - logger.error(r.json()["detail"]) - time.sleep(self.config.error_wait_time) - return self.send_new_message(message) + # wait for 20s + logger.error(r.text) + return None, None - last_line = None - - for line in r.iter_lines(): - if line: - decoded_line = line.decode("utf-8") - if len(decoded_line) == 12: - break - last_line = decoded_line - - result = json.loads(last_line[5:]) + # parsing result + result = self._parse_message_raw_output(r) text = "\n".join(result["message"]["content"]["parts"]) conversation_id = result["conversation_id"] - response_message_id = result["message"]["id"] - self.latest_message_id_dict[conversation_id] = response_message_id + answer_id = result["message"]["id"] + + end_time = time.time() + message.answer_id = answer_id + message.answer = result + message.request_end_timestamp = end_time + message.time_escaped = end_time - start_time + conversation: Conversation = Conversation() + conversation.conversation_id = conversation_id + conversation.message_list.append(message) + + self.conversation_dict[conversation_id] = conversation return text, conversation_id def send_message(self, message, conversation_id): + # Send message to specific conversation logger.info(f"send_message") url = "https://chat.openai.com/backend-api/conversation" - if conversation_id not in self.latest_message_id_dict: + + # get message from id + if conversation_id not in self.conversation_dict: logger.info(f"conversation_id: {conversation_id}") message_id = self.get_latest_message_id(conversation_id) logger.info(f"message_id: {message_id}") else: - message_id = self.latest_message_id_dict[conversation_id] + message_id = ( + self.conversation_dict[conversation_id].message_list[-1].answer_id + ) + new_message_id = str(uuid1()) data = { "action": "next", "messages": [ { - "id": str(uuid1()), + "id": new_message_id, "role": "user", "content": {"content_type": "text", "parts": [message]}, } @@ -120,20 +176,35 @@ class ChatGPT: "parent_message_id": message_id, "model": self.model, } - r = requests.post(url, headers=self.headers, json=data, proxies=self.proxies) - if r.status_code != 200: - logger.warning(r.json()["detail"]) - time.sleep(self.config.error_wait_time) - return self.send_message(message, conversation_id) - response_result = json.loads(r.text.split("data: ")[-2]) - new_message_id = response_result["message"]["id"] - self.latest_message_id_dict[conversation_id] = new_message_id - text = "\n".join(response_result["message"]["content"]["parts"]) - return text - def extract_code_fragments(self, text): - code_fragments = re.findall(r"```(.*?)```", text, re.DOTALL) - return code_fragments + start_time = time.time() + message: Message = Message() + message.ask_id = new_message_id + message.ask = data + message.request_start_timestamp = start_time + + r = requests.post( + url, headers=self.headers, json=data, proxies=self.proxies, stream=True + ) + if r.status_code != 200: + # 发送消息阻塞时等待20秒从新发送 + logger.warning(f"chatgpt failed: {r.text}") + return None, None + # parsing result + + result = self._parse_message_raw_output(r) + text = "\n".join(result["message"]["content"]["parts"]) + conversation_id = result["conversation_id"] + answer_id = result["message"]["id"] + + end_time = time.time() + message.answer_id = answer_id + message.answer = result + message.request_end_timestamp = end_time + message.time_escaped = end_time - start_time + conversation: Conversation = self.conversation_dict[conversation_id] + conversation.message_list.append(message) + return text def get_conversation_history(self, limit=20, offset=0): # Get the conversation id in the history @@ -166,8 +237,8 @@ class ChatGPT: r = requests.patch(url, headers=self.headers, json=data, proxies=self.proxies) # delete conversation id locally - if conversation_id in self.latest_message_id_dict: - del self.latest_message_id_dict[conversation_id] + if conversation_id in self.conversation_dict: + del self.conversation_dict[conversation_id] if r.status_code == 200: return True @@ -175,6 +246,9 @@ class ChatGPT: logger.error("Failed to delete conversation") return False + def extract_code_fragments(self, text): + code_fragments = re.findall(r"```(.*?)```", text, re.DOTALL) + return code_fragments if __name__ == "__main__": chatgpt_config = ChatGPTConfig() @@ -186,4 +260,4 @@ if __name__ == "__main__": "generate: {'post': {'tags': ['pet'], 'summary': 'uploads an image', 'description': '', 'operationId': 'uploadFile', 'consumes': ['multipart/form-data'], 'produces': ['application/json'], 'parameters': [{'name': 'petId', 'in': 'path', 'description': 'ID of pet to update', 'required': True, 'type': 'integer', 'format': 'int64'}, {'name': 'additionalMetadata', 'in': 'formData', 'description': 'Additional data to pass to server', 'required': False, 'type': 'string'}, {'name': 'file', 'in': 'formData', 'description': 'file to upload', 'required': False, 'type': 'file'}], 'responses': {'200': {'description': 'successful operation', 'schema': {'type': 'object', 'properties': {'code': {'type': 'integer', 'format': 'int32'}, 'type': {'type': 'string'}, 'message': {'type': 'string'}}}}}, 'security': [{'petstore_auth': ['write:pets', 'read:pets']}]}}", conversation_id, ) - logger.info(chatgpt.extract_code_fragments(result)) + logger.info(chatgpt.extract_code_fragments(result)) \ No newline at end of file