refactor(openai.py): moving openai chat completion calls to http

This commit is contained in:
Krrish Dholakia 2023-11-08 17:40:32 -08:00
parent da1451e493
commit c57ed0a9d7
6 changed files with 158 additions and 127 deletions

View file

@ -7,7 +7,7 @@ from typing import Callable, Optional
# This file just has the openai config classes.
# For implementation check out completion() in main.py
class CustomOpenAIError(Exception):
class OpenAIError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
@ -163,7 +163,7 @@ class OpenAIChatCompletion(BaseLLM):
def convert_to_model_response_object(self, response_object: Optional[dict]=None, model_response_object: Optional[ModelResponse]=None):
try:
if response_object is None or model_response_object is None:
raise CustomOpenAIError(status_code=500, message="Error in response object format")
raise OpenAIError(status_code=500, message="Error in response object format")
choice_list=[]
for idx, choice in enumerate(response_object["choices"]):
message = Message(content=choice["message"]["content"], role=choice["message"]["role"])
@ -181,7 +181,7 @@ class OpenAIChatCompletion(BaseLLM):
model_response_object.model = response_object["model"]
return model_response_object
except:
CustomOpenAIError(status_code=500, message="Invalid response object.")
OpenAIError(status_code=500, message="Invalid response object.")
def completion(self,
model: Optional[str]=None,
@ -193,58 +193,79 @@ class OpenAIChatCompletion(BaseLLM):
logging_obj=None,
optional_params=None,
litellm_params=None,
logger_fn=None):
logger_fn=None,
headers: Optional[dict]=None):
super().completion()
headers = self.validate_environment(api_key=api_key)
if model is None or messages is None:
raise CustomOpenAIError(status_code=422, message=f"Missing model or messages")
for _ in range(2): # if call fails due to alternating messages, retry with reformatted message
data = {
"model": model,
"messages": messages,
**optional_params
}
try:
if "stream" in optional_params and optional_params["stream"] == True:
response = self._client_session.post(
url=f"{api_base}/chat/completions",
json=data,
headers=headers,
stream=optional_params["stream"]
)
if response.status_code != 200:
raise CustomOpenAIError(status_code=response.status_code, message=response.text)
## RESPONSE OBJECT
return response.iter_lines()
else:
response = self._client_session.post(
url=f"{api_base}/chat/completions",
json=data,
headers=headers,
)
if response.status_code != 200:
raise CustomOpenAIError(status_code=response.status_code, message=response.text)
## RESPONSE OBJECT
return self.convert_to_model_response_object(response_object=response.json(), model_response_object=model_response)
except Exception as e:
if "Conversation roles must alternate user/assistant" in str(e) or "user and assistant roles should be alternating" in str(e):
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility
new_messages = []
for i in range(len(messages)-1):
new_messages.append(messages[i])
if messages[i]["role"] == messages[i+1]["role"]:
if messages[i]["role"] == "user":
new_messages.append({"role": "assistant", "content": ""})
else:
new_messages.append({"role": "user", "content": ""})
new_messages.append(messages[-1])
messages = new_messages
elif "Last message must have role `user`" in str(e):
new_messages = messages
new_messages.append({"role": "user", "content": ""})
messages = new_messages
else:
raise e
exception_mapping_worked = False
try:
if headers is None:
headers = self.validate_environment(api_key=api_key)
if model is None or messages is None:
raise OpenAIError(status_code=422, message=f"Missing model or messages")
for _ in range(2): # if call fails due to alternating messages, retry with reformatted message
data = {
"model": model,
"messages": messages,
**optional_params
}
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=api_key,
additional_args={"headers": headers, "api_base": api_base},
)
try:
if "stream" in optional_params and optional_params["stream"] == True:
response = self._client_session.post(
url=f"{api_base}/chat/completions",
json=data,
headers=headers,
stream=optional_params["stream"]
)
if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text)
## RESPONSE OBJECT
return response.iter_lines()
else:
response = self._client_session.post(
url=f"{api_base}/chat/completions",
json=data,
headers=headers,
)
if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text)
## RESPONSE OBJECT
return self.convert_to_model_response_object(response_object=response.json(), model_response_object=model_response)
except Exception as e:
if "Conversation roles must alternate user/assistant" in str(e) or "user and assistant roles should be alternating" in str(e):
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility
new_messages = []
for i in range(len(messages)-1):
new_messages.append(messages[i])
if messages[i]["role"] == messages[i+1]["role"]:
if messages[i]["role"] == "user":
new_messages.append({"role": "assistant", "content": ""})
else:
new_messages.append({"role": "user", "content": ""})
new_messages.append(messages[-1])
messages = new_messages
elif "Last message must have role `user`" in str(e):
new_messages = messages
new_messages.append({"role": "user", "content": ""})
messages = new_messages
else:
raise e
except OpenAIError as e:
exception_mapping_worked = True
raise e
except Exception as e:
if exception_mapping_worked:
raise e
else:
import traceback
raise OpenAIError(status_code=500, message=traceback.format_exc())