fix(acompletion): support fallbacks on acompletion (#7184)

* fix(acompletion): support fallbacks on acompletion

allows health checks for wildcard routes to use fallback models

* test: update cohere generate api testing

* add max tokens to health check (#7000)

* fix: fix health check test

* test: update testing

---------

Co-authored-by: Cameron <561860+wallies@users.noreply.github.com>
This commit is contained in:
Krish Dholakia 2024-12-11 19:20:54 -08:00 committed by GitHub
parent 5fe77499d2
commit a9aeb21d0b
8 changed files with 240 additions and 69 deletions

View file

@ -534,7 +534,7 @@ def add_known_models():
gemini_models.append(key) gemini_models.append(key)
elif value.get("litellm_provider") == "fireworks_ai": elif value.get("litellm_provider") == "fireworks_ai":
# ignore the 'up-to', '-to-' model names -> not real models. just for cost tracking based on model params. # ignore the 'up-to', '-to-' model names -> not real models. just for cost tracking based on model params.
if "-to-" not in key: if "-to-" not in key and "fireworks-ai-default" not in key:
fireworks_ai_models.append(key) fireworks_ai_models.append(key)
elif value.get("litellm_provider") == "fireworks_ai-embedding-models": elif value.get("litellm_provider") == "fireworks_ai-embedding-models":
# ignore the 'up-to', '-to-' model names -> not real models. just for cost tracking based on model params. # ignore the 'up-to', '-to-' model names -> not real models. just for cost tracking based on model params.

View file

@ -189,7 +189,11 @@ def exception_type( # type: ignore # noqa: PLR0915
#################### Start of Provider Exception mapping #################### #################### Start of Provider Exception mapping ####################
################################################################################ ################################################################################
if "Request Timeout Error" in error_str or "Request timed out" in error_str: if (
"Request Timeout Error" in error_str
or "Request timed out" in error_str
or "Timed out generating response" in error_str
):
exception_mapping_worked = True exception_mapping_worked = True
raise Timeout( raise Timeout(
message=f"APITimeoutError - Request timed out. \nerror_str: {error_str}", message=f"APITimeoutError - Request timed out. \nerror_str: {error_str}",

View file

@ -23,6 +23,7 @@ from litellm import verbose_logger
from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler, HTTPHandler,
_get_httpx_client, _get_httpx_client,
get_async_httpx_client, get_async_httpx_client,
@ -54,10 +55,15 @@ class BaseLLMHTTPHandler:
litellm_params: dict, litellm_params: dict,
encoding: Any, encoding: Any,
api_key: Optional[str] = None, api_key: Optional[str] = None,
client: Optional[AsyncHTTPHandler] = None,
): ):
async_httpx_client = get_async_httpx_client( if client is None:
llm_provider=litellm.LlmProviders(custom_llm_provider) async_httpx_client = get_async_httpx_client(
) llm_provider=litellm.LlmProviders(custom_llm_provider)
)
else:
async_httpx_client = client
try: try:
response = await async_httpx_client.post( response = await async_httpx_client.post(
url=api_base, url=api_base,
@ -97,6 +103,7 @@ class BaseLLMHTTPHandler:
fake_stream: bool = False, fake_stream: bool = False,
api_key: Optional[str] = None, api_key: Optional[str] = None,
headers={}, headers={},
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
): ):
provider_config = ProviderConfigManager.get_provider_chat_config( provider_config = ProviderConfigManager.get_provider_chat_config(
model=model, provider=litellm.LlmProviders(custom_llm_provider) model=model, provider=litellm.LlmProviders(custom_llm_provider)
@ -149,6 +156,11 @@ class BaseLLMHTTPHandler:
logging_obj=logging_obj, logging_obj=logging_obj,
data=data, data=data,
fake_stream=fake_stream, fake_stream=fake_stream,
client=(
client
if client is not None and isinstance(client, AsyncHTTPHandler)
else None
),
) )
else: else:
@ -167,6 +179,11 @@ class BaseLLMHTTPHandler:
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params, litellm_params=litellm_params,
encoding=encoding, encoding=encoding,
client=(
client
if client is not None and isinstance(client, AsyncHTTPHandler)
else None
),
) )
if stream is True: if stream is True:
@ -182,6 +199,11 @@ class BaseLLMHTTPHandler:
logging_obj=logging_obj, logging_obj=logging_obj,
timeout=timeout, timeout=timeout,
fake_stream=fake_stream, fake_stream=fake_stream,
client=(
client
if client is not None and isinstance(client, HTTPHandler)
else None
),
) )
return CustomStreamWrapper( return CustomStreamWrapper(
completion_stream=completion_stream, completion_stream=completion_stream,
@ -190,11 +212,14 @@ class BaseLLMHTTPHandler:
logging_obj=logging_obj, logging_obj=logging_obj,
) )
sync_httpx_client = _get_httpx_client() if client is None or not isinstance(client, HTTPHandler):
sync_httpx_client = _get_httpx_client()
else:
sync_httpx_client = client
try: try:
response = sync_httpx_client.post( response = sync_httpx_client.post(
api_base, url=api_base,
headers=headers, headers=headers,
data=json.dumps(data), data=json.dumps(data),
timeout=timeout, timeout=timeout,
@ -229,8 +254,12 @@ class BaseLLMHTTPHandler:
logging_obj, logging_obj,
timeout: Optional[Union[float, httpx.Timeout]], timeout: Optional[Union[float, httpx.Timeout]],
fake_stream: bool = False, fake_stream: bool = False,
client: Optional[HTTPHandler] = None,
) -> Tuple[Any, httpx.Headers]: ) -> Tuple[Any, httpx.Headers]:
sync_httpx_client = _get_httpx_client() if client is None or not isinstance(client, HTTPHandler):
sync_httpx_client = _get_httpx_client()
else:
sync_httpx_client = client
try: try:
stream = True stream = True
if fake_stream is True: if fake_stream is True:
@ -289,6 +318,7 @@ class BaseLLMHTTPHandler:
logging_obj: LiteLLMLoggingObj, logging_obj: LiteLLMLoggingObj,
data: dict, data: dict,
fake_stream: bool = False, fake_stream: bool = False,
client: Optional[AsyncHTTPHandler] = None,
): ):
completion_stream, _response_headers = await self.make_async_call( completion_stream, _response_headers = await self.make_async_call(
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
@ -300,6 +330,7 @@ class BaseLLMHTTPHandler:
logging_obj=logging_obj, logging_obj=logging_obj,
timeout=timeout, timeout=timeout,
fake_stream=fake_stream, fake_stream=fake_stream,
client=client,
) )
streamwrapper = CustomStreamWrapper( streamwrapper = CustomStreamWrapper(
completion_stream=completion_stream, completion_stream=completion_stream,
@ -320,10 +351,14 @@ class BaseLLMHTTPHandler:
logging_obj: LiteLLMLoggingObj, logging_obj: LiteLLMLoggingObj,
timeout: Optional[Union[float, httpx.Timeout]], timeout: Optional[Union[float, httpx.Timeout]],
fake_stream: bool = False, fake_stream: bool = False,
client: Optional[AsyncHTTPHandler] = None,
) -> Tuple[Any, httpx.Headers]: ) -> Tuple[Any, httpx.Headers]:
async_httpx_client = get_async_httpx_client( if client is None:
llm_provider=litellm.LlmProviders(custom_llm_provider) async_httpx_client = get_async_httpx_client(
) llm_provider=litellm.LlmProviders(custom_llm_provider)
)
else:
async_httpx_client = client
stream = True stream = True
if fake_stream is True: if fake_stream is True:
stream = False stream = False

View file

@ -64,6 +64,7 @@ from litellm.secret_managers.main import get_secret_str
from litellm.utils import ( from litellm.utils import (
CustomStreamWrapper, CustomStreamWrapper,
Usage, Usage,
async_completion_with_fallbacks,
async_mock_completion_streaming_obj, async_mock_completion_streaming_obj,
completion_with_fallbacks, completion_with_fallbacks,
convert_to_model_response_object, convert_to_model_response_object,
@ -364,6 +365,8 @@ async def acompletion(
- The `completion` function is called using `run_in_executor` to execute synchronously in the event loop. - The `completion` function is called using `run_in_executor` to execute synchronously in the event loop.
- If `stream` is True, the function returns an async generator that yields completion lines. - If `stream` is True, the function returns an async generator that yields completion lines.
""" """
fallbacks = kwargs.get("fallbacks", None)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
custom_llm_provider = kwargs.get("custom_llm_provider", None) custom_llm_provider = kwargs.get("custom_llm_provider", None)
# Adjusted to use explicit arguments instead of *args and **kwargs # Adjusted to use explicit arguments instead of *args and **kwargs
@ -407,6 +410,18 @@ async def acompletion(
_, custom_llm_provider, _, _ = get_llm_provider( _, custom_llm_provider, _, _ = get_llm_provider(
model=model, api_base=completion_kwargs.get("base_url", None) model=model, api_base=completion_kwargs.get("base_url", None)
) )
fallbacks = fallbacks or litellm.model_fallbacks
if fallbacks is not None:
response = await async_completion_with_fallbacks(
**completion_kwargs, kwargs={"fallbacks": fallbacks}
)
if response is None:
raise Exception(
"No response from fallbacks. Got none. Turn on `litellm.set_verbose=True` to see more details."
)
return response
try: try:
# Use a partial function to pass your keyword arguments # Use a partial function to pass your keyword arguments
func = partial(completion, **completion_kwargs, **kwargs) func = partial(completion, **completion_kwargs, **kwargs)
@ -1884,6 +1899,7 @@ def completion( # type: ignore # noqa: PLR0915
encoding=encoding, encoding=encoding,
api_key=cohere_key, api_key=cohere_key,
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
client=client,
) )
elif custom_llm_provider == "cohere_chat": elif custom_llm_provider == "cohere_chat":
cohere_key = ( cohere_key = (
@ -4997,6 +5013,38 @@ def speech(
##### Health Endpoints ####################### ##### Health Endpoints #######################
async def ahealth_check_chat_models(
model: str, custom_llm_provider: str, model_params: dict
) -> dict:
if "*" in model:
from litellm.litellm_core_utils.llm_request_utils import (
pick_cheapest_chat_model_from_llm_provider,
)
# this is a wildcard model, we need to pick a random model from the provider
cheapest_model = pick_cheapest_chat_model_from_llm_provider(
custom_llm_provider=custom_llm_provider
)
fallback_models: Optional[List] = None
if custom_llm_provider in litellm.models_by_provider:
models = litellm.models_by_provider[custom_llm_provider]
random.shuffle(models) # Shuffle the models list in place
fallback_models = models[
:2
] # Pick the first 2 models from the shuffled list
model_params["model"] = cheapest_model
model_params["fallbacks"] = fallback_models
model_params["max_tokens"] = 1
await acompletion(**model_params)
response: dict = {} # args like remaining ratelimit etc.
else: # default to completion calls
model_params["max_tokens"] = 1
await acompletion(**model_params)
response = {} # args like remaining ratelimit etc.
return response
async def ahealth_check( # noqa: PLR0915 async def ahealth_check( # noqa: PLR0915
model_params: dict, model_params: dict,
mode: Optional[ mode: Optional[
@ -5128,21 +5176,12 @@ async def ahealth_check( # noqa: PLR0915
model_params["documents"] = ["my sample text"] model_params["documents"] = ["my sample text"]
await litellm.arerank(**model_params) await litellm.arerank(**model_params)
response = {} response = {}
elif "*" in model: else:
from litellm.litellm_core_utils.llm_request_utils import ( response = await ahealth_check_chat_models(
pick_cheapest_chat_model_from_llm_provider, model=model,
custom_llm_provider=custom_llm_provider,
model_params=model_params,
) )
# this is a wildcard model, we need to pick a random model from the provider
cheapest_model = pick_cheapest_chat_model_from_llm_provider(
custom_llm_provider=custom_llm_provider
)
model_params["model"] = cheapest_model
await acompletion(**model_params)
response = {} # args like remaining ratelimit etc.
else: # default to completion calls
await acompletion(**model_params)
response = {} # args like remaining ratelimit etc.
return response return response
except Exception as e: except Exception as e:
stack_trace = traceback.format_exc() stack_trace = traceback.format_exc()

View file

@ -5697,6 +5697,72 @@ def completion_with_fallbacks(**kwargs):
return response return response
async def async_completion_with_fallbacks(**kwargs):
nested_kwargs = kwargs.pop("kwargs", {})
response = None
rate_limited_models = set()
model_expiration_times = {}
start_time = time.time()
original_model = kwargs["model"]
fallbacks = [kwargs["model"]] + nested_kwargs.get("fallbacks", [])
if "fallbacks" in nested_kwargs:
del nested_kwargs["fallbacks"] # remove fallbacks so it's not recursive
if "acompletion" in kwargs:
del kwargs[
"acompletion"
] # remove acompletion so it doesn't lead to keyword errors
litellm_call_id = str(uuid.uuid4())
# max time to process a request with fallbacks: default 45s
while response is None and time.time() - start_time < 45:
for model in fallbacks:
# loop thru all models
try:
# check if it's dict or new model string
if isinstance(
model, dict
): # completion(model="gpt-4", fallbacks=[{"api_key": "", "api_base": ""}, {"api_key": "", "api_base": ""}])
kwargs["api_key"] = model.get("api_key", None)
kwargs["api_base"] = model.get("api_base", None)
model = model.get("model", original_model)
elif (
model in rate_limited_models
): # check if model is currently cooling down
if (
model_expiration_times.get(model)
and time.time() >= model_expiration_times[model]
):
rate_limited_models.remove(
model
) # check if it's been 60s of cool down and remove model
else:
continue # skip model
# delete model from kwargs if it exists
if kwargs.get("model"):
del kwargs["model"]
print_verbose(f"trying to make completion call with model: {model}")
kwargs["litellm_call_id"] = litellm_call_id
kwargs = {
**kwargs,
**nested_kwargs,
} # combine the openai + litellm params at the same level
response = await litellm.acompletion(**kwargs, model=model)
print_verbose(f"response: {response}")
if response is not None:
return response
except Exception as e:
print_verbose(f"error: {e}")
rate_limited_models.add(model)
model_expiration_times[model] = (
time.time() + 60
) # cool down this selected model
pass
return response
def process_system_message(system_message, max_tokens, model): def process_system_message(system_message, max_tokens, model):
system_message_event = {"role": "system", "content": system_message} system_message_event = {"role": "system", "content": system_message}
system_message_tokens = get_token_count([system_message_event], model) system_message_tokens = get_token_count([system_message_event], model)

View file

@ -20,10 +20,13 @@ from litellm import completion
from litellm.llms.cohere.completion.transformation import CohereTextConfig from litellm.llms.cohere.completion.transformation import CohereTextConfig
@pytest.mark.asyncio def test_cohere_generate_api_completion():
async def test_cohere_generate_api_completion():
try: try:
litellm.set_verbose = False from litellm.llms.custom_httpx.http_handler import HTTPHandler
from unittest.mock import patch, MagicMock
client = HTTPHandler()
litellm.set_verbose = True
messages = [ messages = [
{"role": "system", "content": "You're a good bot"}, {"role": "system", "content": "You're a good bot"},
{ {
@ -31,12 +34,28 @@ async def test_cohere_generate_api_completion():
"content": "Hey", "content": "Hey",
}, },
] ]
response = completion(
model="cohere/command-nightly", with patch.object(client, "post") as mock_client:
messages=messages, try:
max_tokens=10, completion(
) model="cohere/command",
print(response) messages=messages,
max_tokens=10,
client=client,
)
except Exception as e:
print(e)
mock_client.assert_called_once()
print("mock_client.call_args.kwargs", mock_client.call_args.kwargs)
assert (
mock_client.call_args.kwargs["url"]
== "https://api.cohere.ai/v1/generate"
)
json_data = json.loads(mock_client.call_args.kwargs["data"])
assert json_data["model"] == "command"
assert json_data["prompt"] == "You're a good bot Hey"
assert json_data["max_tokens"] == 10
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
@ -53,7 +72,7 @@ async def test_cohere_generate_api_stream():
}, },
] ]
response = await litellm.acompletion( response = await litellm.acompletion(
model="cohere/command-nightly", model="cohere/command",
messages=messages, messages=messages,
max_tokens=10, max_tokens=10,
stream=True, stream=True,
@ -76,7 +95,7 @@ def test_completion_cohere_stream_bad_key():
}, },
] ]
completion( completion(
model="command-nightly", model="command",
messages=messages, messages=messages,
stream=True, stream=True,
max_tokens=50, max_tokens=50,
@ -100,7 +119,7 @@ def test_cohere_transform_request():
headers = {} headers = {}
transformed_request = config.transform_request( transformed_request = config.transform_request(
model="command-nightly", model="command",
messages=messages, messages=messages,
optional_params=optional_params, optional_params=optional_params,
litellm_params={}, litellm_params={},
@ -109,7 +128,7 @@ def test_cohere_transform_request():
print("transformed_request", json.dumps(transformed_request, indent=4)) print("transformed_request", json.dumps(transformed_request, indent=4))
assert transformed_request["model"] == "command-nightly" assert transformed_request["model"] == "command"
assert transformed_request["prompt"] == "You're a helpful bot Hello" assert transformed_request["prompt"] == "You're a helpful bot Hello"
assert transformed_request["max_tokens"] == 10 assert transformed_request["max_tokens"] == 10
assert transformed_request["temperature"] == 0.7 assert transformed_request["temperature"] == 0.7
@ -137,7 +156,7 @@ def test_cohere_transform_request_with_tools():
optional_params = {"tools": tools} optional_params = {"tools": tools}
transformed_request = config.transform_request( transformed_request = config.transform_request(
model="command-nightly", model="command",
messages=messages, messages=messages,
optional_params=optional_params, optional_params=optional_params,
litellm_params={}, litellm_params={},
@ -168,7 +187,7 @@ def test_cohere_map_openai_params():
mapped_params = config.map_openai_params( mapped_params = config.map_openai_params(
non_default_params=openai_params, non_default_params=openai_params,
optional_params={}, optional_params={},
model="command-nightly", model="command",
drop_params=False, drop_params=False,
) )

View file

@ -56,13 +56,17 @@ async def test_audio_output_from_model(stream):
if stream is False: if stream is False:
audio_format = "wav" audio_format = "wav"
litellm.set_verbose = False litellm.set_verbose = False
completion = await litellm.acompletion( try:
model="gpt-4o-audio-preview", completion = await litellm.acompletion(
modalities=["text", "audio"], model="gpt-4o-audio-preview",
audio={"voice": "alloy", "format": "pcm16"}, modalities=["text", "audio"],
messages=[{"role": "user", "content": "response in 1 word - yes or no"}], audio={"voice": "alloy", "format": "pcm16"},
stream=stream, messages=[{"role": "user", "content": "response in 1 word - yes or no"}],
) stream=stream,
)
except litellm.Timeout as e:
print(e)
pytest.skip("Skipping test due to timeout")
if stream is True: if stream is True:
await check_streaming_response(completion) await check_streaming_response(completion)
@ -88,25 +92,28 @@ async def test_audio_input_to_model(stream):
response.raise_for_status() response.raise_for_status()
wav_data = response.content wav_data = response.content
encoded_string = base64.b64encode(wav_data).decode("utf-8") encoded_string = base64.b64encode(wav_data).decode("utf-8")
try:
completion = await litellm.acompletion( completion = await litellm.acompletion(
model="gpt-4o-audio-preview", model="gpt-4o-audio-preview",
modalities=["text", "audio"], modalities=["text", "audio"],
audio={"voice": "alloy", "format": audio_format}, audio={"voice": "alloy", "format": audio_format},
stream=stream, stream=stream,
messages=[ messages=[
{ {
"role": "user", "role": "user",
"content": [ "content": [
{"type": "text", "text": "What is in this recording?"}, {"type": "text", "text": "What is in this recording?"},
{ {
"type": "input_audio", "type": "input_audio",
"input_audio": {"data": encoded_string, "format": "wav"}, "input_audio": {"data": encoded_string, "format": "wav"},
}, },
], ],
}, },
], ],
) )
except litellm.Timeout as e:
print(e)
pytest.skip("Skipping test due to timeout")
if stream is True: if stream is True:
await check_streaming_response(completion) await check_streaming_response(completion)

View file

@ -112,16 +112,17 @@ async def test_sagemaker_embedding_health_check():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_fireworks_health_check(): async def test_groq_health_check():
""" """
This should not fail This should not fail
ensure that provider wildcard model passes health check ensure that provider wildcard model passes health check
""" """
litellm.set_verbose = True
response = await litellm.ahealth_check( response = await litellm.ahealth_check(
model_params={ model_params={
"api_key": os.environ.get("FIREWORKS_AI_API_KEY"), "api_key": os.environ.get("GROQ_API_KEY"),
"model": "fireworks_ai/*", "model": "groq/*",
"messages": [{"role": "user", "content": "What's 1 + 1?"}], "messages": [{"role": "user", "content": "What's 1 + 1?"}],
}, },
mode=None, mode=None,