fix(acompletion): support fallbacks on acompletion (#7184)

* fix(acompletion): support fallbacks on acompletion

allows health checks for wildcard routes to use fallback models

* test: update cohere generate api testing

* add max tokens to health check (#7000)

* fix: fix health check test

* test: update testing

---------

Co-authored-by: Cameron <561860+wallies@users.noreply.github.com>
This commit is contained in:
Krish Dholakia 2024-12-11 19:20:54 -08:00 committed by GitHub
parent 5fe77499d2
commit a9aeb21d0b
8 changed files with 240 additions and 69 deletions

View file

@ -64,6 +64,7 @@ from litellm.secret_managers.main import get_secret_str
from litellm.utils import (
CustomStreamWrapper,
Usage,
async_completion_with_fallbacks,
async_mock_completion_streaming_obj,
completion_with_fallbacks,
convert_to_model_response_object,
@ -364,6 +365,8 @@ async def acompletion(
- The `completion` function is called using `run_in_executor` to execute synchronously in the event loop.
- If `stream` is True, the function returns an async generator that yields completion lines.
"""
fallbacks = kwargs.get("fallbacks", None)
loop = asyncio.get_event_loop()
custom_llm_provider = kwargs.get("custom_llm_provider", None)
# Adjusted to use explicit arguments instead of *args and **kwargs
@ -407,6 +410,18 @@ async def acompletion(
_, custom_llm_provider, _, _ = get_llm_provider(
model=model, api_base=completion_kwargs.get("base_url", None)
)
fallbacks = fallbacks or litellm.model_fallbacks
if fallbacks is not None:
response = await async_completion_with_fallbacks(
**completion_kwargs, kwargs={"fallbacks": fallbacks}
)
if response is None:
raise Exception(
"No response from fallbacks. Got none. Turn on `litellm.set_verbose=True` to see more details."
)
return response
try:
# Use a partial function to pass your keyword arguments
func = partial(completion, **completion_kwargs, **kwargs)
@ -1884,6 +1899,7 @@ def completion( # type: ignore # noqa: PLR0915
encoding=encoding,
api_key=cohere_key,
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
client=client,
)
elif custom_llm_provider == "cohere_chat":
cohere_key = (
@ -4997,6 +5013,38 @@ def speech(
##### Health Endpoints #######################
async def ahealth_check_chat_models(
model: str, custom_llm_provider: str, model_params: dict
) -> dict:
if "*" in model:
from litellm.litellm_core_utils.llm_request_utils import (
pick_cheapest_chat_model_from_llm_provider,
)
# this is a wildcard model, we need to pick a random model from the provider
cheapest_model = pick_cheapest_chat_model_from_llm_provider(
custom_llm_provider=custom_llm_provider
)
fallback_models: Optional[List] = None
if custom_llm_provider in litellm.models_by_provider:
models = litellm.models_by_provider[custom_llm_provider]
random.shuffle(models) # Shuffle the models list in place
fallback_models = models[
:2
] # Pick the first 2 models from the shuffled list
model_params["model"] = cheapest_model
model_params["fallbacks"] = fallback_models
model_params["max_tokens"] = 1
await acompletion(**model_params)
response: dict = {} # args like remaining ratelimit etc.
else: # default to completion calls
model_params["max_tokens"] = 1
await acompletion(**model_params)
response = {} # args like remaining ratelimit etc.
return response
async def ahealth_check( # noqa: PLR0915
model_params: dict,
mode: Optional[
@ -5128,21 +5176,12 @@ async def ahealth_check( # noqa: PLR0915
model_params["documents"] = ["my sample text"]
await litellm.arerank(**model_params)
response = {}
elif "*" in model:
from litellm.litellm_core_utils.llm_request_utils import (
pick_cheapest_chat_model_from_llm_provider,
else:
response = await ahealth_check_chat_models(
model=model,
custom_llm_provider=custom_llm_provider,
model_params=model_params,
)
# this is a wildcard model, we need to pick a random model from the provider
cheapest_model = pick_cheapest_chat_model_from_llm_provider(
custom_llm_provider=custom_llm_provider
)
model_params["model"] = cheapest_model
await acompletion(**model_params)
response = {} # args like remaining ratelimit etc.
else: # default to completion calls
await acompletion(**model_params)
response = {} # args like remaining ratelimit etc.
return response
except Exception as e:
stack_trace = traceback.format_exc()