refactor(openai.py): making it compatible for openai v1

BREAKING CHANGE:
This commit is contained in:
Krrish Dholakia 2023-11-11 15:32:14 -08:00
parent 833c38edeb
commit d3323ba637
12 changed files with 622 additions and 370 deletions

View file

@ -375,6 +375,7 @@ from .integrations import *
from .exceptions import (
AuthenticationError,
InvalidRequestError,
BadRequestError,
RateLimitError,
ServiceUnavailableError,
OpenAIError,

View file

@ -8,75 +8,82 @@
# Thank you users! We ❤️ you! - Krrish & Ishaan
## LiteLLM versions of the OpenAI Exception Types
from openai.error import (
from openai import (
AuthenticationError,
InvalidRequestError,
BadRequestError,
RateLimitError,
ServiceUnavailableError,
APIStatusError,
OpenAIError,
APIError,
Timeout,
APITimeoutError,
APIConnectionError,
)
import httpx
class AuthenticationError(AuthenticationError): # type: ignore
def __init__(self, message, llm_provider, model):
def __init__(self, message, llm_provider, model, response: httpx.Response):
self.status_code = 401
self.message = message
self.llm_provider = llm_provider
self.model = model
super().__init__(
self.message
self.message,
response=response,
body=None
) # Call the base class constructor with the parameters it needs
class InvalidRequestError(InvalidRequestError): # type: ignore
def __init__(self, message, model, llm_provider):
class BadRequestError(BadRequestError): # type: ignore
def __init__(self, message, model, llm_provider, response: httpx.Response):
self.status_code = 400
self.message = message
self.model = model
self.llm_provider = llm_provider
super().__init__(
self.message, f"{self.model}"
self.message,
response=response,
body=None
) # Call the base class constructor with the parameters it needs
class Timeout(Timeout): # type: ignore
def __init__(self, message, model, llm_provider):
class Timeout(APITimeoutError): # type: ignore
def __init__(self, message, model, llm_provider, request: httpx.Request):
self.status_code = 408
self.message = message
self.model = model
self.llm_provider = llm_provider
super().__init__(
self.message, f"{self.model}"
request=request
) # Call the base class constructor with the parameters it needs
# sub class of invalid request error - meant to give more granularity for error handling context window exceeded errors
class ContextWindowExceededError(InvalidRequestError): # type: ignore
def __init__(self, message, model, llm_provider):
self.status_code = 400
self.message = message
self.model = model
self.llm_provider = llm_provider
super().__init__(
self.message, self.model, self.llm_provider
) # Call the base class constructor with the parameters it needs
class RateLimitError(RateLimitError): # type: ignore
def __init__(self, message, llm_provider, model):
def __init__(self, message, llm_provider, model, response: httpx.Response):
self.status_code = 429
self.message = message
self.llm_provider = llm_provider
self.modle = model
super().__init__(
self.message
self.message,
response=response,
body=None
) # Call the base class constructor with the parameters it needs
# sub class of rate limit error - meant to give more granularity for error handling context window exceeded errors
class ContextWindowExceededError(BadRequestError): # type: ignore
def __init__(self, message, model, llm_provider, response: httpx.Response):
self.status_code = 400
self.message = message
self.model = model
self.llm_provider = llm_provider
super().__init__(
message=self.message,
model=self.model,
llm_provider=self.llm_provider,
response=response
) # Call the base class constructor with the parameters it needs
class ServiceUnavailableError(ServiceUnavailableError): # type: ignore
class ServiceUnavailableError(APIStatusError): # type: ignore
def __init__(self, message, llm_provider, model):
self.status_code = 500
self.status_code = 503
self.message = message
self.llm_provider = llm_provider
self.model = model
@ -87,13 +94,14 @@ class ServiceUnavailableError(ServiceUnavailableError): # type: ignore
# raise this when the API returns an invalid response object - https://github.com/openai/openai-python/blob/1be14ee34a0f8e42d3f9aa5451aa4cb161f1781f/openai/api_requestor.py#L401
class APIError(APIError): # type: ignore
def __init__(self, status_code, message, llm_provider, model):
def __init__(self, status_code, message, llm_provider, model, request: httpx.Request):
self.status_code = status_code
self.message = message
self.llm_provider = llm_provider
self.model = model
super().__init__(
self.message
self.message,
request=request
)
# raised if an invalid request (not get, delete, put, post) is made
@ -123,4 +131,15 @@ class BudgetExceededError(Exception):
self.current_cost = current_cost
self.max_budget = max_budget
message = f"Budget has been exceeded! Current cost: {current_cost}, Max budget: {max_budget}"
super().__init__(message)
super().__init__(message)
## DEPRECATED ##
class InvalidRequestError(BadRequestError): # type: ignore
def __init__(self, message, model, llm_provider):
self.status_code = 400
self.message = message
self.model = model
self.llm_provider = llm_provider
super().__init__(
self.message, f"{self.model}"
) # Call the base class constructor with the parameters it needs

View file

@ -1,16 +1,21 @@
## This is a template base class to be used for adding new LLM providers via API calls
import litellm
import requests, certifi, ssl
import httpx, certifi, ssl
class BaseLLM:
_client_session = None
def create_client_session(self):
if litellm.client_session:
session = litellm.client_session
_client_session = litellm.client_session
else:
session = requests.Session()
return session
_client_session = httpx.Client(timeout=600)
return _client_session
def __exit__(self):
if hasattr(self, '_client_session'):
self._client_session.close()
def validate_environment(self): # set up the environment required to run the model
pass

View file

@ -1,14 +1,17 @@
from typing import Optional, Union
import types, requests
import types
import httpx
from .base import BaseLLM
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, convert_to_model_response_object
from typing import Callable, Optional
import aiohttp
class OpenAIError(Exception):
def __init__(self, status_code, message):
def __init__(self, status_code, message, request: httpx.Request, response: httpx.Response):
self.status_code = status_code
self.message = message
self.request = request
self.response = response
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
@ -144,7 +147,7 @@ class OpenAITextCompletionConfig():
and v is not None}
class OpenAIChatCompletion(BaseLLM):
_client_session: requests.Session
_client_session: httpx.Client
def __init__(self) -> None:
super().__init__()
@ -200,18 +203,8 @@ class OpenAIChatCompletion(BaseLLM):
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model)
else:
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response)
elif "stream" in optional_params and optional_params["stream"] == True:
response = self._client_session.post(
url=api_base,
json=data,
headers=headers,
stream=optional_params["stream"]
)
if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text)
## RESPONSE OBJECT
return response.iter_lines()
elif optional_params.get("stream", False):
return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model)
else:
response = self._client_session.post(
url=api_base,
@ -219,7 +212,7 @@ class OpenAIChatCompletion(BaseLLM):
headers=headers,
)
if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text)
raise OpenAIError(status_code=response.status_code, message=response.text, request=response.request, response=response)
## RESPONSE OBJECT
return convert_to_model_response_object(response_object=response.json(), model_response_object=model_response)
@ -246,41 +239,64 @@ class OpenAIChatCompletion(BaseLLM):
exception_mapping_worked = True
raise e
except Exception as e:
if exception_mapping_worked:
raise e
else:
import traceback
raise OpenAIError(status_code=500, message=traceback.format_exc())
raise e
async def acompletion(self,
api_base: str,
data: dict, headers: dict,
model_response: ModelResponse):
async with aiohttp.ClientSession() as session:
async with session.post(api_base, json=data, headers=headers, ssl=False) as response:
response_json = await response.json()
if response.status != 200:
raise OpenAIError(status_code=response.status, message=response.text)
async with httpx.AsyncClient() as client:
response = await client.post(api_base, json=data, headers=headers)
response_json = response.json()
if response.status != 200:
raise OpenAIError(status_code=response.status, message=response.text)
## RESPONSE OBJECT
return convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
## RESPONSE OBJECT
return convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
def streaming(self,
logging_obj,
api_base: str,
data: dict,
headers: dict,
model_response: ModelResponse,
model: str
):
with self._client_session.stream(
url=f"{api_base}",
json=data,
headers=headers,
method="POST"
) as response:
if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text(), request=self._client_session.request, response=response)
completion_stream = response.iter_lines()
streamwrapper = CustomStreamWrapper(completion_stream=completion_stream, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
for transformed_chunk in streamwrapper:
yield transformed_chunk
async def async_streaming(self,
logging_obj,
api_base: str,
data: dict, headers: dict,
data: dict,
headers: dict,
model_response: ModelResponse,
model: str):
async with aiohttp.ClientSession() as session:
async with session.post(api_base, json=data, headers=headers, ssl=False) as response:
# Check if the request was successful (status code 200)
if response.status != 200:
raise OpenAIError(status_code=response.status, message=await response.text())
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
client = httpx.AsyncClient()
async with client.stream(
url=f"{api_base}",
json=data,
headers=headers,
method="POST"
) as response:
if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text(), request=self._client_session.request, response=response)
streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="openai",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
def embedding(self,
model: str,
@ -349,7 +365,7 @@ class OpenAIChatCompletion(BaseLLM):
class OpenAITextCompletion(BaseLLM):
_client_session: requests.Session
_client_session: httpx.Client
def __init__(self) -> None:
super().__init__()
@ -367,7 +383,7 @@ class OpenAITextCompletion(BaseLLM):
try:
## RESPONSE OBJECT
if response_object is None or model_response_object is None:
raise OpenAIError(status_code=500, message="Error in response object format")
raise ValueError(message="Error in response object format")
choice_list=[]
for idx, choice in enumerate(response_object["choices"]):
message = Message(content=choice["text"], role="assistant")
@ -386,8 +402,8 @@ class OpenAITextCompletion(BaseLLM):
model_response_object._hidden_params["original_response"] = response_object # track original response, if users make a litellm.text_completion() request, we can return the original response
return model_response_object
except:
OpenAIError(status_code=500, message="Invalid response object.")
except Exception as e:
raise e
def completion(self,
model: Optional[str]=None,
@ -397,6 +413,7 @@ class OpenAITextCompletion(BaseLLM):
api_key: Optional[str]=None,
api_base: Optional[str]=None,
logging_obj=None,
acompletion: bool = False,
optional_params=None,
litellm_params=None,
logger_fn=None,
@ -412,9 +429,6 @@ class OpenAITextCompletion(BaseLLM):
api_base = f"{api_base}/completions"
if len(messages)>0 and "content" in messages[0] and type(messages[0]["content"]) == list:
# Note: internal logic - for enabling litellm.text_completion()
# text-davinci-003 can accept a string or array, if it's an array, assume the array is set in messages[0]['content']
# https://platform.openai.com/docs/api-reference/completions/create
prompt = messages[0]["content"]
else:
prompt = " ".join([message["content"] for message in messages]) # type: ignore
@ -431,19 +445,13 @@ class OpenAITextCompletion(BaseLLM):
api_key=api_key,
additional_args={"headers": headers, "api_base": api_base, "data": data},
)
if "stream" in optional_params and optional_params["stream"] == True:
response = self._client_session.post(
url=f"{api_base}",
json=data,
headers=headers,
stream=optional_params["stream"]
)
if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text)
## RESPONSE OBJECT
return response.iter_lines()
if acompletion == True:
if optional_params.get("stream", False):
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model)
else:
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model)
elif optional_params.get("stream", False):
return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, headers=headers, model_response=model_response, model=model)
else:
response = self._client_session.post(
url=f"{api_base}",
@ -451,7 +459,7 @@ class OpenAITextCompletion(BaseLLM):
headers=headers,
)
if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text)
raise OpenAIError(status_code=response.status_code, message=response.text, request=self._client_session.request, response=response)
## LOGGING
logging_obj.post_call(
@ -466,12 +474,76 @@ class OpenAITextCompletion(BaseLLM):
## RESPONSE OBJECT
return self.convert_to_model_response_object(response_object=response.json(), model_response_object=model_response)
except OpenAIError as e:
exception_mapping_worked = True
raise e
except Exception as e:
if exception_mapping_worked:
raise e
else:
import traceback
raise OpenAIError(status_code=500, message=traceback.format_exc())
raise e
async def acompletion(self,
logging_obj,
api_base: str,
data: dict,
headers: dict,
model_response: ModelResponse,
prompt: str,
api_key: str,
model: str):
async with httpx.AsyncClient() as client:
response = await client.post(api_base, json=data, headers=headers)
response_json = response.json()
if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text)
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
original_response=response,
additional_args={
"headers": headers,
"api_base": api_base,
},
)
## RESPONSE OBJECT
return self.convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
def streaming(self,
logging_obj,
api_base: str,
data: dict,
headers: dict,
model_response: ModelResponse,
model: str
):
with self._client_session.stream(
url=f"{api_base}",
json=data,
headers=headers,
method="POST"
) as response:
if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text(), request=self._client_session.request, response=response)
streamwrapper = CustomStreamWrapper(completion_stream=response.iter_lines(), model=model, custom_llm_provider="text-completion-openai",logging_obj=logging_obj)
for transformed_chunk in streamwrapper:
yield transformed_chunk
async def async_streaming(self,
logging_obj,
api_base: str,
data: dict,
headers: dict,
model_response: ModelResponse,
model: str):
client = httpx.AsyncClient()
async with client.stream(
url=f"{api_base}",
json=data,
headers=headers,
method="POST"
) as response:
if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text(), request=self._client_session.request, response=response)
streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="text-completion-openai",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper:
yield transformed_chunk

View file

@ -138,8 +138,10 @@ async def acompletion(*args, **kwargs):
_, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None))
if (custom_llm_provider == "openai" or custom_llm_provider == "azure" or custom_llm_provider == "custom_openai"): # currently implemented aiohttp calls for just azure and openai, soon all.
if (custom_llm_provider == "openai"
or custom_llm_provider == "azure"
or custom_llm_provider == "custom_openai"
or custom_llm_provider == "text-completion-openai"): # currently implemented aiohttp calls for just azure and openai, soon all.
if kwargs.get("stream", False):
response = completion(*args, **kwargs)
else:
@ -161,7 +163,7 @@ async def acompletion(*args, **kwargs):
line
async for line in response
)
else:
else:
end_time = datetime.datetime.now()
# [OPTIONAL] ADD TO CACHE
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
@ -596,15 +598,16 @@ def completion(
print_verbose=print_verbose,
api_key=api_key,
api_base=api_base,
acompletion=acompletion,
logging_obj=logging,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn
)
if "stream" in optional_params and optional_params["stream"] == True:
response = CustomStreamWrapper(model_response, model, custom_llm_provider="text-completion-openai", logging_obj=logging)
return response
# if "stream" in optional_params and optional_params["stream"] == True:
# response = CustomStreamWrapper(model_response, model, custom_llm_provider="text-completion-openai", logging_obj=logging)
# return response
response = model_response
elif (
"replicate" in model or

View file

@ -28,15 +28,13 @@ def test_async_response():
user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
try:
response = await acompletion(model="gpt-3.5-turbo", messages=messages)
print(f"response: {response}")
response = await acompletion(model="azure/chatgpt-v-2", messages=messages)
response = await acompletion(model="gpt-3.5-turbo-instruct", messages=messages)
print(f"response: {response}")
except Exception as e:
pytest.fail(f"An exception occurred: {e}")
response = asyncio.run(test_get_response())
# print(response)
print(response)
# test_async_response()
def test_get_response_streaming():
@ -45,8 +43,7 @@ def test_get_response_streaming():
user_message = "write a short poem in one sentence"
messages = [{"content": user_message, "role": "user"}]
try:
response = await acompletion(model="azure/chatgpt-v-2", messages=messages, stream=True)
# response = await acompletion(model="gpt-3.5-turbo", messages=messages, stream=True)
response = await acompletion(model="gpt-3.5-turbo-instruct", messages=messages, stream=True)
print(type(response))
import inspect
@ -59,18 +56,17 @@ def test_get_response_streaming():
async for chunk in response:
token = chunk["choices"][0]["delta"].get("content", "")
output += token
print(f"output: {output}")
assert output is not None, "output cannot be None."
assert isinstance(output, str), "output needs to be of type str"
assert len(output) > 0, f"Length of output needs to be greater than 0. {output}"
assert len(output) > 0, "Length of output needs to be greater than 0."
print(f'output: {output}')
except Exception as e:
pytest.fail(f"An exception occurred: {e}")
return response
asyncio.run(test_async_call())
test_get_response_streaming()
# test_get_response_streaming()
def test_get_response_non_openai_streaming():
import asyncio

View file

@ -9,7 +9,7 @@ sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
from openai.error import Timeout
from openai import Timeout
import litellm
from litellm import embedding, completion, completion_cost
from litellm import RateLimitError
@ -405,7 +405,6 @@ def test_completion_openai():
litellm.api_key = os.environ['OPENAI_API_KEY']
response = completion(model="gpt-3.5-turbo", messages=messages, max_tokens=10, request_timeout=10)
print("This is the response object\n", response)
print("\n\nThis is response ms:", response.response_ms)
response_str = response["choices"][0]["message"]["content"]
@ -422,14 +421,15 @@ def test_completion_openai():
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_openai()
test_completion_openai()
def test_completion_text_openai():
try:
litellm.set_verbose = True
# litellm.set_verbose = True
response = completion(model="gpt-3.5-turbo-instruct", messages=messages)
print(response)
print(response["choices"][0]["message"]["content"])
except Exception as e:
print(e)
pytest.fail(f"Error occurred: {e}")
# test_completion_text_openai()

View file

@ -14,7 +14,7 @@ import litellm
from litellm import completion_with_retries, completion
from litellm import (
AuthenticationError,
InvalidRequestError,
BadRequestError,
RateLimitError,
ServiceUnavailableError,
OpenAIError,

View file

@ -1,4 +1,8 @@
from openai.error import AuthenticationError, InvalidRequestError, RateLimitError, OpenAIError
try:
from openai import AuthenticationError, BadRequestError, RateLimitError, OpenAIError
except:
from openai.error import AuthenticationError, InvalidRequestError, RateLimitError, OpenAIError
import os
import sys
import traceback
@ -38,23 +42,24 @@ models = ["command-nightly"]
# Test 1: Context Window Errors
@pytest.mark.parametrize("model", models)
def test_context_window(model):
sample_text = "Say error 50 times" * 100000
sample_text = "Say error 50 times" * 10000
messages = [{"content": sample_text, "role": "user"}]
print(f"model: {model}")
try:
completion(model=model, messages=messages)
response = completion(model=model, messages=messages)
print(f"response: {response}")
print("FAILED!")
pytest.fail(f"An exception occurred")
except ContextWindowExceededError:
pass
except ContextWindowExceededError as e:
print(f"Worked!")
except RateLimitError:
pass
print("RateLimited!")
except Exception as e:
print(f"{e}")
pytest.fail(f"An error occcurred - {e}")
@pytest.mark.parametrize("model", models)
def test_context_window_with_fallbacks(model):
ctx_window_fallback_dict = {"command-nightly": "claude-2"}
ctx_window_fallback_dict = {"command-nightly": "claude-2", "gpt-3.5-turbo-instruct": "gpt-3.5-turbo-16k"}
sample_text = "how does a court case get to the Supreme Court?" * 1000
messages = [{"content": sample_text, "role": "user"}]
@ -62,8 +67,8 @@ def test_context_window_with_fallbacks(model):
# for model in litellm.models_by_provider["bedrock"]:
# test_context_window(model=model)
# test_context_window(model="gpt-3.5-turbo-instruct")
# test_context_window_with_fallbacks(model="command-nightly")
# test_context_window(model="gpt-3.5-turbo")
# test_context_window_with_fallbacks(model="gpt-3.5-turbo")
# Test 2: InvalidAuth Errors
@pytest.mark.parametrize("model", models)
def invalid_auth(model): # set the model key to an invalid key, depending on the model
@ -158,14 +163,14 @@ def invalid_auth(model): # set the model key to an invalid key, depending on th
# for model in litellm.models_by_provider["bedrock"]:
# invalid_auth(model=model)
# invalid_auth(model="gpt-3.5-turbo-instruct")
# invalid_auth(model="gpt-3.5-turbo")
# Test 3: Invalid Request Error
@pytest.mark.parametrize("model", models)
def test_invalid_request_error(model):
messages = [{"content": "hey, how's it going?", "role": "user"}]
with pytest.raises(InvalidRequestError):
with pytest.raises(BadRequestError):
completion(model=model, messages=messages, max_tokens="hello world")
# test_invalid_request_error(model="gpt-3.5-turbo")
@ -178,15 +183,16 @@ def test_invalid_request_error(model):
# response = completion(model=model, messages=messages)
# except RateLimitError:
# return True
# except OpenAIError: # is at least an openai error -> in case of random model errors - e.g. overloaded server
# return True
# # except OpenAIError: # is at least an openai error -> in case of random model errors - e.g. overloaded server
# # return True
# except Exception as e:
# print(f"Uncaught Exception {model}: {type(e).__name__} - {e}")
# traceback.print_exc()
# pass
# return False
# # Repeat each model 500 times
# extended_models = [model for model in models for _ in range(250)]
# # extended_models = [model for model in models for _ in range(250)]
# extended_models = ["gpt-3.5-turbo-instruct" for _ in range(250)]
# def worker(model):
# return test_model_call(model)

View file

@ -11,7 +11,7 @@ sys.path.insert(
from dotenv import load_dotenv
load_dotenv()
import litellm
from litellm import completion, acompletion, AuthenticationError, InvalidRequestError, RateLimitError
from litellm import completion, acompletion, AuthenticationError, BadRequestError, RateLimitError, ModelResponse
litellm.logging = False
litellm.set_verbose = False
@ -47,38 +47,17 @@ first_openai_chunk_example = {
def validate_first_format(chunk):
# write a test to make sure chunk follows the same format as first_openai_chunk_example
assert isinstance(chunk, dict), "Chunk should be a dictionary."
assert "id" in chunk, "Chunk should have an 'id'."
assert isinstance(chunk, ModelResponse), "Chunk should be a dictionary."
assert isinstance(chunk['id'], str), "'id' should be a string."
assert "object" in chunk, "Chunk should have an 'object'."
assert isinstance(chunk['object'], str), "'object' should be a string."
assert "created" in chunk, "Chunk should have a 'created'."
assert isinstance(chunk['created'], int), "'created' should be an integer."
assert "model" in chunk, "Chunk should have a 'model'."
assert isinstance(chunk['model'], str), "'model' should be a string."
assert "choices" in chunk, "Chunk should have 'choices'."
assert isinstance(chunk['choices'], list), "'choices' should be a list."
for choice in chunk['choices']:
assert isinstance(choice, dict), "Each choice should be a dictionary."
assert "index" in choice, "Each choice should have 'index'."
assert isinstance(choice['index'], int), "'index' should be an integer."
assert "delta" in choice, "Each choice should have 'delta'."
assert isinstance(choice['delta'], dict), "'delta' should be a dictionary."
assert "role" in choice['delta'], "'delta' should have a 'role'."
assert isinstance(choice['delta']['role'], str), "'role' should be a string."
assert "content" in choice['delta'], "'delta' should have 'content'."
assert isinstance(choice['delta']['content'], str), "'content' should be a string."
assert "finish_reason" in choice, "Each choice should have 'finish_reason'."
assert (choice['finish_reason'] is None) or isinstance(choice['finish_reason'], str), "'finish_reason' should be None or a string."
second_openai_chunk_example = {
@ -98,35 +77,16 @@ second_openai_chunk_example = {
}
def validate_second_format(chunk):
assert isinstance(chunk, dict), "Chunk should be a dictionary."
assert "id" in chunk, "Chunk should have an 'id'."
assert isinstance(chunk, ModelResponse), "Chunk should be a dictionary."
assert isinstance(chunk['id'], str), "'id' should be a string."
assert "object" in chunk, "Chunk should have an 'object'."
assert isinstance(chunk['object'], str), "'object' should be a string."
assert "created" in chunk, "Chunk should have a 'created'."
assert isinstance(chunk['created'], int), "'created' should be an integer."
assert "model" in chunk, "Chunk should have a 'model'."
assert isinstance(chunk['model'], str), "'model' should be a string."
assert "choices" in chunk, "Chunk should have 'choices'."
assert isinstance(chunk['choices'], list), "'choices' should be a list."
for choice in chunk['choices']:
assert isinstance(choice, dict), "Each choice should be a dictionary."
assert "index" in choice, "Each choice should have 'index'."
assert isinstance(choice['index'], int), "'index' should be an integer."
assert "delta" in choice, "Each choice should have 'delta'."
assert isinstance(choice['delta'], dict), "'delta' should be a dictionary."
assert "content" in choice['delta'], "'delta' should have 'content'."
assert isinstance(choice['delta']['content'], str), "'content' should be a string."
assert "finish_reason" in choice, "Each choice should have 'finish_reason'."
assert (choice['finish_reason'] is None) or isinstance(choice['finish_reason'], str), "'finish_reason' should be None or a string."
last_openai_chunk_example = {
@ -144,32 +104,15 @@ last_openai_chunk_example = {
}
def validate_last_format(chunk):
assert isinstance(chunk, dict), "Chunk should be a dictionary."
assert "id" in chunk, "Chunk should have an 'id'."
assert isinstance(chunk, ModelResponse), "Chunk should be a dictionary."
assert isinstance(chunk['id'], str), "'id' should be a string."
assert "object" in chunk, "Chunk should have an 'object'."
assert isinstance(chunk['object'], str), "'object' should be a string."
assert "created" in chunk, "Chunk should have a 'created'."
assert isinstance(chunk['created'], int), "'created' should be an integer."
assert "model" in chunk, "Chunk should have a 'model'."
assert isinstance(chunk['model'], str), "'model' should be a string."
assert "choices" in chunk, "Chunk should have 'choices'."
assert isinstance(chunk['choices'], list), "'choices' should be a list."
for choice in chunk['choices']:
assert isinstance(choice, dict), "Each choice should be a dictionary."
assert "index" in choice, "Each choice should have 'index'."
assert isinstance(choice['index'], int), "'index' should be an integer."
assert "delta" in choice, "Each choice should have 'delta'."
assert isinstance(choice['delta'], dict), "'delta' should be a dictionary."
assert "finish_reason" in choice, "Each choice should have 'finish_reason'."
assert isinstance(choice['finish_reason'], str), "'finish_reason' should be a string."
def streaming_format_tests(idx, chunk):
@ -188,6 +131,7 @@ def streaming_format_tests(idx, chunk):
if chunk["choices"][0]["finish_reason"]: # ensure finish reason is only in last chunk
validate_last_format(chunk=chunk)
finished = True
print(f"chunk choices: {chunk['choices'][0]['delta']['content']}")
if "content" in chunk["choices"][0]["delta"]:
extracted_chunk = chunk["choices"][0]["delta"]["content"]
print(f"extracted chunk: {extracted_chunk}")
@ -549,7 +493,7 @@ def test_completion_claude_stream_bad_key():
pytest.fail(f"Error occurred: {e}")
test_completion_claude_stream_bad_key()
# test_completion_claude_stream_bad_key()
# test_completion_replicate_stream()
# def test_completion_vertexai_stream():
@ -824,7 +768,7 @@ def ai21_completion_call_bad_key():
if complete_response.strip() == "":
raise Exception("Empty response received")
print(f"completion_response: {complete_response}")
except InvalidRequestError as e:
except Bad as e:
pass
except:
pytest.fail(f"error occurred: {traceback.format_exc()}")
@ -885,7 +829,7 @@ def ai21_completion_call_bad_key():
# test on openai completion call
def test_openai_chat_completion_call():
try:
litellm.set_verbose = True
litellm.set_verbose = False
response = completion(
model="gpt-3.5-turbo", messages=messages, stream=True
)
@ -904,7 +848,7 @@ def test_openai_chat_completion_call():
print(f"error occurred: {traceback.format_exc()}")
pass
# test_openai_chat_completion_call()
test_openai_chat_completion_call()
def test_openai_chat_completion_complete_response_call():
try:
@ -928,6 +872,7 @@ def test_openai_text_completion_call():
start_time = time.time()
for idx, chunk in enumerate(response):
chunk, finished = streaming_format_tests(idx, chunk)
print(f"chunk: {chunk}")
complete_response += chunk
if finished:
break
@ -939,6 +884,8 @@ def test_openai_text_completion_call():
print(f"error occurred: {traceback.format_exc()}")
pass
# test_openai_text_completion_call()
# # test on together ai completion call - starcoder
def test_together_ai_completion_call_starcoder():
try:
@ -992,7 +939,7 @@ def test_together_ai_completion_call_starcoder_bad_key():
if complete_response == "":
raise Exception("Empty response received")
print(f"complete response: {complete_response}")
except InvalidRequestError as e:
except BadRequestError as e:
pass
except:
print(f"error occurred: {traceback.format_exc()}")

View file

@ -17,7 +17,10 @@ from concurrent import futures
from inspect import iscoroutinefunction
from functools import wraps
from threading import Thread
from openai.error import Timeout
try:
from openai import Timeout
except:
from openai.error import Timeout
def timeout(timeout_duration: float = 0.0, exception_to_raise=Timeout):

File diff suppressed because it is too large Load diff