refactor(openai.py): making it compatible for openai v1

BREAKING CHANGE:
This commit is contained in:
Krrish Dholakia 2023-11-11 15:32:14 -08:00
parent 833c38edeb
commit d3323ba637
12 changed files with 622 additions and 370 deletions

View file

@ -375,6 +375,7 @@ from .integrations import *
from .exceptions import ( from .exceptions import (
AuthenticationError, AuthenticationError,
InvalidRequestError, InvalidRequestError,
BadRequestError,
RateLimitError, RateLimitError,
ServiceUnavailableError, ServiceUnavailableError,
OpenAIError, OpenAIError,

View file

@ -8,75 +8,82 @@
# Thank you users! We ❤️ you! - Krrish & Ishaan # Thank you users! We ❤️ you! - Krrish & Ishaan
## LiteLLM versions of the OpenAI Exception Types ## LiteLLM versions of the OpenAI Exception Types
from openai.error import (
from openai import (
AuthenticationError, AuthenticationError,
InvalidRequestError, BadRequestError,
RateLimitError, RateLimitError,
ServiceUnavailableError, APIStatusError,
OpenAIError, OpenAIError,
APIError, APIError,
Timeout, APITimeoutError,
APIConnectionError, APIConnectionError,
) )
import httpx
class AuthenticationError(AuthenticationError): # type: ignore class AuthenticationError(AuthenticationError): # type: ignore
def __init__(self, message, llm_provider, model): def __init__(self, message, llm_provider, model, response: httpx.Response):
self.status_code = 401 self.status_code = 401
self.message = message self.message = message
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.model = model self.model = model
super().__init__( super().__init__(
self.message self.message,
response=response,
body=None
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class BadRequestError(BadRequestError): # type: ignore
class InvalidRequestError(InvalidRequestError): # type: ignore def __init__(self, message, model, llm_provider, response: httpx.Response):
def __init__(self, message, model, llm_provider):
self.status_code = 400 self.status_code = 400
self.message = message self.message = message
self.model = model self.model = model
self.llm_provider = llm_provider self.llm_provider = llm_provider
super().__init__( super().__init__(
self.message, f"{self.model}" self.message,
response=response,
body=None
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class Timeout(Timeout): # type: ignore class Timeout(APITimeoutError): # type: ignore
def __init__(self, message, model, llm_provider): def __init__(self, message, model, llm_provider, request: httpx.Request):
self.status_code = 408 self.status_code = 408
self.message = message self.message = message
self.model = model self.model = model
self.llm_provider = llm_provider self.llm_provider = llm_provider
super().__init__( super().__init__(
self.message, f"{self.model}" request=request
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
# sub class of invalid request error - meant to give more granularity for error handling context window exceeded errors
class ContextWindowExceededError(InvalidRequestError): # type: ignore
def __init__(self, message, model, llm_provider):
self.status_code = 400
self.message = message
self.model = model
self.llm_provider = llm_provider
super().__init__(
self.message, self.model, self.llm_provider
) # Call the base class constructor with the parameters it needs
class RateLimitError(RateLimitError): # type: ignore class RateLimitError(RateLimitError): # type: ignore
def __init__(self, message, llm_provider, model): def __init__(self, message, llm_provider, model, response: httpx.Response):
self.status_code = 429 self.status_code = 429
self.message = message self.message = message
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.modle = model self.modle = model
super().__init__( super().__init__(
self.message self.message,
response=response,
body=None
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
# sub class of rate limit error - meant to give more granularity for error handling context window exceeded errors
class ContextWindowExceededError(BadRequestError): # type: ignore
def __init__(self, message, model, llm_provider, response: httpx.Response):
self.status_code = 400
self.message = message
self.model = model
self.llm_provider = llm_provider
super().__init__(
message=self.message,
model=self.model,
llm_provider=self.llm_provider,
response=response
) # Call the base class constructor with the parameters it needs
class ServiceUnavailableError(ServiceUnavailableError): # type: ignore class ServiceUnavailableError(APIStatusError): # type: ignore
def __init__(self, message, llm_provider, model): def __init__(self, message, llm_provider, model):
self.status_code = 500 self.status_code = 503
self.message = message self.message = message
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.model = model self.model = model
@ -87,13 +94,14 @@ class ServiceUnavailableError(ServiceUnavailableError): # type: ignore
# raise this when the API returns an invalid response object - https://github.com/openai/openai-python/blob/1be14ee34a0f8e42d3f9aa5451aa4cb161f1781f/openai/api_requestor.py#L401 # raise this when the API returns an invalid response object - https://github.com/openai/openai-python/blob/1be14ee34a0f8e42d3f9aa5451aa4cb161f1781f/openai/api_requestor.py#L401
class APIError(APIError): # type: ignore class APIError(APIError): # type: ignore
def __init__(self, status_code, message, llm_provider, model): def __init__(self, status_code, message, llm_provider, model, request: httpx.Request):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.llm_provider = llm_provider self.llm_provider = llm_provider
self.model = model self.model = model
super().__init__( super().__init__(
self.message self.message,
request=request
) )
# raised if an invalid request (not get, delete, put, post) is made # raised if an invalid request (not get, delete, put, post) is made
@ -123,4 +131,15 @@ class BudgetExceededError(Exception):
self.current_cost = current_cost self.current_cost = current_cost
self.max_budget = max_budget self.max_budget = max_budget
message = f"Budget has been exceeded! Current cost: {current_cost}, Max budget: {max_budget}" message = f"Budget has been exceeded! Current cost: {current_cost}, Max budget: {max_budget}"
super().__init__(message) super().__init__(message)
## DEPRECATED ##
class InvalidRequestError(BadRequestError): # type: ignore
def __init__(self, message, model, llm_provider):
self.status_code = 400
self.message = message
self.model = model
self.llm_provider = llm_provider
super().__init__(
self.message, f"{self.model}"
) # Call the base class constructor with the parameters it needs

View file

@ -1,16 +1,21 @@
## This is a template base class to be used for adding new LLM providers via API calls ## This is a template base class to be used for adding new LLM providers via API calls
import litellm import litellm
import requests, certifi, ssl import httpx, certifi, ssl
class BaseLLM: class BaseLLM:
_client_session = None
def create_client_session(self): def create_client_session(self):
if litellm.client_session: if litellm.client_session:
session = litellm.client_session _client_session = litellm.client_session
else: else:
session = requests.Session() _client_session = httpx.Client(timeout=600)
return session
return _client_session
def __exit__(self):
if hasattr(self, '_client_session'):
self._client_session.close()
def validate_environment(self): # set up the environment required to run the model def validate_environment(self): # set up the environment required to run the model
pass pass

View file

@ -1,14 +1,17 @@
from typing import Optional, Union from typing import Optional, Union
import types, requests import types
import httpx
from .base import BaseLLM from .base import BaseLLM
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, convert_to_model_response_object from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, convert_to_model_response_object
from typing import Callable, Optional from typing import Callable, Optional
import aiohttp import aiohttp
class OpenAIError(Exception): class OpenAIError(Exception):
def __init__(self, status_code, message): def __init__(self, status_code, message, request: httpx.Request, response: httpx.Response):
self.status_code = status_code self.status_code = status_code
self.message = message self.message = message
self.request = request
self.response = response
super().__init__( super().__init__(
self.message self.message
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
@ -144,7 +147,7 @@ class OpenAITextCompletionConfig():
and v is not None} and v is not None}
class OpenAIChatCompletion(BaseLLM): class OpenAIChatCompletion(BaseLLM):
_client_session: requests.Session _client_session: httpx.Client
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -200,18 +203,8 @@ class OpenAIChatCompletion(BaseLLM):
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model) return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model)
else: else:
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response) return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response)
elif "stream" in optional_params and optional_params["stream"] == True: elif optional_params.get("stream", False):
response = self._client_session.post( return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model)
url=api_base,
json=data,
headers=headers,
stream=optional_params["stream"]
)
if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text)
## RESPONSE OBJECT
return response.iter_lines()
else: else:
response = self._client_session.post( response = self._client_session.post(
url=api_base, url=api_base,
@ -219,7 +212,7 @@ class OpenAIChatCompletion(BaseLLM):
headers=headers, headers=headers,
) )
if response.status_code != 200: if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text) raise OpenAIError(status_code=response.status_code, message=response.text, request=response.request, response=response)
## RESPONSE OBJECT ## RESPONSE OBJECT
return convert_to_model_response_object(response_object=response.json(), model_response_object=model_response) return convert_to_model_response_object(response_object=response.json(), model_response_object=model_response)
@ -246,41 +239,64 @@ class OpenAIChatCompletion(BaseLLM):
exception_mapping_worked = True exception_mapping_worked = True
raise e raise e
except Exception as e: except Exception as e:
if exception_mapping_worked: raise e
raise e
else:
import traceback
raise OpenAIError(status_code=500, message=traceback.format_exc())
async def acompletion(self, async def acompletion(self,
api_base: str, api_base: str,
data: dict, headers: dict, data: dict, headers: dict,
model_response: ModelResponse): model_response: ModelResponse):
async with aiohttp.ClientSession() as session: async with httpx.AsyncClient() as client:
async with session.post(api_base, json=data, headers=headers, ssl=False) as response: response = await client.post(api_base, json=data, headers=headers)
response_json = await response.json() response_json = response.json()
if response.status != 200: if response.status != 200:
raise OpenAIError(status_code=response.status, message=response.text) raise OpenAIError(status_code=response.status, message=response.text)
## RESPONSE OBJECT ## RESPONSE OBJECT
return convert_to_model_response_object(response_object=response_json, model_response_object=model_response) return convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
def streaming(self,
logging_obj,
api_base: str,
data: dict,
headers: dict,
model_response: ModelResponse,
model: str
):
with self._client_session.stream(
url=f"{api_base}",
json=data,
headers=headers,
method="POST"
) as response:
if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text(), request=self._client_session.request, response=response)
completion_stream = response.iter_lines()
streamwrapper = CustomStreamWrapper(completion_stream=completion_stream, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
for transformed_chunk in streamwrapper:
yield transformed_chunk
async def async_streaming(self, async def async_streaming(self,
logging_obj, logging_obj,
api_base: str, api_base: str,
data: dict, headers: dict, data: dict,
headers: dict,
model_response: ModelResponse, model_response: ModelResponse,
model: str): model: str):
async with aiohttp.ClientSession() as session: client = httpx.AsyncClient()
async with session.post(api_base, json=data, headers=headers, ssl=False) as response: async with client.stream(
# Check if the request was successful (status code 200) url=f"{api_base}",
if response.status != 200: json=data,
raise OpenAIError(status_code=response.status, message=await response.text()) headers=headers,
method="POST"
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj) ) as response:
async for transformed_chunk in streamwrapper: if response.status_code != 200:
yield transformed_chunk raise OpenAIError(status_code=response.status_code, message=response.text(), request=self._client_session.request, response=response)
streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="openai",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
def embedding(self, def embedding(self,
model: str, model: str,
@ -349,7 +365,7 @@ class OpenAIChatCompletion(BaseLLM):
class OpenAITextCompletion(BaseLLM): class OpenAITextCompletion(BaseLLM):
_client_session: requests.Session _client_session: httpx.Client
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -367,7 +383,7 @@ class OpenAITextCompletion(BaseLLM):
try: try:
## RESPONSE OBJECT ## RESPONSE OBJECT
if response_object is None or model_response_object is None: if response_object is None or model_response_object is None:
raise OpenAIError(status_code=500, message="Error in response object format") raise ValueError(message="Error in response object format")
choice_list=[] choice_list=[]
for idx, choice in enumerate(response_object["choices"]): for idx, choice in enumerate(response_object["choices"]):
message = Message(content=choice["text"], role="assistant") message = Message(content=choice["text"], role="assistant")
@ -386,8 +402,8 @@ class OpenAITextCompletion(BaseLLM):
model_response_object._hidden_params["original_response"] = response_object # track original response, if users make a litellm.text_completion() request, we can return the original response model_response_object._hidden_params["original_response"] = response_object # track original response, if users make a litellm.text_completion() request, we can return the original response
return model_response_object return model_response_object
except: except Exception as e:
OpenAIError(status_code=500, message="Invalid response object.") raise e
def completion(self, def completion(self,
model: Optional[str]=None, model: Optional[str]=None,
@ -397,6 +413,7 @@ class OpenAITextCompletion(BaseLLM):
api_key: Optional[str]=None, api_key: Optional[str]=None,
api_base: Optional[str]=None, api_base: Optional[str]=None,
logging_obj=None, logging_obj=None,
acompletion: bool = False,
optional_params=None, optional_params=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
@ -412,9 +429,6 @@ class OpenAITextCompletion(BaseLLM):
api_base = f"{api_base}/completions" api_base = f"{api_base}/completions"
if len(messages)>0 and "content" in messages[0] and type(messages[0]["content"]) == list: if len(messages)>0 and "content" in messages[0] and type(messages[0]["content"]) == list:
# Note: internal logic - for enabling litellm.text_completion()
# text-davinci-003 can accept a string or array, if it's an array, assume the array is set in messages[0]['content']
# https://platform.openai.com/docs/api-reference/completions/create
prompt = messages[0]["content"] prompt = messages[0]["content"]
else: else:
prompt = " ".join([message["content"] for message in messages]) # type: ignore prompt = " ".join([message["content"] for message in messages]) # type: ignore
@ -431,19 +445,13 @@ class OpenAITextCompletion(BaseLLM):
api_key=api_key, api_key=api_key,
additional_args={"headers": headers, "api_base": api_base, "data": data}, additional_args={"headers": headers, "api_base": api_base, "data": data},
) )
if acompletion == True:
if "stream" in optional_params and optional_params["stream"] == True: if optional_params.get("stream", False):
response = self._client_session.post( return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model)
url=f"{api_base}", else:
json=data, return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model)
headers=headers, elif optional_params.get("stream", False):
stream=optional_params["stream"] return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model)
)
if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text)
## RESPONSE OBJECT
return response.iter_lines()
else: else:
response = self._client_session.post( response = self._client_session.post(
url=f"{api_base}", url=f"{api_base}",
@ -451,7 +459,7 @@ class OpenAITextCompletion(BaseLLM):
headers=headers, headers=headers,
) )
if response.status_code != 200: if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text) raise OpenAIError(status_code=response.status_code, message=response.text, request=self._client_session.request, response=response)
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
@ -466,12 +474,76 @@ class OpenAITextCompletion(BaseLLM):
## RESPONSE OBJECT ## RESPONSE OBJECT
return self.convert_to_model_response_object(response_object=response.json(), model_response_object=model_response) return self.convert_to_model_response_object(response_object=response.json(), model_response_object=model_response)
except OpenAIError as e:
exception_mapping_worked = True
raise e
except Exception as e: except Exception as e:
if exception_mapping_worked: raise e
raise e
else: async def acompletion(self,
import traceback logging_obj,
raise OpenAIError(status_code=500, message=traceback.format_exc()) api_base: str,
data: dict,
headers: dict,
model_response: ModelResponse,
prompt: str,
api_key: str,
model: str):
async with httpx.AsyncClient() as client:
response = await client.post(api_base, json=data, headers=headers)
response_json = response.json()
if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text)
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
original_response=response,
additional_args={
"headers": headers,
"api_base": api_base,
},
)
## RESPONSE OBJECT
return self.convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
def streaming(self,
logging_obj,
api_base: str,
data: dict,
headers: dict,
model_response: ModelResponse,
model: str
):
with self._client_session.stream(
url=f"{api_base}",
json=data,
headers=headers,
method="POST"
) as response:
if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text(), request=self._client_session.request, response=response)
streamwrapper = CustomStreamWrapper(completion_stream=response.iter_lines(), model=model, custom_llm_provider="text-completion-openai",logging_obj=logging_obj)
for transformed_chunk in streamwrapper:
yield transformed_chunk
async def async_streaming(self,
logging_obj,
api_base: str,
data: dict,
headers: dict,
model_response: ModelResponse,
model: str):
client = httpx.AsyncClient()
async with client.stream(
url=f"{api_base}",
json=data,
headers=headers,
method="POST"
) as response:
if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text(), request=self._client_session.request, response=response)
streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="text-completion-openai",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper:
yield transformed_chunk

View file

@ -138,8 +138,10 @@ async def acompletion(*args, **kwargs):
_, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None)) _, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None))
if (custom_llm_provider == "openai"
if (custom_llm_provider == "openai" or custom_llm_provider == "azure" or custom_llm_provider == "custom_openai"): # currently implemented aiohttp calls for just azure and openai, soon all. or custom_llm_provider == "azure"
or custom_llm_provider == "custom_openai"
or custom_llm_provider == "text-completion-openai"): # currently implemented aiohttp calls for just azure and openai, soon all.
if kwargs.get("stream", False): if kwargs.get("stream", False):
response = completion(*args, **kwargs) response = completion(*args, **kwargs)
else: else:
@ -161,7 +163,7 @@ async def acompletion(*args, **kwargs):
line line
async for line in response async for line in response
) )
else: else:
end_time = datetime.datetime.now() end_time = datetime.datetime.now()
# [OPTIONAL] ADD TO CACHE # [OPTIONAL] ADD TO CACHE
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
@ -596,15 +598,16 @@ def completion(
print_verbose=print_verbose, print_verbose=print_verbose,
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
acompletion=acompletion,
logging_obj=logging, logging_obj=logging,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params, litellm_params=litellm_params,
logger_fn=logger_fn logger_fn=logger_fn
) )
if "stream" in optional_params and optional_params["stream"] == True: # if "stream" in optional_params and optional_params["stream"] == True:
response = CustomStreamWrapper(model_response, model, custom_llm_provider="text-completion-openai", logging_obj=logging) # response = CustomStreamWrapper(model_response, model, custom_llm_provider="text-completion-openai", logging_obj=logging)
return response # return response
response = model_response response = model_response
elif ( elif (
"replicate" in model or "replicate" in model or

View file

@ -28,15 +28,13 @@ def test_async_response():
user_message = "Hello, how are you?" user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:
response = await acompletion(model="gpt-3.5-turbo", messages=messages) response = await acompletion(model="gpt-3.5-turbo-instruct", messages=messages)
print(f"response: {response}")
response = await acompletion(model="azure/chatgpt-v-2", messages=messages)
print(f"response: {response}") print(f"response: {response}")
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
response = asyncio.run(test_get_response()) response = asyncio.run(test_get_response())
# print(response) print(response)
# test_async_response() # test_async_response()
def test_get_response_streaming(): def test_get_response_streaming():
@ -45,8 +43,7 @@ def test_get_response_streaming():
user_message = "write a short poem in one sentence" user_message = "write a short poem in one sentence"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:
response = await acompletion(model="azure/chatgpt-v-2", messages=messages, stream=True) response = await acompletion(model="gpt-3.5-turbo-instruct", messages=messages, stream=True)
# response = await acompletion(model="gpt-3.5-turbo", messages=messages, stream=True)
print(type(response)) print(type(response))
import inspect import inspect
@ -59,18 +56,17 @@ def test_get_response_streaming():
async for chunk in response: async for chunk in response:
token = chunk["choices"][0]["delta"].get("content", "") token = chunk["choices"][0]["delta"].get("content", "")
output += token output += token
print(f"output: {output}")
assert output is not None, "output cannot be None." assert output is not None, "output cannot be None."
assert isinstance(output, str), "output needs to be of type str" assert isinstance(output, str), "output needs to be of type str"
assert len(output) > 0, f"Length of output needs to be greater than 0. {output}" assert len(output) > 0, "Length of output needs to be greater than 0."
print(f'output: {output}')
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
return response return response
asyncio.run(test_async_call()) asyncio.run(test_async_call())
test_get_response_streaming() # test_get_response_streaming()
def test_get_response_non_openai_streaming(): def test_get_response_non_openai_streaming():
import asyncio import asyncio

View file

@ -9,7 +9,7 @@ sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import pytest import pytest
from openai.error import Timeout from openai import Timeout
import litellm import litellm
from litellm import embedding, completion, completion_cost from litellm import embedding, completion, completion_cost
from litellm import RateLimitError from litellm import RateLimitError
@ -405,7 +405,6 @@ def test_completion_openai():
litellm.api_key = os.environ['OPENAI_API_KEY'] litellm.api_key = os.environ['OPENAI_API_KEY']
response = completion(model="gpt-3.5-turbo", messages=messages, max_tokens=10, request_timeout=10) response = completion(model="gpt-3.5-turbo", messages=messages, max_tokens=10, request_timeout=10)
print("This is the response object\n", response) print("This is the response object\n", response)
print("\n\nThis is response ms:", response.response_ms)
response_str = response["choices"][0]["message"]["content"] response_str = response["choices"][0]["message"]["content"]
@ -422,14 +421,15 @@ def test_completion_openai():
pass pass
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_completion_openai() test_completion_openai()
def test_completion_text_openai(): def test_completion_text_openai():
try: try:
litellm.set_verbose = True # litellm.set_verbose = True
response = completion(model="gpt-3.5-turbo-instruct", messages=messages) response = completion(model="gpt-3.5-turbo-instruct", messages=messages)
print(response) print(response["choices"][0]["message"]["content"])
except Exception as e: except Exception as e:
print(e)
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_completion_text_openai() # test_completion_text_openai()

View file

@ -14,7 +14,7 @@ import litellm
from litellm import completion_with_retries, completion from litellm import completion_with_retries, completion
from litellm import ( from litellm import (
AuthenticationError, AuthenticationError,
InvalidRequestError, BadRequestError,
RateLimitError, RateLimitError,
ServiceUnavailableError, ServiceUnavailableError,
OpenAIError, OpenAIError,

View file

@ -1,4 +1,8 @@
from openai.error import AuthenticationError, InvalidRequestError, RateLimitError, OpenAIError try:
from openai import AuthenticationError, BadRequestError, RateLimitError, OpenAIError
except:
from openai.error import AuthenticationError, InvalidRequestError, RateLimitError, OpenAIError
import os import os
import sys import sys
import traceback import traceback
@ -38,23 +42,24 @@ models = ["command-nightly"]
# Test 1: Context Window Errors # Test 1: Context Window Errors
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
def test_context_window(model): def test_context_window(model):
sample_text = "Say error 50 times" * 100000 sample_text = "Say error 50 times" * 10000
messages = [{"content": sample_text, "role": "user"}] messages = [{"content": sample_text, "role": "user"}]
print(f"model: {model}")
try: try:
completion(model=model, messages=messages) response = completion(model=model, messages=messages)
print(f"response: {response}")
print("FAILED!")
pytest.fail(f"An exception occurred") pytest.fail(f"An exception occurred")
except ContextWindowExceededError: except ContextWindowExceededError as e:
pass print(f"Worked!")
except RateLimitError: except RateLimitError:
pass print("RateLimited!")
except Exception as e: except Exception as e:
print(f"{e}") print(f"{e}")
pytest.fail(f"An error occcurred - {e}") pytest.fail(f"An error occcurred - {e}")
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
def test_context_window_with_fallbacks(model): def test_context_window_with_fallbacks(model):
ctx_window_fallback_dict = {"command-nightly": "claude-2"} ctx_window_fallback_dict = {"command-nightly": "claude-2", "gpt-3.5-turbo-instruct": "gpt-3.5-turbo-16k"}
sample_text = "how does a court case get to the Supreme Court?" * 1000 sample_text = "how does a court case get to the Supreme Court?" * 1000
messages = [{"content": sample_text, "role": "user"}] messages = [{"content": sample_text, "role": "user"}]
@ -62,8 +67,8 @@ def test_context_window_with_fallbacks(model):
# for model in litellm.models_by_provider["bedrock"]: # for model in litellm.models_by_provider["bedrock"]:
# test_context_window(model=model) # test_context_window(model=model)
# test_context_window(model="gpt-3.5-turbo-instruct") # test_context_window(model="gpt-3.5-turbo")
# test_context_window_with_fallbacks(model="command-nightly") # test_context_window_with_fallbacks(model="gpt-3.5-turbo")
# Test 2: InvalidAuth Errors # Test 2: InvalidAuth Errors
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
def invalid_auth(model): # set the model key to an invalid key, depending on the model def invalid_auth(model): # set the model key to an invalid key, depending on the model
@ -158,14 +163,14 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
# for model in litellm.models_by_provider["bedrock"]: # for model in litellm.models_by_provider["bedrock"]:
# invalid_auth(model=model) # invalid_auth(model=model)
# invalid_auth(model="gpt-3.5-turbo-instruct") # invalid_auth(model="gpt-3.5-turbo")
# Test 3: Invalid Request Error # Test 3: Invalid Request Error
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
def test_invalid_request_error(model): def test_invalid_request_error(model):
messages = [{"content": "hey, how's it going?", "role": "user"}] messages = [{"content": "hey, how's it going?", "role": "user"}]
with pytest.raises(InvalidRequestError): with pytest.raises(BadRequestError):
completion(model=model, messages=messages, max_tokens="hello world") completion(model=model, messages=messages, max_tokens="hello world")
# test_invalid_request_error(model="gpt-3.5-turbo") # test_invalid_request_error(model="gpt-3.5-turbo")
@ -178,15 +183,16 @@ def test_invalid_request_error(model):
# response = completion(model=model, messages=messages) # response = completion(model=model, messages=messages)
# except RateLimitError: # except RateLimitError:
# return True # return True
# except OpenAIError: # is at least an openai error -> in case of random model errors - e.g. overloaded server # # except OpenAIError: # is at least an openai error -> in case of random model errors - e.g. overloaded server
# return True # # return True
# except Exception as e: # except Exception as e:
# print(f"Uncaught Exception {model}: {type(e).__name__} - {e}") # print(f"Uncaught Exception {model}: {type(e).__name__} - {e}")
# traceback.print_exc() # traceback.print_exc()
# pass # pass
# return False # return False
# # Repeat each model 500 times # # Repeat each model 500 times
# extended_models = [model for model in models for _ in range(250)] # # extended_models = [model for model in models for _ in range(250)]
# extended_models = ["gpt-3.5-turbo-instruct" for _ in range(250)]
# def worker(model): # def worker(model):
# return test_model_call(model) # return test_model_call(model)

View file

@ -11,7 +11,7 @@ sys.path.insert(
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
import litellm import litellm
from litellm import completion, acompletion, AuthenticationError, InvalidRequestError, RateLimitError from litellm import completion, acompletion, AuthenticationError, BadRequestError, RateLimitError, ModelResponse
litellm.logging = False litellm.logging = False
litellm.set_verbose = False litellm.set_verbose = False
@ -47,38 +47,17 @@ first_openai_chunk_example = {
def validate_first_format(chunk): def validate_first_format(chunk):
# write a test to make sure chunk follows the same format as first_openai_chunk_example # write a test to make sure chunk follows the same format as first_openai_chunk_example
assert isinstance(chunk, dict), "Chunk should be a dictionary." assert isinstance(chunk, ModelResponse), "Chunk should be a dictionary."
assert "id" in chunk, "Chunk should have an 'id'."
assert isinstance(chunk['id'], str), "'id' should be a string." assert isinstance(chunk['id'], str), "'id' should be a string."
assert "object" in chunk, "Chunk should have an 'object'."
assert isinstance(chunk['object'], str), "'object' should be a string." assert isinstance(chunk['object'], str), "'object' should be a string."
assert "created" in chunk, "Chunk should have a 'created'."
assert isinstance(chunk['created'], int), "'created' should be an integer." assert isinstance(chunk['created'], int), "'created' should be an integer."
assert "model" in chunk, "Chunk should have a 'model'."
assert isinstance(chunk['model'], str), "'model' should be a string." assert isinstance(chunk['model'], str), "'model' should be a string."
assert "choices" in chunk, "Chunk should have 'choices'."
assert isinstance(chunk['choices'], list), "'choices' should be a list." assert isinstance(chunk['choices'], list), "'choices' should be a list."
for choice in chunk['choices']: for choice in chunk['choices']:
assert isinstance(choice, dict), "Each choice should be a dictionary."
assert "index" in choice, "Each choice should have 'index'."
assert isinstance(choice['index'], int), "'index' should be an integer." assert isinstance(choice['index'], int), "'index' should be an integer."
assert "delta" in choice, "Each choice should have 'delta'."
assert isinstance(choice['delta'], dict), "'delta' should be a dictionary."
assert "role" in choice['delta'], "'delta' should have a 'role'."
assert isinstance(choice['delta']['role'], str), "'role' should be a string." assert isinstance(choice['delta']['role'], str), "'role' should be a string."
assert "content" in choice['delta'], "'delta' should have 'content'."
assert isinstance(choice['delta']['content'], str), "'content' should be a string." assert isinstance(choice['delta']['content'], str), "'content' should be a string."
assert "finish_reason" in choice, "Each choice should have 'finish_reason'."
assert (choice['finish_reason'] is None) or isinstance(choice['finish_reason'], str), "'finish_reason' should be None or a string." assert (choice['finish_reason'] is None) or isinstance(choice['finish_reason'], str), "'finish_reason' should be None or a string."
second_openai_chunk_example = { second_openai_chunk_example = {
@ -98,35 +77,16 @@ second_openai_chunk_example = {
} }
def validate_second_format(chunk): def validate_second_format(chunk):
assert isinstance(chunk, dict), "Chunk should be a dictionary." assert isinstance(chunk, ModelResponse), "Chunk should be a dictionary."
assert "id" in chunk, "Chunk should have an 'id'."
assert isinstance(chunk['id'], str), "'id' should be a string." assert isinstance(chunk['id'], str), "'id' should be a string."
assert "object" in chunk, "Chunk should have an 'object'."
assert isinstance(chunk['object'], str), "'object' should be a string." assert isinstance(chunk['object'], str), "'object' should be a string."
assert "created" in chunk, "Chunk should have a 'created'."
assert isinstance(chunk['created'], int), "'created' should be an integer." assert isinstance(chunk['created'], int), "'created' should be an integer."
assert "model" in chunk, "Chunk should have a 'model'."
assert isinstance(chunk['model'], str), "'model' should be a string." assert isinstance(chunk['model'], str), "'model' should be a string."
assert "choices" in chunk, "Chunk should have 'choices'."
assert isinstance(chunk['choices'], list), "'choices' should be a list." assert isinstance(chunk['choices'], list), "'choices' should be a list."
for choice in chunk['choices']: for choice in chunk['choices']:
assert isinstance(choice, dict), "Each choice should be a dictionary."
assert "index" in choice, "Each choice should have 'index'."
assert isinstance(choice['index'], int), "'index' should be an integer." assert isinstance(choice['index'], int), "'index' should be an integer."
assert "delta" in choice, "Each choice should have 'delta'."
assert isinstance(choice['delta'], dict), "'delta' should be a dictionary."
assert "content" in choice['delta'], "'delta' should have 'content'."
assert isinstance(choice['delta']['content'], str), "'content' should be a string." assert isinstance(choice['delta']['content'], str), "'content' should be a string."
assert "finish_reason" in choice, "Each choice should have 'finish_reason'."
assert (choice['finish_reason'] is None) or isinstance(choice['finish_reason'], str), "'finish_reason' should be None or a string." assert (choice['finish_reason'] is None) or isinstance(choice['finish_reason'], str), "'finish_reason' should be None or a string."
last_openai_chunk_example = { last_openai_chunk_example = {
@ -144,32 +104,15 @@ last_openai_chunk_example = {
} }
def validate_last_format(chunk): def validate_last_format(chunk):
assert isinstance(chunk, dict), "Chunk should be a dictionary." assert isinstance(chunk, ModelResponse), "Chunk should be a dictionary."
assert "id" in chunk, "Chunk should have an 'id'."
assert isinstance(chunk['id'], str), "'id' should be a string." assert isinstance(chunk['id'], str), "'id' should be a string."
assert "object" in chunk, "Chunk should have an 'object'."
assert isinstance(chunk['object'], str), "'object' should be a string." assert isinstance(chunk['object'], str), "'object' should be a string."
assert "created" in chunk, "Chunk should have a 'created'."
assert isinstance(chunk['created'], int), "'created' should be an integer." assert isinstance(chunk['created'], int), "'created' should be an integer."
assert "model" in chunk, "Chunk should have a 'model'."
assert isinstance(chunk['model'], str), "'model' should be a string." assert isinstance(chunk['model'], str), "'model' should be a string."
assert "choices" in chunk, "Chunk should have 'choices'."
assert isinstance(chunk['choices'], list), "'choices' should be a list." assert isinstance(chunk['choices'], list), "'choices' should be a list."
for choice in chunk['choices']: for choice in chunk['choices']:
assert isinstance(choice, dict), "Each choice should be a dictionary."
assert "index" in choice, "Each choice should have 'index'."
assert isinstance(choice['index'], int), "'index' should be an integer." assert isinstance(choice['index'], int), "'index' should be an integer."
assert "delta" in choice, "Each choice should have 'delta'."
assert isinstance(choice['delta'], dict), "'delta' should be a dictionary."
assert "finish_reason" in choice, "Each choice should have 'finish_reason'."
assert isinstance(choice['finish_reason'], str), "'finish_reason' should be a string." assert isinstance(choice['finish_reason'], str), "'finish_reason' should be a string."
def streaming_format_tests(idx, chunk): def streaming_format_tests(idx, chunk):
@ -188,6 +131,7 @@ def streaming_format_tests(idx, chunk):
if chunk["choices"][0]["finish_reason"]: # ensure finish reason is only in last chunk if chunk["choices"][0]["finish_reason"]: # ensure finish reason is only in last chunk
validate_last_format(chunk=chunk) validate_last_format(chunk=chunk)
finished = True finished = True
print(f"chunk choices: {chunk['choices'][0]['delta']['content']}")
if "content" in chunk["choices"][0]["delta"]: if "content" in chunk["choices"][0]["delta"]:
extracted_chunk = chunk["choices"][0]["delta"]["content"] extracted_chunk = chunk["choices"][0]["delta"]["content"]
print(f"extracted chunk: {extracted_chunk}") print(f"extracted chunk: {extracted_chunk}")
@ -549,7 +493,7 @@ def test_completion_claude_stream_bad_key():
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
test_completion_claude_stream_bad_key() # test_completion_claude_stream_bad_key()
# test_completion_replicate_stream() # test_completion_replicate_stream()
# def test_completion_vertexai_stream(): # def test_completion_vertexai_stream():
@ -824,7 +768,7 @@ def ai21_completion_call_bad_key():
if complete_response.strip() == "": if complete_response.strip() == "":
raise Exception("Empty response received") raise Exception("Empty response received")
print(f"completion_response: {complete_response}") print(f"completion_response: {complete_response}")
except InvalidRequestError as e: except Bad as e:
pass pass
except: except:
pytest.fail(f"error occurred: {traceback.format_exc()}") pytest.fail(f"error occurred: {traceback.format_exc()}")
@ -885,7 +829,7 @@ def ai21_completion_call_bad_key():
# test on openai completion call # test on openai completion call
def test_openai_chat_completion_call(): def test_openai_chat_completion_call():
try: try:
litellm.set_verbose = True litellm.set_verbose = False
response = completion( response = completion(
model="gpt-3.5-turbo", messages=messages, stream=True model="gpt-3.5-turbo", messages=messages, stream=True
) )
@ -904,7 +848,7 @@ def test_openai_chat_completion_call():
print(f"error occurred: {traceback.format_exc()}") print(f"error occurred: {traceback.format_exc()}")
pass pass
# test_openai_chat_completion_call() test_openai_chat_completion_call()
def test_openai_chat_completion_complete_response_call(): def test_openai_chat_completion_complete_response_call():
try: try:
@ -928,6 +872,7 @@ def test_openai_text_completion_call():
start_time = time.time() start_time = time.time()
for idx, chunk in enumerate(response): for idx, chunk in enumerate(response):
chunk, finished = streaming_format_tests(idx, chunk) chunk, finished = streaming_format_tests(idx, chunk)
print(f"chunk: {chunk}")
complete_response += chunk complete_response += chunk
if finished: if finished:
break break
@ -939,6 +884,8 @@ def test_openai_text_completion_call():
print(f"error occurred: {traceback.format_exc()}") print(f"error occurred: {traceback.format_exc()}")
pass pass
# test_openai_text_completion_call()
# # test on together ai completion call - starcoder # # test on together ai completion call - starcoder
def test_together_ai_completion_call_starcoder(): def test_together_ai_completion_call_starcoder():
try: try:
@ -992,7 +939,7 @@ def test_together_ai_completion_call_starcoder_bad_key():
if complete_response == "": if complete_response == "":
raise Exception("Empty response received") raise Exception("Empty response received")
print(f"complete response: {complete_response}") print(f"complete response: {complete_response}")
except InvalidRequestError as e: except BadRequestError as e:
pass pass
except: except:
print(f"error occurred: {traceback.format_exc()}") print(f"error occurred: {traceback.format_exc()}")

View file

@ -17,7 +17,10 @@ from concurrent import futures
from inspect import iscoroutinefunction from inspect import iscoroutinefunction
from functools import wraps from functools import wraps
from threading import Thread from threading import Thread
from openai.error import Timeout try:
from openai import Timeout
except:
from openai.error import Timeout
def timeout(timeout_duration: float = 0.0, exception_to_raise=Timeout): def timeout(timeout_duration: float = 0.0, exception_to_raise=Timeout):

File diff suppressed because it is too large Load diff