mirror of
https://github.com/weyne85/PentestGPT.git
synced 2025-10-29 16:58:59 +00:00
feat: 🎸 update together with testing
update chatgpt wrapper to keep allignment with RESTful testing
This commit is contained in:
150
utils/chatgpt.py
150
utils/chatgpt.py
@@ -1,14 +1,16 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import dataclasses
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from uuid import uuid1
|
||||
import datetime
|
||||
|
||||
import loguru
|
||||
import requests
|
||||
|
||||
|
||||
from config.chatgpt_config import ChatGPTConfig
|
||||
|
||||
logger = loguru.logger
|
||||
@@ -26,6 +28,31 @@ logger = loguru.logger
|
||||
# is_debugging: bool = False
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Message:
|
||||
ask_id: str = None
|
||||
ask: dict = None
|
||||
answer: dict = None
|
||||
answer_id: str = None
|
||||
request_start_timestamp: float = None
|
||||
request_end_timestamp: float = None
|
||||
time_escaped: float = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Conversation:
|
||||
conversation_id: str = None
|
||||
message_list: List[Message] = dataclasses.field(default_factory=list)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.conversation_id)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Conversation):
|
||||
return False
|
||||
return self.conversation_id == other.conversation_id
|
||||
|
||||
|
||||
class ChatGPT:
|
||||
def __init__(self, config: ChatGPTConfig):
|
||||
self.config = config
|
||||
@@ -35,7 +62,7 @@ class ChatGPT:
|
||||
self.cf_clearance = config.cf_clearance
|
||||
self.session_token = config.session_token
|
||||
# conversation_id: message_id
|
||||
self.latest_message_id_dict = {}
|
||||
self.conversation_dict: Dict[str, Conversation] = {}
|
||||
self.headers = dict(
|
||||
{
|
||||
"cookie": f"cf_clearance={self.cf_clearance}; _puid={self._puid}; "
|
||||
@@ -54,11 +81,26 @@ class ChatGPT:
|
||||
return "Bearer " + authorization
|
||||
|
||||
def get_latest_message_id(self, conversation_id):
|
||||
# Get continuous conversation message id
|
||||
url = f"https://chat.openai.com/backend-api/conversation/{conversation_id}"
|
||||
r = requests.get(url, headers=self.headers, proxies=self.proxies)
|
||||
return r.json()["current_node"]
|
||||
|
||||
def _parse_message_raw_output(self, response: requests.Response):
|
||||
# parse message raw output
|
||||
last_line = None
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
decoded_line = line.decode("utf-8")
|
||||
if len(decoded_line) == 12:
|
||||
break
|
||||
if "data:" in decoded_line:
|
||||
last_line = decoded_line
|
||||
result = json.loads(last_line[5:])
|
||||
return result
|
||||
|
||||
def send_new_message(self, message):
|
||||
# 发送新会话窗口消息,返回会话id
|
||||
logger.info(f"send_new_message")
|
||||
url = "https://chat.openai.com/backend-api/conversation"
|
||||
message_id = str(uuid1())
|
||||
@@ -74,44 +116,58 @@ class ChatGPT:
|
||||
"parent_message_id": str(uuid1()),
|
||||
"model": self.model,
|
||||
}
|
||||
|
||||
r = requests.post(url, headers=self.headers, json=data, proxies=self.proxies)
|
||||
start_time = time.time()
|
||||
message: Message = Message()
|
||||
message.ask_id = message_id
|
||||
message.ask = data
|
||||
message.request_start_timestamp = start_time
|
||||
r = requests.post(
|
||||
url, headers=self.headers, json=data, proxies=self.proxies, stream=True
|
||||
)
|
||||
if r.status_code != 200:
|
||||
logger.error(r.json()["detail"])
|
||||
time.sleep(self.config.error_wait_time)
|
||||
return self.send_new_message(message)
|
||||
# wait for 20s
|
||||
logger.error(r.text)
|
||||
return None, None
|
||||
|
||||
last_line = None
|
||||
|
||||
for line in r.iter_lines():
|
||||
if line:
|
||||
decoded_line = line.decode("utf-8")
|
||||
if len(decoded_line) == 12:
|
||||
break
|
||||
last_line = decoded_line
|
||||
|
||||
result = json.loads(last_line[5:])
|
||||
# parsing result
|
||||
result = self._parse_message_raw_output(r)
|
||||
text = "\n".join(result["message"]["content"]["parts"])
|
||||
conversation_id = result["conversation_id"]
|
||||
response_message_id = result["message"]["id"]
|
||||
self.latest_message_id_dict[conversation_id] = response_message_id
|
||||
answer_id = result["message"]["id"]
|
||||
|
||||
end_time = time.time()
|
||||
message.answer_id = answer_id
|
||||
message.answer = result
|
||||
message.request_end_timestamp = end_time
|
||||
message.time_escaped = end_time - start_time
|
||||
conversation: Conversation = Conversation()
|
||||
conversation.conversation_id = conversation_id
|
||||
conversation.message_list.append(message)
|
||||
|
||||
self.conversation_dict[conversation_id] = conversation
|
||||
return text, conversation_id
|
||||
|
||||
def send_message(self, message, conversation_id):
|
||||
# Send message to specific conversation
|
||||
logger.info(f"send_message")
|
||||
url = "https://chat.openai.com/backend-api/conversation"
|
||||
if conversation_id not in self.latest_message_id_dict:
|
||||
|
||||
# get message from id
|
||||
if conversation_id not in self.conversation_dict:
|
||||
logger.info(f"conversation_id: {conversation_id}")
|
||||
message_id = self.get_latest_message_id(conversation_id)
|
||||
logger.info(f"message_id: {message_id}")
|
||||
else:
|
||||
message_id = self.latest_message_id_dict[conversation_id]
|
||||
message_id = (
|
||||
self.conversation_dict[conversation_id].message_list[-1].answer_id
|
||||
)
|
||||
|
||||
new_message_id = str(uuid1())
|
||||
data = {
|
||||
"action": "next",
|
||||
"messages": [
|
||||
{
|
||||
"id": str(uuid1()),
|
||||
"id": new_message_id,
|
||||
"role": "user",
|
||||
"content": {"content_type": "text", "parts": [message]},
|
||||
}
|
||||
@@ -120,20 +176,35 @@ class ChatGPT:
|
||||
"parent_message_id": message_id,
|
||||
"model": self.model,
|
||||
}
|
||||
r = requests.post(url, headers=self.headers, json=data, proxies=self.proxies)
|
||||
if r.status_code != 200:
|
||||
logger.warning(r.json()["detail"])
|
||||
time.sleep(self.config.error_wait_time)
|
||||
return self.send_message(message, conversation_id)
|
||||
response_result = json.loads(r.text.split("data: ")[-2])
|
||||
new_message_id = response_result["message"]["id"]
|
||||
self.latest_message_id_dict[conversation_id] = new_message_id
|
||||
text = "\n".join(response_result["message"]["content"]["parts"])
|
||||
return text
|
||||
|
||||
def extract_code_fragments(self, text):
|
||||
code_fragments = re.findall(r"```(.*?)```", text, re.DOTALL)
|
||||
return code_fragments
|
||||
start_time = time.time()
|
||||
message: Message = Message()
|
||||
message.ask_id = new_message_id
|
||||
message.ask = data
|
||||
message.request_start_timestamp = start_time
|
||||
|
||||
r = requests.post(
|
||||
url, headers=self.headers, json=data, proxies=self.proxies, stream=True
|
||||
)
|
||||
if r.status_code != 200:
|
||||
# 发送消息阻塞时等待20秒从新发送
|
||||
logger.warning(f"chatgpt failed: {r.text}")
|
||||
return None, None
|
||||
# parsing result
|
||||
|
||||
result = self._parse_message_raw_output(r)
|
||||
text = "\n".join(result["message"]["content"]["parts"])
|
||||
conversation_id = result["conversation_id"]
|
||||
answer_id = result["message"]["id"]
|
||||
|
||||
end_time = time.time()
|
||||
message.answer_id = answer_id
|
||||
message.answer = result
|
||||
message.request_end_timestamp = end_time
|
||||
message.time_escaped = end_time - start_time
|
||||
conversation: Conversation = self.conversation_dict[conversation_id]
|
||||
conversation.message_list.append(message)
|
||||
return text
|
||||
|
||||
def get_conversation_history(self, limit=20, offset=0):
|
||||
# Get the conversation id in the history
|
||||
@@ -166,8 +237,8 @@ class ChatGPT:
|
||||
r = requests.patch(url, headers=self.headers, json=data, proxies=self.proxies)
|
||||
|
||||
# delete conversation id locally
|
||||
if conversation_id in self.latest_message_id_dict:
|
||||
del self.latest_message_id_dict[conversation_id]
|
||||
if conversation_id in self.conversation_dict:
|
||||
del self.conversation_dict[conversation_id]
|
||||
|
||||
if r.status_code == 200:
|
||||
return True
|
||||
@@ -175,6 +246,9 @@ class ChatGPT:
|
||||
logger.error("Failed to delete conversation")
|
||||
return False
|
||||
|
||||
def extract_code_fragments(self, text):
|
||||
code_fragments = re.findall(r"```(.*?)```", text, re.DOTALL)
|
||||
return code_fragments
|
||||
|
||||
if __name__ == "__main__":
|
||||
chatgpt_config = ChatGPTConfig()
|
||||
@@ -186,4 +260,4 @@ if __name__ == "__main__":
|
||||
"generate: {'post': {'tags': ['pet'], 'summary': 'uploads an image', 'description': '', 'operationId': 'uploadFile', 'consumes': ['multipart/form-data'], 'produces': ['application/json'], 'parameters': [{'name': 'petId', 'in': 'path', 'description': 'ID of pet to update', 'required': True, 'type': 'integer', 'format': 'int64'}, {'name': 'additionalMetadata', 'in': 'formData', 'description': 'Additional data to pass to server', 'required': False, 'type': 'string'}, {'name': 'file', 'in': 'formData', 'description': 'file to upload', 'required': False, 'type': 'file'}], 'responses': {'200': {'description': 'successful operation', 'schema': {'type': 'object', 'properties': {'code': {'type': 'integer', 'format': 'int32'}, 'type': {'type': 'string'}, 'message': {'type': 'string'}}}}}, 'security': [{'petstore_auth': ['write:pets', 'read:pets']}]}}",
|
||||
conversation_id,
|
||||
)
|
||||
logger.info(chatgpt.extract_code_fragments(result))
|
||||
logger.info(chatgpt.extract_code_fragments(result))
|
||||
Reference in New Issue
Block a user