diff --git a/litellm/__init__.py b/litellm/__init__.py index 81e386f0f9..6f5afbba83 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -534,7 +534,7 @@ def add_known_models(): gemini_models.append(key) elif value.get("litellm_provider") == "fireworks_ai": # ignore the 'up-to', '-to-' model names -> not real models. just for cost tracking based on model params. - if "-to-" not in key: + if "-to-" not in key and "fireworks-ai-default" not in key: fireworks_ai_models.append(key) elif value.get("litellm_provider") == "fireworks_ai-embedding-models": # ignore the 'up-to', '-to-' model names -> not real models. just for cost tracking based on model params. diff --git a/litellm/litellm_core_utils/exception_mapping_utils.py b/litellm/litellm_core_utils/exception_mapping_utils.py index 7317e29284..c4228f0527 100644 --- a/litellm/litellm_core_utils/exception_mapping_utils.py +++ b/litellm/litellm_core_utils/exception_mapping_utils.py @@ -189,7 +189,11 @@ def exception_type( # type: ignore # noqa: PLR0915 #################### Start of Provider Exception mapping #################### ################################################################################ - if "Request Timeout Error" in error_str or "Request timed out" in error_str: + if ( + "Request Timeout Error" in error_str + or "Request timed out" in error_str + or "Timed out generating response" in error_str + ): exception_mapping_worked = True raise Timeout( message=f"APITimeoutError - Request timed out. \nerror_str: {error_str}", diff --git a/litellm/llms/custom_httpx/llm_http_handler.py b/litellm/llms/custom_httpx/llm_http_handler.py index e3114a5221..51e1dae6b0 100644 --- a/litellm/llms/custom_httpx/llm_http_handler.py +++ b/litellm/llms/custom_httpx/llm_http_handler.py @@ -23,6 +23,7 @@ from litellm import verbose_logger from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, HTTPHandler, _get_httpx_client, get_async_httpx_client, @@ -54,10 +55,15 @@ class BaseLLMHTTPHandler: litellm_params: dict, encoding: Any, api_key: Optional[str] = None, + client: Optional[AsyncHTTPHandler] = None, ): - async_httpx_client = get_async_httpx_client( - llm_provider=litellm.LlmProviders(custom_llm_provider) - ) + if client is None: + async_httpx_client = get_async_httpx_client( + llm_provider=litellm.LlmProviders(custom_llm_provider) + ) + else: + async_httpx_client = client + try: response = await async_httpx_client.post( url=api_base, @@ -97,6 +103,7 @@ class BaseLLMHTTPHandler: fake_stream: bool = False, api_key: Optional[str] = None, headers={}, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, ): provider_config = ProviderConfigManager.get_provider_chat_config( model=model, provider=litellm.LlmProviders(custom_llm_provider) @@ -149,6 +156,11 @@ class BaseLLMHTTPHandler: logging_obj=logging_obj, data=data, fake_stream=fake_stream, + client=( + client + if client is not None and isinstance(client, AsyncHTTPHandler) + else None + ), ) else: @@ -167,6 +179,11 @@ class BaseLLMHTTPHandler: optional_params=optional_params, litellm_params=litellm_params, encoding=encoding, + client=( + client + if client is not None and isinstance(client, AsyncHTTPHandler) + else None + ), ) if stream is True: @@ -182,6 +199,11 @@ class BaseLLMHTTPHandler: logging_obj=logging_obj, timeout=timeout, fake_stream=fake_stream, + client=( + client + if client is not None and isinstance(client, HTTPHandler) + else None + ), ) return CustomStreamWrapper( completion_stream=completion_stream, @@ -190,11 +212,14 @@ class BaseLLMHTTPHandler: logging_obj=logging_obj, ) - sync_httpx_client = _get_httpx_client() + if client is None or not isinstance(client, HTTPHandler): + sync_httpx_client = _get_httpx_client() + else: + sync_httpx_client = client try: response = sync_httpx_client.post( - api_base, + url=api_base, headers=headers, data=json.dumps(data), timeout=timeout, @@ -229,8 +254,12 @@ class BaseLLMHTTPHandler: logging_obj, timeout: Optional[Union[float, httpx.Timeout]], fake_stream: bool = False, + client: Optional[HTTPHandler] = None, ) -> Tuple[Any, httpx.Headers]: - sync_httpx_client = _get_httpx_client() + if client is None or not isinstance(client, HTTPHandler): + sync_httpx_client = _get_httpx_client() + else: + sync_httpx_client = client try: stream = True if fake_stream is True: @@ -289,6 +318,7 @@ class BaseLLMHTTPHandler: logging_obj: LiteLLMLoggingObj, data: dict, fake_stream: bool = False, + client: Optional[AsyncHTTPHandler] = None, ): completion_stream, _response_headers = await self.make_async_call( custom_llm_provider=custom_llm_provider, @@ -300,6 +330,7 @@ class BaseLLMHTTPHandler: logging_obj=logging_obj, timeout=timeout, fake_stream=fake_stream, + client=client, ) streamwrapper = CustomStreamWrapper( completion_stream=completion_stream, @@ -320,10 +351,14 @@ class BaseLLMHTTPHandler: logging_obj: LiteLLMLoggingObj, timeout: Optional[Union[float, httpx.Timeout]], fake_stream: bool = False, + client: Optional[AsyncHTTPHandler] = None, ) -> Tuple[Any, httpx.Headers]: - async_httpx_client = get_async_httpx_client( - llm_provider=litellm.LlmProviders(custom_llm_provider) - ) + if client is None: + async_httpx_client = get_async_httpx_client( + llm_provider=litellm.LlmProviders(custom_llm_provider) + ) + else: + async_httpx_client = client stream = True if fake_stream is True: stream = False diff --git a/litellm/main.py b/litellm/main.py index 36508cc9a0..b0d87e41d8 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -64,6 +64,7 @@ from litellm.secret_managers.main import get_secret_str from litellm.utils import ( CustomStreamWrapper, Usage, + async_completion_with_fallbacks, async_mock_completion_streaming_obj, completion_with_fallbacks, convert_to_model_response_object, @@ -364,6 +365,8 @@ async def acompletion( - The `completion` function is called using `run_in_executor` to execute synchronously in the event loop. - If `stream` is True, the function returns an async generator that yields completion lines. """ + fallbacks = kwargs.get("fallbacks", None) + loop = asyncio.get_event_loop() custom_llm_provider = kwargs.get("custom_llm_provider", None) # Adjusted to use explicit arguments instead of *args and **kwargs @@ -407,6 +410,18 @@ async def acompletion( _, custom_llm_provider, _, _ = get_llm_provider( model=model, api_base=completion_kwargs.get("base_url", None) ) + + fallbacks = fallbacks or litellm.model_fallbacks + if fallbacks is not None: + response = await async_completion_with_fallbacks( + **completion_kwargs, kwargs={"fallbacks": fallbacks} + ) + if response is None: + raise Exception( + "No response from fallbacks. Got none. Turn on `litellm.set_verbose=True` to see more details." + ) + return response + try: # Use a partial function to pass your keyword arguments func = partial(completion, **completion_kwargs, **kwargs) @@ -1884,6 +1899,7 @@ def completion( # type: ignore # noqa: PLR0915 encoding=encoding, api_key=cohere_key, logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements + client=client, ) elif custom_llm_provider == "cohere_chat": cohere_key = ( @@ -4997,6 +5013,38 @@ def speech( ##### Health Endpoints ####################### +async def ahealth_check_chat_models( + model: str, custom_llm_provider: str, model_params: dict +) -> dict: + if "*" in model: + from litellm.litellm_core_utils.llm_request_utils import ( + pick_cheapest_chat_model_from_llm_provider, + ) + + # this is a wildcard model, we need to pick a random model from the provider + cheapest_model = pick_cheapest_chat_model_from_llm_provider( + custom_llm_provider=custom_llm_provider + ) + fallback_models: Optional[List] = None + if custom_llm_provider in litellm.models_by_provider: + models = litellm.models_by_provider[custom_llm_provider] + random.shuffle(models) # Shuffle the models list in place + fallback_models = models[ + :2 + ] # Pick the first 2 models from the shuffled list + model_params["model"] = cheapest_model + model_params["fallbacks"] = fallback_models + model_params["max_tokens"] = 1 + await acompletion(**model_params) + response: dict = {} # args like remaining ratelimit etc. + else: # default to completion calls + model_params["max_tokens"] = 1 + await acompletion(**model_params) + response = {} # args like remaining ratelimit etc. + + return response + + async def ahealth_check( # noqa: PLR0915 model_params: dict, mode: Optional[ @@ -5128,21 +5176,12 @@ async def ahealth_check( # noqa: PLR0915 model_params["documents"] = ["my sample text"] await litellm.arerank(**model_params) response = {} - elif "*" in model: - from litellm.litellm_core_utils.llm_request_utils import ( - pick_cheapest_chat_model_from_llm_provider, + else: + response = await ahealth_check_chat_models( + model=model, + custom_llm_provider=custom_llm_provider, + model_params=model_params, ) - - # this is a wildcard model, we need to pick a random model from the provider - cheapest_model = pick_cheapest_chat_model_from_llm_provider( - custom_llm_provider=custom_llm_provider - ) - model_params["model"] = cheapest_model - await acompletion(**model_params) - response = {} # args like remaining ratelimit etc. - else: # default to completion calls - await acompletion(**model_params) - response = {} # args like remaining ratelimit etc. return response except Exception as e: stack_trace = traceback.format_exc() diff --git a/litellm/utils.py b/litellm/utils.py index d0afd1831a..6bcea7174b 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5697,6 +5697,72 @@ def completion_with_fallbacks(**kwargs): return response +async def async_completion_with_fallbacks(**kwargs): + nested_kwargs = kwargs.pop("kwargs", {}) + response = None + rate_limited_models = set() + model_expiration_times = {} + start_time = time.time() + original_model = kwargs["model"] + fallbacks = [kwargs["model"]] + nested_kwargs.get("fallbacks", []) + if "fallbacks" in nested_kwargs: + del nested_kwargs["fallbacks"] # remove fallbacks so it's not recursive + if "acompletion" in kwargs: + del kwargs[ + "acompletion" + ] # remove acompletion so it doesn't lead to keyword errors + litellm_call_id = str(uuid.uuid4()) + + # max time to process a request with fallbacks: default 45s + while response is None and time.time() - start_time < 45: + for model in fallbacks: + # loop thru all models + try: + # check if it's dict or new model string + if isinstance( + model, dict + ): # completion(model="gpt-4", fallbacks=[{"api_key": "", "api_base": ""}, {"api_key": "", "api_base": ""}]) + kwargs["api_key"] = model.get("api_key", None) + kwargs["api_base"] = model.get("api_base", None) + model = model.get("model", original_model) + elif ( + model in rate_limited_models + ): # check if model is currently cooling down + if ( + model_expiration_times.get(model) + and time.time() >= model_expiration_times[model] + ): + rate_limited_models.remove( + model + ) # check if it's been 60s of cool down and remove model + else: + continue # skip model + + # delete model from kwargs if it exists + if kwargs.get("model"): + del kwargs["model"] + + print_verbose(f"trying to make completion call with model: {model}") + kwargs["litellm_call_id"] = litellm_call_id + kwargs = { + **kwargs, + **nested_kwargs, + } # combine the openai + litellm params at the same level + response = await litellm.acompletion(**kwargs, model=model) + print_verbose(f"response: {response}") + if response is not None: + return response + + except Exception as e: + print_verbose(f"error: {e}") + rate_limited_models.add(model) + model_expiration_times[model] = ( + time.time() + 60 + ) # cool down this selected model + pass + return response + + def process_system_message(system_message, max_tokens, model): system_message_event = {"role": "system", "content": system_message} system_message_tokens = get_token_count([system_message_event], model) diff --git a/tests/llm_translation/test_cohere_generate_api.py b/tests/llm_translation/test_cohere_generate_api.py index 9e0bb82846..837e077cde 100644 --- a/tests/llm_translation/test_cohere_generate_api.py +++ b/tests/llm_translation/test_cohere_generate_api.py @@ -20,10 +20,13 @@ from litellm import completion from litellm.llms.cohere.completion.transformation import CohereTextConfig -@pytest.mark.asyncio -async def test_cohere_generate_api_completion(): +def test_cohere_generate_api_completion(): try: - litellm.set_verbose = False + from litellm.llms.custom_httpx.http_handler import HTTPHandler + from unittest.mock import patch, MagicMock + + client = HTTPHandler() + litellm.set_verbose = True messages = [ {"role": "system", "content": "You're a good bot"}, { @@ -31,12 +34,28 @@ async def test_cohere_generate_api_completion(): "content": "Hey", }, ] - response = completion( - model="cohere/command-nightly", - messages=messages, - max_tokens=10, - ) - print(response) + + with patch.object(client, "post") as mock_client: + try: + completion( + model="cohere/command", + messages=messages, + max_tokens=10, + client=client, + ) + except Exception as e: + print(e) + mock_client.assert_called_once() + print("mock_client.call_args.kwargs", mock_client.call_args.kwargs) + + assert ( + mock_client.call_args.kwargs["url"] + == "https://api.cohere.ai/v1/generate" + ) + json_data = json.loads(mock_client.call_args.kwargs["data"]) + assert json_data["model"] == "command" + assert json_data["prompt"] == "You're a good bot Hey" + assert json_data["max_tokens"] == 10 except Exception as e: pytest.fail(f"Error occurred: {e}") @@ -53,7 +72,7 @@ async def test_cohere_generate_api_stream(): }, ] response = await litellm.acompletion( - model="cohere/command-nightly", + model="cohere/command", messages=messages, max_tokens=10, stream=True, @@ -76,7 +95,7 @@ def test_completion_cohere_stream_bad_key(): }, ] completion( - model="command-nightly", + model="command", messages=messages, stream=True, max_tokens=50, @@ -100,7 +119,7 @@ def test_cohere_transform_request(): headers = {} transformed_request = config.transform_request( - model="command-nightly", + model="command", messages=messages, optional_params=optional_params, litellm_params={}, @@ -109,7 +128,7 @@ def test_cohere_transform_request(): print("transformed_request", json.dumps(transformed_request, indent=4)) - assert transformed_request["model"] == "command-nightly" + assert transformed_request["model"] == "command" assert transformed_request["prompt"] == "You're a helpful bot Hello" assert transformed_request["max_tokens"] == 10 assert transformed_request["temperature"] == 0.7 @@ -137,7 +156,7 @@ def test_cohere_transform_request_with_tools(): optional_params = {"tools": tools} transformed_request = config.transform_request( - model="command-nightly", + model="command", messages=messages, optional_params=optional_params, litellm_params={}, @@ -168,7 +187,7 @@ def test_cohere_map_openai_params(): mapped_params = config.map_openai_params( non_default_params=openai_params, optional_params={}, - model="command-nightly", + model="command", drop_params=False, ) diff --git a/tests/llm_translation/test_gpt4o_audio.py b/tests/llm_translation/test_gpt4o_audio.py index 2eae06a446..6174cac734 100644 --- a/tests/llm_translation/test_gpt4o_audio.py +++ b/tests/llm_translation/test_gpt4o_audio.py @@ -56,13 +56,17 @@ async def test_audio_output_from_model(stream): if stream is False: audio_format = "wav" litellm.set_verbose = False - completion = await litellm.acompletion( - model="gpt-4o-audio-preview", - modalities=["text", "audio"], - audio={"voice": "alloy", "format": "pcm16"}, - messages=[{"role": "user", "content": "response in 1 word - yes or no"}], - stream=stream, - ) + try: + completion = await litellm.acompletion( + model="gpt-4o-audio-preview", + modalities=["text", "audio"], + audio={"voice": "alloy", "format": "pcm16"}, + messages=[{"role": "user", "content": "response in 1 word - yes or no"}], + stream=stream, + ) + except litellm.Timeout as e: + print(e) + pytest.skip("Skipping test due to timeout") if stream is True: await check_streaming_response(completion) @@ -88,25 +92,28 @@ async def test_audio_input_to_model(stream): response.raise_for_status() wav_data = response.content encoded_string = base64.b64encode(wav_data).decode("utf-8") - - completion = await litellm.acompletion( - model="gpt-4o-audio-preview", - modalities=["text", "audio"], - audio={"voice": "alloy", "format": audio_format}, - stream=stream, - messages=[ - { - "role": "user", - "content": [ - {"type": "text", "text": "What is in this recording?"}, - { - "type": "input_audio", - "input_audio": {"data": encoded_string, "format": "wav"}, - }, - ], - }, - ], - ) + try: + completion = await litellm.acompletion( + model="gpt-4o-audio-preview", + modalities=["text", "audio"], + audio={"voice": "alloy", "format": audio_format}, + stream=stream, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What is in this recording?"}, + { + "type": "input_audio", + "input_audio": {"data": encoded_string, "format": "wav"}, + }, + ], + }, + ], + ) + except litellm.Timeout as e: + print(e) + pytest.skip("Skipping test due to timeout") if stream is True: await check_streaming_response(completion) diff --git a/tests/local_testing/test_health_check.py b/tests/local_testing/test_health_check.py index 71c6c42170..b0e8d1c3b0 100644 --- a/tests/local_testing/test_health_check.py +++ b/tests/local_testing/test_health_check.py @@ -112,16 +112,17 @@ async def test_sagemaker_embedding_health_check(): @pytest.mark.asyncio -async def test_fireworks_health_check(): +async def test_groq_health_check(): """ This should not fail ensure that provider wildcard model passes health check """ + litellm.set_verbose = True response = await litellm.ahealth_check( model_params={ - "api_key": os.environ.get("FIREWORKS_AI_API_KEY"), - "model": "fireworks_ai/*", + "api_key": os.environ.get("GROQ_API_KEY"), + "model": "groq/*", "messages": [{"role": "user", "content": "What's 1 + 1?"}], }, mode=None,