mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(acompletion): support fallbacks on acompletion (#7184)
* fix(acompletion): support fallbacks on acompletion allows health checks for wildcard routes to use fallback models * test: update cohere generate api testing * add max tokens to health check (#7000) * fix: fix health check test * test: update testing --------- Co-authored-by: Cameron <561860+wallies@users.noreply.github.com>
This commit is contained in:
parent
5fe77499d2
commit
a9aeb21d0b
8 changed files with 240 additions and 69 deletions
|
@ -534,7 +534,7 @@ def add_known_models():
|
|||
gemini_models.append(key)
|
||||
elif value.get("litellm_provider") == "fireworks_ai":
|
||||
# ignore the 'up-to', '-to-' model names -> not real models. just for cost tracking based on model params.
|
||||
if "-to-" not in key:
|
||||
if "-to-" not in key and "fireworks-ai-default" not in key:
|
||||
fireworks_ai_models.append(key)
|
||||
elif value.get("litellm_provider") == "fireworks_ai-embedding-models":
|
||||
# ignore the 'up-to', '-to-' model names -> not real models. just for cost tracking based on model params.
|
||||
|
|
|
@ -189,7 +189,11 @@ def exception_type( # type: ignore # noqa: PLR0915
|
|||
#################### Start of Provider Exception mapping ####################
|
||||
################################################################################
|
||||
|
||||
if "Request Timeout Error" in error_str or "Request timed out" in error_str:
|
||||
if (
|
||||
"Request Timeout Error" in error_str
|
||||
or "Request timed out" in error_str
|
||||
or "Timed out generating response" in error_str
|
||||
):
|
||||
exception_mapping_worked = True
|
||||
raise Timeout(
|
||||
message=f"APITimeoutError - Request timed out. \nerror_str: {error_str}",
|
||||
|
|
|
@ -23,6 +23,7 @@ from litellm import verbose_logger
|
|||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
|
@ -54,10 +55,15 @@ class BaseLLMHTTPHandler:
|
|||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
):
|
||||
if client is None:
|
||||
async_httpx_client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders(custom_llm_provider)
|
||||
)
|
||||
else:
|
||||
async_httpx_client = client
|
||||
|
||||
try:
|
||||
response = await async_httpx_client.post(
|
||||
url=api_base,
|
||||
|
@ -97,6 +103,7 @@ class BaseLLMHTTPHandler:
|
|||
fake_stream: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
headers={},
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
):
|
||||
provider_config = ProviderConfigManager.get_provider_chat_config(
|
||||
model=model, provider=litellm.LlmProviders(custom_llm_provider)
|
||||
|
@ -149,6 +156,11 @@ class BaseLLMHTTPHandler:
|
|||
logging_obj=logging_obj,
|
||||
data=data,
|
||||
fake_stream=fake_stream,
|
||||
client=(
|
||||
client
|
||||
if client is not None and isinstance(client, AsyncHTTPHandler)
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
|
@ -167,6 +179,11 @@ class BaseLLMHTTPHandler:
|
|||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
client=(
|
||||
client
|
||||
if client is not None and isinstance(client, AsyncHTTPHandler)
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
if stream is True:
|
||||
|
@ -182,6 +199,11 @@ class BaseLLMHTTPHandler:
|
|||
logging_obj=logging_obj,
|
||||
timeout=timeout,
|
||||
fake_stream=fake_stream,
|
||||
client=(
|
||||
client
|
||||
if client is not None and isinstance(client, HTTPHandler)
|
||||
else None
|
||||
),
|
||||
)
|
||||
return CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
|
@ -190,11 +212,14 @@ class BaseLLMHTTPHandler:
|
|||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
sync_httpx_client = _get_httpx_client()
|
||||
else:
|
||||
sync_httpx_client = client
|
||||
|
||||
try:
|
||||
response = sync_httpx_client.post(
|
||||
api_base,
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
timeout=timeout,
|
||||
|
@ -229,8 +254,12 @@ class BaseLLMHTTPHandler:
|
|||
logging_obj,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
fake_stream: bool = False,
|
||||
client: Optional[HTTPHandler] = None,
|
||||
) -> Tuple[Any, httpx.Headers]:
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
sync_httpx_client = _get_httpx_client()
|
||||
else:
|
||||
sync_httpx_client = client
|
||||
try:
|
||||
stream = True
|
||||
if fake_stream is True:
|
||||
|
@ -289,6 +318,7 @@ class BaseLLMHTTPHandler:
|
|||
logging_obj: LiteLLMLoggingObj,
|
||||
data: dict,
|
||||
fake_stream: bool = False,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
):
|
||||
completion_stream, _response_headers = await self.make_async_call(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
|
@ -300,6 +330,7 @@ class BaseLLMHTTPHandler:
|
|||
logging_obj=logging_obj,
|
||||
timeout=timeout,
|
||||
fake_stream=fake_stream,
|
||||
client=client,
|
||||
)
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
|
@ -320,10 +351,14 @@ class BaseLLMHTTPHandler:
|
|||
logging_obj: LiteLLMLoggingObj,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
fake_stream: bool = False,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> Tuple[Any, httpx.Headers]:
|
||||
if client is None:
|
||||
async_httpx_client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders(custom_llm_provider)
|
||||
)
|
||||
else:
|
||||
async_httpx_client = client
|
||||
stream = True
|
||||
if fake_stream is True:
|
||||
stream = False
|
||||
|
|
|
@ -64,6 +64,7 @@ from litellm.secret_managers.main import get_secret_str
|
|||
from litellm.utils import (
|
||||
CustomStreamWrapper,
|
||||
Usage,
|
||||
async_completion_with_fallbacks,
|
||||
async_mock_completion_streaming_obj,
|
||||
completion_with_fallbacks,
|
||||
convert_to_model_response_object,
|
||||
|
@ -364,6 +365,8 @@ async def acompletion(
|
|||
- The `completion` function is called using `run_in_executor` to execute synchronously in the event loop.
|
||||
- If `stream` is True, the function returns an async generator that yields completion lines.
|
||||
"""
|
||||
fallbacks = kwargs.get("fallbacks", None)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
custom_llm_provider = kwargs.get("custom_llm_provider", None)
|
||||
# Adjusted to use explicit arguments instead of *args and **kwargs
|
||||
|
@ -407,6 +410,18 @@ async def acompletion(
|
|||
_, custom_llm_provider, _, _ = get_llm_provider(
|
||||
model=model, api_base=completion_kwargs.get("base_url", None)
|
||||
)
|
||||
|
||||
fallbacks = fallbacks or litellm.model_fallbacks
|
||||
if fallbacks is not None:
|
||||
response = await async_completion_with_fallbacks(
|
||||
**completion_kwargs, kwargs={"fallbacks": fallbacks}
|
||||
)
|
||||
if response is None:
|
||||
raise Exception(
|
||||
"No response from fallbacks. Got none. Turn on `litellm.set_verbose=True` to see more details."
|
||||
)
|
||||
return response
|
||||
|
||||
try:
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(completion, **completion_kwargs, **kwargs)
|
||||
|
@ -1884,6 +1899,7 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
encoding=encoding,
|
||||
api_key=cohere_key,
|
||||
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
||||
client=client,
|
||||
)
|
||||
elif custom_llm_provider == "cohere_chat":
|
||||
cohere_key = (
|
||||
|
@ -4997,6 +5013,38 @@ def speech(
|
|||
##### Health Endpoints #######################
|
||||
|
||||
|
||||
async def ahealth_check_chat_models(
|
||||
model: str, custom_llm_provider: str, model_params: dict
|
||||
) -> dict:
|
||||
if "*" in model:
|
||||
from litellm.litellm_core_utils.llm_request_utils import (
|
||||
pick_cheapest_chat_model_from_llm_provider,
|
||||
)
|
||||
|
||||
# this is a wildcard model, we need to pick a random model from the provider
|
||||
cheapest_model = pick_cheapest_chat_model_from_llm_provider(
|
||||
custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
fallback_models: Optional[List] = None
|
||||
if custom_llm_provider in litellm.models_by_provider:
|
||||
models = litellm.models_by_provider[custom_llm_provider]
|
||||
random.shuffle(models) # Shuffle the models list in place
|
||||
fallback_models = models[
|
||||
:2
|
||||
] # Pick the first 2 models from the shuffled list
|
||||
model_params["model"] = cheapest_model
|
||||
model_params["fallbacks"] = fallback_models
|
||||
model_params["max_tokens"] = 1
|
||||
await acompletion(**model_params)
|
||||
response: dict = {} # args like remaining ratelimit etc.
|
||||
else: # default to completion calls
|
||||
model_params["max_tokens"] = 1
|
||||
await acompletion(**model_params)
|
||||
response = {} # args like remaining ratelimit etc.
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def ahealth_check( # noqa: PLR0915
|
||||
model_params: dict,
|
||||
mode: Optional[
|
||||
|
@ -5128,21 +5176,12 @@ async def ahealth_check( # noqa: PLR0915
|
|||
model_params["documents"] = ["my sample text"]
|
||||
await litellm.arerank(**model_params)
|
||||
response = {}
|
||||
elif "*" in model:
|
||||
from litellm.litellm_core_utils.llm_request_utils import (
|
||||
pick_cheapest_chat_model_from_llm_provider,
|
||||
else:
|
||||
response = await ahealth_check_chat_models(
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
model_params=model_params,
|
||||
)
|
||||
|
||||
# this is a wildcard model, we need to pick a random model from the provider
|
||||
cheapest_model = pick_cheapest_chat_model_from_llm_provider(
|
||||
custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
model_params["model"] = cheapest_model
|
||||
await acompletion(**model_params)
|
||||
response = {} # args like remaining ratelimit etc.
|
||||
else: # default to completion calls
|
||||
await acompletion(**model_params)
|
||||
response = {} # args like remaining ratelimit etc.
|
||||
return response
|
||||
except Exception as e:
|
||||
stack_trace = traceback.format_exc()
|
||||
|
|
|
@ -5697,6 +5697,72 @@ def completion_with_fallbacks(**kwargs):
|
|||
return response
|
||||
|
||||
|
||||
async def async_completion_with_fallbacks(**kwargs):
|
||||
nested_kwargs = kwargs.pop("kwargs", {})
|
||||
response = None
|
||||
rate_limited_models = set()
|
||||
model_expiration_times = {}
|
||||
start_time = time.time()
|
||||
original_model = kwargs["model"]
|
||||
fallbacks = [kwargs["model"]] + nested_kwargs.get("fallbacks", [])
|
||||
if "fallbacks" in nested_kwargs:
|
||||
del nested_kwargs["fallbacks"] # remove fallbacks so it's not recursive
|
||||
if "acompletion" in kwargs:
|
||||
del kwargs[
|
||||
"acompletion"
|
||||
] # remove acompletion so it doesn't lead to keyword errors
|
||||
litellm_call_id = str(uuid.uuid4())
|
||||
|
||||
# max time to process a request with fallbacks: default 45s
|
||||
while response is None and time.time() - start_time < 45:
|
||||
for model in fallbacks:
|
||||
# loop thru all models
|
||||
try:
|
||||
# check if it's dict or new model string
|
||||
if isinstance(
|
||||
model, dict
|
||||
): # completion(model="gpt-4", fallbacks=[{"api_key": "", "api_base": ""}, {"api_key": "", "api_base": ""}])
|
||||
kwargs["api_key"] = model.get("api_key", None)
|
||||
kwargs["api_base"] = model.get("api_base", None)
|
||||
model = model.get("model", original_model)
|
||||
elif (
|
||||
model in rate_limited_models
|
||||
): # check if model is currently cooling down
|
||||
if (
|
||||
model_expiration_times.get(model)
|
||||
and time.time() >= model_expiration_times[model]
|
||||
):
|
||||
rate_limited_models.remove(
|
||||
model
|
||||
) # check if it's been 60s of cool down and remove model
|
||||
else:
|
||||
continue # skip model
|
||||
|
||||
# delete model from kwargs if it exists
|
||||
if kwargs.get("model"):
|
||||
del kwargs["model"]
|
||||
|
||||
print_verbose(f"trying to make completion call with model: {model}")
|
||||
kwargs["litellm_call_id"] = litellm_call_id
|
||||
kwargs = {
|
||||
**kwargs,
|
||||
**nested_kwargs,
|
||||
} # combine the openai + litellm params at the same level
|
||||
response = await litellm.acompletion(**kwargs, model=model)
|
||||
print_verbose(f"response: {response}")
|
||||
if response is not None:
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
print_verbose(f"error: {e}")
|
||||
rate_limited_models.add(model)
|
||||
model_expiration_times[model] = (
|
||||
time.time() + 60
|
||||
) # cool down this selected model
|
||||
pass
|
||||
return response
|
||||
|
||||
|
||||
def process_system_message(system_message, max_tokens, model):
|
||||
system_message_event = {"role": "system", "content": system_message}
|
||||
system_message_tokens = get_token_count([system_message_event], model)
|
||||
|
|
|
@ -20,10 +20,13 @@ from litellm import completion
|
|||
from litellm.llms.cohere.completion.transformation import CohereTextConfig
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cohere_generate_api_completion():
|
||||
def test_cohere_generate_api_completion():
|
||||
try:
|
||||
litellm.set_verbose = False
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
client = HTTPHandler()
|
||||
litellm.set_verbose = True
|
||||
messages = [
|
||||
{"role": "system", "content": "You're a good bot"},
|
||||
{
|
||||
|
@ -31,12 +34,28 @@ async def test_cohere_generate_api_completion():
|
|||
"content": "Hey",
|
||||
},
|
||||
]
|
||||
response = completion(
|
||||
model="cohere/command-nightly",
|
||||
|
||||
with patch.object(client, "post") as mock_client:
|
||||
try:
|
||||
completion(
|
||||
model="cohere/command",
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
client=client,
|
||||
)
|
||||
print(response)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
mock_client.assert_called_once()
|
||||
print("mock_client.call_args.kwargs", mock_client.call_args.kwargs)
|
||||
|
||||
assert (
|
||||
mock_client.call_args.kwargs["url"]
|
||||
== "https://api.cohere.ai/v1/generate"
|
||||
)
|
||||
json_data = json.loads(mock_client.call_args.kwargs["data"])
|
||||
assert json_data["model"] == "command"
|
||||
assert json_data["prompt"] == "You're a good bot Hey"
|
||||
assert json_data["max_tokens"] == 10
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
@ -53,7 +72,7 @@ async def test_cohere_generate_api_stream():
|
|||
},
|
||||
]
|
||||
response = await litellm.acompletion(
|
||||
model="cohere/command-nightly",
|
||||
model="cohere/command",
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
stream=True,
|
||||
|
@ -76,7 +95,7 @@ def test_completion_cohere_stream_bad_key():
|
|||
},
|
||||
]
|
||||
completion(
|
||||
model="command-nightly",
|
||||
model="command",
|
||||
messages=messages,
|
||||
stream=True,
|
||||
max_tokens=50,
|
||||
|
@ -100,7 +119,7 @@ def test_cohere_transform_request():
|
|||
headers = {}
|
||||
|
||||
transformed_request = config.transform_request(
|
||||
model="command-nightly",
|
||||
model="command",
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params={},
|
||||
|
@ -109,7 +128,7 @@ def test_cohere_transform_request():
|
|||
|
||||
print("transformed_request", json.dumps(transformed_request, indent=4))
|
||||
|
||||
assert transformed_request["model"] == "command-nightly"
|
||||
assert transformed_request["model"] == "command"
|
||||
assert transformed_request["prompt"] == "You're a helpful bot Hello"
|
||||
assert transformed_request["max_tokens"] == 10
|
||||
assert transformed_request["temperature"] == 0.7
|
||||
|
@ -137,7 +156,7 @@ def test_cohere_transform_request_with_tools():
|
|||
optional_params = {"tools": tools}
|
||||
|
||||
transformed_request = config.transform_request(
|
||||
model="command-nightly",
|
||||
model="command",
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params={},
|
||||
|
@ -168,7 +187,7 @@ def test_cohere_map_openai_params():
|
|||
mapped_params = config.map_openai_params(
|
||||
non_default_params=openai_params,
|
||||
optional_params={},
|
||||
model="command-nightly",
|
||||
model="command",
|
||||
drop_params=False,
|
||||
)
|
||||
|
||||
|
|
|
@ -56,6 +56,7 @@ async def test_audio_output_from_model(stream):
|
|||
if stream is False:
|
||||
audio_format = "wav"
|
||||
litellm.set_verbose = False
|
||||
try:
|
||||
completion = await litellm.acompletion(
|
||||
model="gpt-4o-audio-preview",
|
||||
modalities=["text", "audio"],
|
||||
|
@ -63,6 +64,9 @@ async def test_audio_output_from_model(stream):
|
|||
messages=[{"role": "user", "content": "response in 1 word - yes or no"}],
|
||||
stream=stream,
|
||||
)
|
||||
except litellm.Timeout as e:
|
||||
print(e)
|
||||
pytest.skip("Skipping test due to timeout")
|
||||
|
||||
if stream is True:
|
||||
await check_streaming_response(completion)
|
||||
|
@ -88,7 +92,7 @@ async def test_audio_input_to_model(stream):
|
|||
response.raise_for_status()
|
||||
wav_data = response.content
|
||||
encoded_string = base64.b64encode(wav_data).decode("utf-8")
|
||||
|
||||
try:
|
||||
completion = await litellm.acompletion(
|
||||
model="gpt-4o-audio-preview",
|
||||
modalities=["text", "audio"],
|
||||
|
@ -107,6 +111,9 @@ async def test_audio_input_to_model(stream):
|
|||
},
|
||||
],
|
||||
)
|
||||
except litellm.Timeout as e:
|
||||
print(e)
|
||||
pytest.skip("Skipping test due to timeout")
|
||||
|
||||
if stream is True:
|
||||
await check_streaming_response(completion)
|
||||
|
|
|
@ -112,16 +112,17 @@ async def test_sagemaker_embedding_health_check():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fireworks_health_check():
|
||||
async def test_groq_health_check():
|
||||
"""
|
||||
This should not fail
|
||||
|
||||
ensure that provider wildcard model passes health check
|
||||
"""
|
||||
litellm.set_verbose = True
|
||||
response = await litellm.ahealth_check(
|
||||
model_params={
|
||||
"api_key": os.environ.get("FIREWORKS_AI_API_KEY"),
|
||||
"model": "fireworks_ai/*",
|
||||
"api_key": os.environ.get("GROQ_API_KEY"),
|
||||
"model": "groq/*",
|
||||
"messages": [{"role": "user", "content": "What's 1 + 1?"}],
|
||||
},
|
||||
mode=None,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue