forked from phoenix/litellm-mirror
refactor(openai.py): moving openai chat completion calls to http
This commit is contained in:
parent
da1451e493
commit
c57ed0a9d7
6 changed files with 158 additions and 127 deletions
|
@ -7,7 +7,7 @@ from typing import Callable, Optional
|
|||
# This file just has the openai config classes.
|
||||
# For implementation check out completion() in main.py
|
||||
|
||||
class CustomOpenAIError(Exception):
|
||||
class OpenAIError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
|
@ -163,7 +163,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
def convert_to_model_response_object(self, response_object: Optional[dict]=None, model_response_object: Optional[ModelResponse]=None):
|
||||
try:
|
||||
if response_object is None or model_response_object is None:
|
||||
raise CustomOpenAIError(status_code=500, message="Error in response object format")
|
||||
raise OpenAIError(status_code=500, message="Error in response object format")
|
||||
choice_list=[]
|
||||
for idx, choice in enumerate(response_object["choices"]):
|
||||
message = Message(content=choice["message"]["content"], role=choice["message"]["role"])
|
||||
|
@ -181,7 +181,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
model_response_object.model = response_object["model"]
|
||||
return model_response_object
|
||||
except:
|
||||
CustomOpenAIError(status_code=500, message="Invalid response object.")
|
||||
OpenAIError(status_code=500, message="Invalid response object.")
|
||||
|
||||
def completion(self,
|
||||
model: Optional[str]=None,
|
||||
|
@ -193,58 +193,79 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
logging_obj=None,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None):
|
||||
logger_fn=None,
|
||||
headers: Optional[dict]=None):
|
||||
super().completion()
|
||||
headers = self.validate_environment(api_key=api_key)
|
||||
if model is None or messages is None:
|
||||
raise CustomOpenAIError(status_code=422, message=f"Missing model or messages")
|
||||
|
||||
for _ in range(2): # if call fails due to alternating messages, retry with reformatted message
|
||||
data = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
**optional_params
|
||||
}
|
||||
try:
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
response = self._client_session.post(
|
||||
url=f"{api_base}/chat/completions",
|
||||
json=data,
|
||||
headers=headers,
|
||||
stream=optional_params["stream"]
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise CustomOpenAIError(status_code=response.status_code, message=response.text)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
return response.iter_lines()
|
||||
else:
|
||||
response = self._client_session.post(
|
||||
url=f"{api_base}/chat/completions",
|
||||
json=data,
|
||||
headers=headers,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise CustomOpenAIError(status_code=response.status_code, message=response.text)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
return self.convert_to_model_response_object(response_object=response.json(), model_response_object=model_response)
|
||||
except Exception as e:
|
||||
if "Conversation roles must alternate user/assistant" in str(e) or "user and assistant roles should be alternating" in str(e):
|
||||
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility
|
||||
new_messages = []
|
||||
for i in range(len(messages)-1):
|
||||
new_messages.append(messages[i])
|
||||
if messages[i]["role"] == messages[i+1]["role"]:
|
||||
if messages[i]["role"] == "user":
|
||||
new_messages.append({"role": "assistant", "content": ""})
|
||||
else:
|
||||
new_messages.append({"role": "user", "content": ""})
|
||||
new_messages.append(messages[-1])
|
||||
messages = new_messages
|
||||
elif "Last message must have role `user`" in str(e):
|
||||
new_messages = messages
|
||||
new_messages.append({"role": "user", "content": ""})
|
||||
messages = new_messages
|
||||
else:
|
||||
raise e
|
||||
exception_mapping_worked = False
|
||||
try:
|
||||
if headers is None:
|
||||
headers = self.validate_environment(api_key=api_key)
|
||||
if model is None or messages is None:
|
||||
raise OpenAIError(status_code=422, message=f"Missing model or messages")
|
||||
|
||||
for _ in range(2): # if call fails due to alternating messages, retry with reformatted message
|
||||
data = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
**optional_params
|
||||
}
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
additional_args={"headers": headers, "api_base": api_base},
|
||||
)
|
||||
|
||||
try:
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
response = self._client_session.post(
|
||||
url=f"{api_base}/chat/completions",
|
||||
json=data,
|
||||
headers=headers,
|
||||
stream=optional_params["stream"]
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise OpenAIError(status_code=response.status_code, message=response.text)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
return response.iter_lines()
|
||||
else:
|
||||
response = self._client_session.post(
|
||||
url=f"{api_base}/chat/completions",
|
||||
json=data,
|
||||
headers=headers,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise OpenAIError(status_code=response.status_code, message=response.text)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
return self.convert_to_model_response_object(response_object=response.json(), model_response_object=model_response)
|
||||
except Exception as e:
|
||||
if "Conversation roles must alternate user/assistant" in str(e) or "user and assistant roles should be alternating" in str(e):
|
||||
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility
|
||||
new_messages = []
|
||||
for i in range(len(messages)-1):
|
||||
new_messages.append(messages[i])
|
||||
if messages[i]["role"] == messages[i+1]["role"]:
|
||||
if messages[i]["role"] == "user":
|
||||
new_messages.append({"role": "assistant", "content": ""})
|
||||
else:
|
||||
new_messages.append({"role": "user", "content": ""})
|
||||
new_messages.append(messages[-1])
|
||||
messages = new_messages
|
||||
elif "Last message must have role `user`" in str(e):
|
||||
new_messages = messages
|
||||
new_messages.append({"role": "user", "content": ""})
|
||||
messages = new_messages
|
||||
else:
|
||||
raise e
|
||||
except OpenAIError as e:
|
||||
exception_mapping_worked = True
|
||||
raise e
|
||||
except Exception as e:
|
||||
if exception_mapping_worked:
|
||||
raise e
|
||||
else:
|
||||
import traceback
|
||||
raise OpenAIError(status_code=500, message=traceback.format_exc())
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue