mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(acompletion): support fallbacks on acompletion (#7184)
* fix(acompletion): support fallbacks on acompletion allows health checks for wildcard routes to use fallback models * test: update cohere generate api testing * add max tokens to health check (#7000) * fix: fix health check test * test: update testing --------- Co-authored-by: Cameron <561860+wallies@users.noreply.github.com>
This commit is contained in:
parent
5fe77499d2
commit
a9aeb21d0b
8 changed files with 240 additions and 69 deletions
|
@ -64,6 +64,7 @@ from litellm.secret_managers.main import get_secret_str
|
|||
from litellm.utils import (
|
||||
CustomStreamWrapper,
|
||||
Usage,
|
||||
async_completion_with_fallbacks,
|
||||
async_mock_completion_streaming_obj,
|
||||
completion_with_fallbacks,
|
||||
convert_to_model_response_object,
|
||||
|
@ -364,6 +365,8 @@ async def acompletion(
|
|||
- The `completion` function is called using `run_in_executor` to execute synchronously in the event loop.
|
||||
- If `stream` is True, the function returns an async generator that yields completion lines.
|
||||
"""
|
||||
fallbacks = kwargs.get("fallbacks", None)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
custom_llm_provider = kwargs.get("custom_llm_provider", None)
|
||||
# Adjusted to use explicit arguments instead of *args and **kwargs
|
||||
|
@ -407,6 +410,18 @@ async def acompletion(
|
|||
_, custom_llm_provider, _, _ = get_llm_provider(
|
||||
model=model, api_base=completion_kwargs.get("base_url", None)
|
||||
)
|
||||
|
||||
fallbacks = fallbacks or litellm.model_fallbacks
|
||||
if fallbacks is not None:
|
||||
response = await async_completion_with_fallbacks(
|
||||
**completion_kwargs, kwargs={"fallbacks": fallbacks}
|
||||
)
|
||||
if response is None:
|
||||
raise Exception(
|
||||
"No response from fallbacks. Got none. Turn on `litellm.set_verbose=True` to see more details."
|
||||
)
|
||||
return response
|
||||
|
||||
try:
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(completion, **completion_kwargs, **kwargs)
|
||||
|
@ -1884,6 +1899,7 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
encoding=encoding,
|
||||
api_key=cohere_key,
|
||||
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
||||
client=client,
|
||||
)
|
||||
elif custom_llm_provider == "cohere_chat":
|
||||
cohere_key = (
|
||||
|
@ -4997,6 +5013,38 @@ def speech(
|
|||
##### Health Endpoints #######################
|
||||
|
||||
|
||||
async def ahealth_check_chat_models(
|
||||
model: str, custom_llm_provider: str, model_params: dict
|
||||
) -> dict:
|
||||
if "*" in model:
|
||||
from litellm.litellm_core_utils.llm_request_utils import (
|
||||
pick_cheapest_chat_model_from_llm_provider,
|
||||
)
|
||||
|
||||
# this is a wildcard model, we need to pick a random model from the provider
|
||||
cheapest_model = pick_cheapest_chat_model_from_llm_provider(
|
||||
custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
fallback_models: Optional[List] = None
|
||||
if custom_llm_provider in litellm.models_by_provider:
|
||||
models = litellm.models_by_provider[custom_llm_provider]
|
||||
random.shuffle(models) # Shuffle the models list in place
|
||||
fallback_models = models[
|
||||
:2
|
||||
] # Pick the first 2 models from the shuffled list
|
||||
model_params["model"] = cheapest_model
|
||||
model_params["fallbacks"] = fallback_models
|
||||
model_params["max_tokens"] = 1
|
||||
await acompletion(**model_params)
|
||||
response: dict = {} # args like remaining ratelimit etc.
|
||||
else: # default to completion calls
|
||||
model_params["max_tokens"] = 1
|
||||
await acompletion(**model_params)
|
||||
response = {} # args like remaining ratelimit etc.
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def ahealth_check( # noqa: PLR0915
|
||||
model_params: dict,
|
||||
mode: Optional[
|
||||
|
@ -5128,21 +5176,12 @@ async def ahealth_check( # noqa: PLR0915
|
|||
model_params["documents"] = ["my sample text"]
|
||||
await litellm.arerank(**model_params)
|
||||
response = {}
|
||||
elif "*" in model:
|
||||
from litellm.litellm_core_utils.llm_request_utils import (
|
||||
pick_cheapest_chat_model_from_llm_provider,
|
||||
else:
|
||||
response = await ahealth_check_chat_models(
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
model_params=model_params,
|
||||
)
|
||||
|
||||
# this is a wildcard model, we need to pick a random model from the provider
|
||||
cheapest_model = pick_cheapest_chat_model_from_llm_provider(
|
||||
custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
model_params["model"] = cheapest_model
|
||||
await acompletion(**model_params)
|
||||
response = {} # args like remaining ratelimit etc.
|
||||
else: # default to completion calls
|
||||
await acompletion(**model_params)
|
||||
response = {} # args like remaining ratelimit etc.
|
||||
return response
|
||||
except Exception as e:
|
||||
stack_trace = traceback.format_exc()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue