forked from phoenix/litellm-mirror
fix(acompletion): support client side timeouts + raise exceptions correctly for async calls
This commit is contained in:
parent
c4e53aa77b
commit
0ab6b2451d
8 changed files with 142 additions and 81 deletions
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional, Union
|
||||
from typing import Optional, Union, Any
|
||||
import types, requests
|
||||
from .base import BaseLLM
|
||||
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, convert_to_model_response_object
|
||||
|
@ -98,6 +98,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
api_type: str,
|
||||
azure_ad_token: str,
|
||||
print_verbose: Callable,
|
||||
timeout,
|
||||
logging_obj,
|
||||
optional_params,
|
||||
litellm_params,
|
||||
|
@ -129,13 +130,13 @@ class AzureChatCompletion(BaseLLM):
|
|||
)
|
||||
if acompletion is True:
|
||||
if optional_params.get("stream", False):
|
||||
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token)
|
||||
return self.async_streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout)
|
||||
else:
|
||||
return self.acompletion(api_base=api_base, data=data, model_response=model_response, api_key=api_key, api_version=api_version, model=model, azure_ad_token=azure_ad_token)
|
||||
return self.acompletion(api_base=api_base, data=data, model_response=model_response, api_key=api_key, api_version=api_version, model=model, azure_ad_token=azure_ad_token, timeout=timeout)
|
||||
elif "stream" in optional_params and optional_params["stream"] == True:
|
||||
return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token)
|
||||
return self.streaming(logging_obj=logging_obj, api_base=api_base, data=data, model=model, api_key=api_key, api_version=api_version, azure_ad_token=azure_ad_token, timeout=timeout)
|
||||
else:
|
||||
azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.client_session)
|
||||
azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.client_session, timeout=timeout)
|
||||
response = azure_client.chat.completions.create(**data) # type: ignore
|
||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
||||
except AzureOpenAIError as e:
|
||||
|
@ -150,16 +151,18 @@ class AzureChatCompletion(BaseLLM):
|
|||
model: str,
|
||||
api_base: str,
|
||||
data: dict,
|
||||
timeout: Any,
|
||||
model_response: ModelResponse,
|
||||
azure_ad_token: Optional[str]=None, ):
|
||||
response = None
|
||||
try:
|
||||
azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.aclient_session)
|
||||
azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.aclient_session, timeout=timeout)
|
||||
response = await azure_client.chat.completions.create(**data)
|
||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
||||
except Exception as e:
|
||||
if isinstance(e,httpx.TimeoutException):
|
||||
raise AzureOpenAIError(status_code=500, message="Request Timeout Error")
|
||||
elif response and hasattr(response, "text"):
|
||||
elif response is not None and hasattr(response, "text"):
|
||||
raise AzureOpenAIError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}")
|
||||
else:
|
||||
raise AzureOpenAIError(status_code=500, message=f"{str(e)}")
|
||||
|
@ -171,9 +174,10 @@ class AzureChatCompletion(BaseLLM):
|
|||
api_version: str,
|
||||
data: dict,
|
||||
model: str,
|
||||
timeout: Any,
|
||||
azure_ad_token: Optional[str]=None,
|
||||
):
|
||||
azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.client_session)
|
||||
azure_client = AzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.client_session, timeout=timeout)
|
||||
response = azure_client.chat.completions.create(**data)
|
||||
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj)
|
||||
for transformed_chunk in streamwrapper:
|
||||
|
@ -186,8 +190,9 @@ class AzureChatCompletion(BaseLLM):
|
|||
api_version: str,
|
||||
data: dict,
|
||||
model: str,
|
||||
timeout: Any,
|
||||
azure_ad_token: Optional[str]=None):
|
||||
azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.aclient_session)
|
||||
azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=api_base, azure_deployment=model, azure_ad_token=azure_ad_token, http_client=litellm.aclient_session, timeout=timeout)
|
||||
response = await azure_client.chat.completions.create(**data)
|
||||
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="azure",logging_obj=logging_obj)
|
||||
async for transformed_chunk in streamwrapper:
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional, Union
|
||||
from typing import Optional, Union, Any
|
||||
import types, time, json
|
||||
import httpx
|
||||
from .base import BaseLLM
|
||||
|
@ -161,6 +161,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
|
||||
def completion(self,
|
||||
model_response: ModelResponse,
|
||||
timeout: Any,
|
||||
model: Optional[str]=None,
|
||||
messages: Optional[list]=None,
|
||||
print_verbose: Optional[Callable]=None,
|
||||
|
@ -180,6 +181,9 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
if model is None or messages is None:
|
||||
raise OpenAIError(status_code=422, message=f"Missing model or messages")
|
||||
|
||||
if not isinstance(timeout, float):
|
||||
raise OpenAIError(status_code=422, message=f"Timeout needs to be a float")
|
||||
|
||||
for _ in range(2): # if call fails due to alternating messages, retry with reformatted message
|
||||
data = {
|
||||
"model": model,
|
||||
|
@ -197,13 +201,13 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
try:
|
||||
if acompletion is True:
|
||||
if optional_params.get("stream", False):
|
||||
return self.async_streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key)
|
||||
return self.async_streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout)
|
||||
else:
|
||||
return self.acompletion(data=data, model_response=model_response, api_base=api_base, api_key=api_key)
|
||||
return self.acompletion(data=data, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout)
|
||||
elif optional_params.get("stream", False):
|
||||
return self.streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key)
|
||||
return self.streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout)
|
||||
else:
|
||||
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session)
|
||||
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout)
|
||||
response = openai_client.chat.completions.create(**data) # type: ignore
|
||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
||||
except Exception as e:
|
||||
|
@ -234,27 +238,32 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
async def acompletion(self,
|
||||
data: dict,
|
||||
model_response: ModelResponse,
|
||||
timeout: float,
|
||||
api_key: Optional[str]=None,
|
||||
api_base: Optional[str]=None):
|
||||
response = None
|
||||
try:
|
||||
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session)
|
||||
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout)
|
||||
response = await openai_aclient.chat.completions.create(**data)
|
||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
||||
except Exception as e:
|
||||
if response and hasattr(response, "text"):
|
||||
raise OpenAIError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}")
|
||||
else:
|
||||
if type(e).__name__ == "ReadTimeout":
|
||||
raise OpenAIError(status_code=408, message=f"{type(e).__name__}")
|
||||
else:
|
||||
raise OpenAIError(status_code=500, message=f"{str(e)}")
|
||||
|
||||
def streaming(self,
|
||||
logging_obj,
|
||||
timeout: float,
|
||||
data: dict,
|
||||
model: str,
|
||||
api_key: Optional[str]=None,
|
||||
api_base: Optional[str]=None
|
||||
):
|
||||
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session)
|
||||
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout)
|
||||
response = openai_client.chat.completions.create(**data)
|
||||
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
|
||||
for transformed_chunk in streamwrapper:
|
||||
|
@ -262,15 +271,26 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
|
||||
async def async_streaming(self,
|
||||
logging_obj,
|
||||
timeout: float,
|
||||
data: dict,
|
||||
model: str,
|
||||
api_key: Optional[str]=None,
|
||||
api_base: Optional[str]=None):
|
||||
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session)
|
||||
response = None
|
||||
try:
|
||||
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout)
|
||||
response = await openai_aclient.chat.completions.create(**data)
|
||||
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
|
||||
async for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
except Exception as e: # need to exception handle here. async exceptions don't get caught in sync functions.
|
||||
if response is not None and hasattr(response, "text"):
|
||||
raise OpenAIError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}")
|
||||
else:
|
||||
if type(e).__name__ == "ReadTimeout":
|
||||
raise OpenAIError(status_code=408, message=f"{type(e).__name__}")
|
||||
else:
|
||||
raise OpenAIError(status_code=500, message=f"{str(e)}")
|
||||
|
||||
def embedding(self,
|
||||
model: str,
|
||||
|
|
|
@ -10,7 +10,6 @@
|
|||
import os, openai, sys, json, inspect, uuid, datetime, threading
|
||||
from typing import Any
|
||||
from functools import partial
|
||||
|
||||
import dotenv, traceback, random, asyncio, time, contextvars
|
||||
from copy import deepcopy
|
||||
import httpx
|
||||
|
@ -18,7 +17,6 @@ import litellm
|
|||
from litellm import ( # type: ignore
|
||||
client,
|
||||
exception_type,
|
||||
timeout,
|
||||
get_optional_params,
|
||||
get_litellm_params,
|
||||
Logging,
|
||||
|
@ -176,28 +174,28 @@ async def acompletion(*args, **kwargs):
|
|||
init_response = completion(*args, **kwargs)
|
||||
if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO
|
||||
response = init_response
|
||||
else:
|
||||
elif asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
# Call the synchronous function using run_in_executor
|
||||
response = await loop.run_in_executor(None, func_with_context)
|
||||
if kwargs.get("stream", False): # return an async generator
|
||||
# do not change this
|
||||
# for stream = True, always return an async generator
|
||||
# See OpenAI acreate https://github.com/openai/openai-python/blob/5d50e9e3b39540af782ca24e65c290343d86e1a9/openai/api_resources/abstract/engine_api_resource.py#L193
|
||||
# return response
|
||||
return(
|
||||
line
|
||||
async for line in response
|
||||
)
|
||||
return _async_streaming(response=response, model=model, custom_llm_provider=custom_llm_provider, args=args)
|
||||
else:
|
||||
return response
|
||||
except Exception as e:
|
||||
## Map to OpenAI Exception
|
||||
raise exception_type(
|
||||
model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args,
|
||||
)
|
||||
|
||||
async def _async_streaming(response, model, custom_llm_provider, args):
|
||||
try:
|
||||
async for line in response:
|
||||
yield line
|
||||
except Exception as e:
|
||||
raise exception_type(
|
||||
model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args,
|
||||
)
|
||||
|
||||
def mock_completion(model: str, messages: List, stream: Optional[bool] = False, mock_response: str = "This is a mock request", **kwargs):
|
||||
"""
|
||||
|
@ -245,6 +243,7 @@ def completion(
|
|||
messages: List = [],
|
||||
functions: List = [],
|
||||
function_call: str = "", # optional params
|
||||
timeout: Union[float, int] = 600.0,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
n: Optional[int] = None,
|
||||
|
@ -261,7 +260,6 @@ def completion(
|
|||
tools: Optional[List] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
deployment_id = None,
|
||||
|
||||
# set api_base, api_version, api_key
|
||||
base_url: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
|
@ -298,9 +296,8 @@ def completion(
|
|||
|
||||
LITELLM Specific Params
|
||||
mock_response (str, optional): If provided, return a mock completion response for testing or debugging purposes (default is None).
|
||||
force_timeout (int, optional): The maximum execution time in seconds for the completion request (default is 600).
|
||||
custom_llm_provider (str, optional): Used for Non-OpenAI LLMs, Example usage for bedrock, set model="amazon.titan-tg1-large" and custom_llm_provider="bedrock"
|
||||
num_retries (int, optional): The number of retries to attempt (default is 0).
|
||||
max_retries (int, optional): The number of retries to attempt (default is 0).
|
||||
Returns:
|
||||
ModelResponse: A response object containing the generated completion and associated metadata.
|
||||
|
||||
|
@ -314,7 +311,7 @@ def completion(
|
|||
api_base = kwargs.get('api_base', None)
|
||||
return_async = kwargs.get('return_async', False)
|
||||
mock_response = kwargs.get('mock_response', None)
|
||||
force_timeout= kwargs.get('force_timeout', 600)
|
||||
force_timeout= kwargs.get('force_timeout', 600) ## deprecated
|
||||
logger_fn = kwargs.get('logger_fn', None)
|
||||
verbose = kwargs.get('verbose', False)
|
||||
custom_llm_provider = kwargs.get('custom_llm_provider', None)
|
||||
|
@ -338,8 +335,10 @@ def completion(
|
|||
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "max_retries"]
|
||||
default_params = openai_params + litellm_params
|
||||
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
||||
|
||||
if mock_response:
|
||||
return mock_completion(model, messages, stream=stream, mock_response=mock_response)
|
||||
timeout = float(timeout)
|
||||
try:
|
||||
if base_url:
|
||||
api_base = base_url
|
||||
|
@ -486,7 +485,8 @@ def completion(
|
|||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
logging_obj=logging,
|
||||
acompletion=acompletion
|
||||
acompletion=acompletion,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
|
@ -552,7 +552,8 @@ def completion(
|
|||
logging_obj=logging,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn
|
||||
logger_fn=logger_fn,
|
||||
timeout=timeout
|
||||
)
|
||||
except Exception as e:
|
||||
## LOGGING - log the original exception returned
|
||||
|
@ -1399,6 +1400,24 @@ def completion_with_retries(*args, **kwargs):
|
|||
retryer = tenacity.Retrying(wait=tenacity.wait_exponential(multiplier=1, max=10), stop=tenacity.stop_after_attempt(num_retries), reraise=True)
|
||||
return retryer(original_function, *args, **kwargs)
|
||||
|
||||
async def acompletion_with_retries(*args, **kwargs):
|
||||
"""
|
||||
Executes a litellm.completion() with 3 retries
|
||||
"""
|
||||
try:
|
||||
import tenacity
|
||||
except Exception as e:
|
||||
raise Exception(f"tenacity import failed please run `pip install tenacity`. Error{e}")
|
||||
|
||||
num_retries = kwargs.pop("num_retries", 3)
|
||||
retry_strategy = kwargs.pop("retry_strategy", "constant_retry")
|
||||
original_function = kwargs.pop("original_function", completion)
|
||||
if retry_strategy == "constant_retry":
|
||||
retryer = tenacity.Retrying(stop=tenacity.stop_after_attempt(num_retries), reraise=True)
|
||||
elif retry_strategy == "exponential_backoff_retry":
|
||||
retryer = tenacity.Retrying(wait=tenacity.wait_exponential(multiplier=1, max=10), stop=tenacity.stop_after_attempt(num_retries), reraise=True)
|
||||
return await retryer(original_function, *args, **kwargs)
|
||||
|
||||
|
||||
|
||||
def batch_completion(
|
||||
|
@ -1639,9 +1658,6 @@ async def aembedding(*args, **kwargs):
|
|||
return response
|
||||
|
||||
@client
|
||||
@timeout( # type: ignore
|
||||
60
|
||||
) ## set timeouts, in case calls hang (e.g. Azure) - default is 60s, override with `force_timeout`
|
||||
def embedding(
|
||||
model,
|
||||
input=[],
|
||||
|
|
|
@ -26,4 +26,3 @@ model_list:
|
|||
litellm_settings:
|
||||
drop_params: True
|
||||
success_callback: ["langfuse"] # https://docs.litellm.ai/docs/observability/langfuse_integration
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ class Router:
|
|||
redis_port: Optional[int] = None,
|
||||
redis_password: Optional[str] = None,
|
||||
cache_responses: bool = False,
|
||||
num_retries: Optional[int] = None,
|
||||
num_retries: int = 0,
|
||||
timeout: float = 600,
|
||||
default_litellm_params = {}, # default params for Router.chat.completion.create
|
||||
routing_strategy: Literal["simple-shuffle", "least-busy"] = "simple-shuffle") -> None:
|
||||
|
@ -42,12 +42,13 @@ class Router:
|
|||
self.set_model_list(model_list)
|
||||
self.healthy_deployments: List = self.model_list
|
||||
|
||||
if num_retries:
|
||||
self.num_retries = num_retries
|
||||
|
||||
self.chat = litellm.Chat(params=default_litellm_params)
|
||||
|
||||
litellm.request_timeout = timeout
|
||||
self.default_litellm_params = {
|
||||
"timeout": timeout
|
||||
}
|
||||
self.routing_strategy = routing_strategy
|
||||
### HEALTH CHECK THREAD ###
|
||||
if self.routing_strategy == "least-busy":
|
||||
|
@ -222,6 +223,9 @@ class Router:
|
|||
# pick the one that is available (lowest TPM/RPM)
|
||||
deployment = self.get_available_deployment(model=model, messages=messages)
|
||||
data = deployment["litellm_params"]
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if k not in data: # prioritize model-specific params > default router params
|
||||
data[k] = v
|
||||
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
|
||||
|
||||
|
||||
|
@ -234,16 +238,20 @@ class Router:
|
|||
try:
|
||||
deployment = self.get_available_deployment(model=model, messages=messages)
|
||||
data = deployment["litellm_params"]
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if k not in data: # prioritize model-specific params > default router params
|
||||
data[k] = v
|
||||
response = await litellm.acompletion(**{**data, "messages": messages, "caching": self.cache_responses, **kwargs})
|
||||
if inspect.iscoroutinefunction(response): # async errors are often returned as coroutines
|
||||
response = await response
|
||||
return response
|
||||
except Exception as e:
|
||||
if self.num_retries > 0:
|
||||
kwargs["model"] = model
|
||||
kwargs["messages"] = messages
|
||||
kwargs["original_exception"] = e
|
||||
kwargs["original_function"] = self.acompletion
|
||||
return await self.async_function_with_retries(**kwargs)
|
||||
else:
|
||||
raise e
|
||||
|
||||
def text_completion(self,
|
||||
model: str,
|
||||
|
@ -258,6 +266,9 @@ class Router:
|
|||
deployment = self.get_available_deployment(model=model, messages=messages)
|
||||
|
||||
data = deployment["litellm_params"]
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if k not in data: # prioritize model-specific params > default router params
|
||||
data[k] = v
|
||||
# call via litellm.completion()
|
||||
return litellm.text_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore
|
||||
|
||||
|
@ -270,6 +281,9 @@ class Router:
|
|||
deployment = self.get_available_deployment(model=model, input=input)
|
||||
|
||||
data = deployment["litellm_params"]
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if k not in data: # prioritize model-specific params > default router params
|
||||
data[k] = v
|
||||
# call via litellm.embedding()
|
||||
return litellm.embedding(**{**data, "input": input, "caching": self.cache_responses, **kwargs})
|
||||
|
||||
|
@ -282,4 +296,7 @@ class Router:
|
|||
deployment = self.get_available_deployment(model=model, input=input)
|
||||
|
||||
data = deployment["litellm_params"]
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if k not in data: # prioritize model-specific params > default router params
|
||||
data[k] = v
|
||||
return await litellm.aembedding(**{**data, "input": input, "caching": self.cache_responses, **kwargs})
|
|
@ -5,8 +5,6 @@ import sys, os
|
|||
import pytest
|
||||
import traceback
|
||||
import asyncio, logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
|
@ -20,7 +18,8 @@ def test_sync_response():
|
|||
user_message = "Hello, how are you?"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
try:
|
||||
response = completion(model="gpt-3.5-turbo", messages=messages, api_key=os.environ["OPENAI_API_KEY"])
|
||||
response = completion(model="gpt-3.5-turbo", messages=messages, timeout=5)
|
||||
print(f"response: {response}")
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {e}")
|
||||
# test_sync_response()
|
||||
|
@ -30,7 +29,7 @@ def test_sync_response_anyscale():
|
|||
user_message = "Hello, how are you?"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
try:
|
||||
response = completion(model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages)
|
||||
response = completion(model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, timeout=5)
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {e}")
|
||||
|
||||
|
@ -43,10 +42,13 @@ def test_async_response_openai():
|
|||
user_message = "Hello, how are you?"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
try:
|
||||
response = await acompletion(model="gpt-3.5-turbo", messages=messages)
|
||||
response = await acompletion(model="gpt-3.5-turbo", messages=messages, timeout=5)
|
||||
print(f"response: {response}")
|
||||
except litellm.Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {e}")
|
||||
print(e)
|
||||
|
||||
asyncio.run(test_get_response())
|
||||
|
||||
|
@ -56,17 +58,18 @@ def test_async_response_azure():
|
|||
import asyncio
|
||||
litellm.set_verbose = True
|
||||
async def test_get_response():
|
||||
user_message = "Hello, how are you?"
|
||||
user_message = "What do you know?"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
try:
|
||||
response = await acompletion(model="azure/chatgpt-v-2", messages=messages)
|
||||
response = await acompletion(model="azure/chatgpt-v-2", messages=messages, timeout=5)
|
||||
print(f"response: {response}")
|
||||
except litellm.Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {e}")
|
||||
|
||||
asyncio.run(test_get_response())
|
||||
|
||||
|
||||
def test_async_anyscale_response():
|
||||
import asyncio
|
||||
litellm.set_verbose = True
|
||||
|
@ -74,9 +77,11 @@ def test_async_anyscale_response():
|
|||
user_message = "Hello, how are you?"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
try:
|
||||
response = await acompletion(model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages)
|
||||
response = await acompletion(model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, timeout=5)
|
||||
# response = await response
|
||||
print(f"response: {response}")
|
||||
except litellm.Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {e}")
|
||||
|
||||
|
@ -91,7 +96,7 @@ def test_get_response_streaming():
|
|||
messages = [{"content": user_message, "role": "user"}]
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
response = await acompletion(model="gpt-3.5-turbo", messages=messages, stream=True)
|
||||
response = await acompletion(model="gpt-3.5-turbo", messages=messages, stream=True, timeout=5)
|
||||
print(type(response))
|
||||
|
||||
import inspect
|
||||
|
@ -108,12 +113,12 @@ def test_get_response_streaming():
|
|||
assert isinstance(output, str), "output needs to be of type str"
|
||||
assert len(output) > 0, "Length of output needs to be greater than 0."
|
||||
print(f'output: {output}')
|
||||
except litellm.Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {e}")
|
||||
return response
|
||||
asyncio.run(test_async_call())
|
||||
|
||||
|
||||
# test_get_response_streaming()
|
||||
|
||||
def test_get_response_non_openai_streaming():
|
||||
|
@ -123,7 +128,7 @@ def test_get_response_non_openai_streaming():
|
|||
user_message = "Hello, how are you?"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
try:
|
||||
response = await acompletion(model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, stream=True)
|
||||
response = await acompletion(model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", messages=messages, stream=True, timeout=5)
|
||||
print(type(response))
|
||||
|
||||
import inspect
|
||||
|
@ -140,10 +145,11 @@ def test_get_response_non_openai_streaming():
|
|||
assert output is not None, "output cannot be None."
|
||||
assert isinstance(output, str), "output needs to be of type str"
|
||||
assert len(output) > 0, "Length of output needs to be greater than 0."
|
||||
|
||||
except litellm.Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {e}")
|
||||
return response
|
||||
asyncio.run(test_async_call())
|
||||
|
||||
# test_get_response_non_openai_streaming()
|
||||
test_get_response_non_openai_streaming()
|
|
@ -242,11 +242,11 @@ def test_acompletion_on_router():
|
|||
]
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "What is the weather like in Boston?"}
|
||||
{"role": "user", "content": "What is the weather like in SF?"}
|
||||
]
|
||||
|
||||
async def get_response():
|
||||
router = Router(model_list=model_list, redis_host=os.environ["REDIS_HOST"], redis_password=os.environ["REDIS_PASSWORD"], redis_port=os.environ["REDIS_PORT"], cache_responses=True, timeout=10)
|
||||
router = Router(model_list=model_list, redis_host=os.environ["REDIS_HOST"], redis_password=os.environ["REDIS_PASSWORD"], redis_port=os.environ["REDIS_PORT"], cache_responses=True, timeout=0.1)
|
||||
response1 = await router.acompletion(model="gpt-3.5-turbo", messages=messages)
|
||||
print(f"response1: {response1}")
|
||||
response2 = await router.acompletion(model="gpt-3.5-turbo", messages=messages)
|
||||
|
|
|
@ -18,7 +18,7 @@ import tiktoken
|
|||
import uuid
|
||||
import aiohttp
|
||||
import logging
|
||||
import asyncio, httpx
|
||||
import asyncio, httpx, inspect
|
||||
import copy
|
||||
from tokenizers import Tokenizer
|
||||
from dataclasses import (
|
||||
|
@ -1047,7 +1047,6 @@ def exception_logging(
|
|||
# make it easy to log if completion/embedding runs succeeded or failed + see what happened | Non-Blocking
|
||||
def client(original_function):
|
||||
global liteDebuggerClient, get_all_keys
|
||||
import inspect
|
||||
def function_setup(
|
||||
start_time, *args, **kwargs
|
||||
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
|
||||
|
@ -1333,13 +1332,13 @@ def client(original_function):
|
|||
kwargs["retry_strategy"] = "exponential_backoff_retry"
|
||||
elif (isinstance(e, openai.APIError)): # generic api error
|
||||
kwargs["retry_strategy"] = "constant_retry"
|
||||
return litellm.completion_with_retries(*args, **kwargs)
|
||||
return await litellm.acompletion_with_retries(*args, **kwargs)
|
||||
elif isinstance(e, litellm.exceptions.ContextWindowExceededError) and context_window_fallback_dict and model in context_window_fallback_dict:
|
||||
if len(args) > 0:
|
||||
args[0] = context_window_fallback_dict[model]
|
||||
else:
|
||||
kwargs["model"] = context_window_fallback_dict[model]
|
||||
return original_function(*args, **kwargs)
|
||||
return await original_function(*args, **kwargs)
|
||||
traceback_exception = traceback.format_exc()
|
||||
crash_reporting(*args, **kwargs, exception=traceback_exception)
|
||||
end_time = datetime.datetime.now()
|
||||
|
@ -3370,7 +3369,7 @@ def exception_type(
|
|||
else:
|
||||
exception_type = ""
|
||||
|
||||
if "Request Timeout Error" in error_str:
|
||||
if "Request Timeout Error" in error_str or "Request timed out" in error_str:
|
||||
exception_mapping_worked = True
|
||||
raise Timeout(
|
||||
message=f"APITimeoutError - Request timed out",
|
||||
|
@ -3411,7 +3410,6 @@ def exception_type(
|
|||
message=f"OpenAIException - {original_exception.message}",
|
||||
model=model,
|
||||
llm_provider="openai",
|
||||
request=original_exception.request
|
||||
)
|
||||
if original_exception.status_code == 422:
|
||||
exception_mapping_worked = True
|
||||
|
@ -4229,7 +4227,7 @@ def exception_type(
|
|||
llm_provider=custom_llm_provider,
|
||||
response=original_exception.response
|
||||
)
|
||||
else: # ensure generic errors always return APIConnectionError
|
||||
else: # ensure generic errors always return APIConnectionError=
|
||||
exception_mapping_worked = True
|
||||
if hasattr(original_exception, "request"):
|
||||
raise APIConnectionError(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue