diff --git a/litellm/litellm_core_utils/get_llm_provider_logic.py b/litellm/litellm_core_utils/get_llm_provider_logic.py index 0bf74c5dca..037351d0e6 100644 --- a/litellm/litellm_core_utils/get_llm_provider_logic.py +++ b/litellm/litellm_core_utils/get_llm_provider_logic.py @@ -129,17 +129,15 @@ def get_llm_provider( # noqa: PLR0915 model, custom_llm_provider ) - if custom_llm_provider: - if ( - model.split("/")[0] == custom_llm_provider - ): # handle scenario where model="azure/*" and custom_llm_provider="azure" - model = model.replace("{}/".format(custom_llm_provider), "") - - return model, custom_llm_provider, dynamic_api_key, api_base + if custom_llm_provider and ( + model.split("/")[0] != custom_llm_provider + ): # handle scenario where model="azure/*" and custom_llm_provider="azure" + model = custom_llm_provider + "/" + model if api_key and api_key.startswith("os.environ/"): dynamic_api_key = get_secret_str(api_key) # check if llm provider part of model name + if ( model.split("/", 1)[0] in litellm.provider_list and model.split("/", 1)[0] not in litellm.model_list_set @@ -573,11 +571,11 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915 dynamic_api_key = api_key or get_secret_str("GALADRIEL_API_KEY") elif custom_llm_provider == "snowflake": api_base = ( - api_base - or get_secret("SNOWFLAKE_API_BASE") - or f"https://{get_secret('SNOWFLAKE_ACCOUNT_ID')}.snowflakecomputing.com/api/v2/cortex/inference:complete" - ) # type: ignore - dynamic_api_key = api_key or get_secret("SNOWFLAKE_JWT") + api_base + or get_secret_str("SNOWFLAKE_API_BASE") + or f"https://{get_secret('SNOWFLAKE_ACCOUNT_ID')}.snowflakecomputing.com/api/v2/cortex/inference:complete" + ) # type: ignore + dynamic_api_key = api_key or get_secret_str("SNOWFLAKE_JWT") if api_base is not None and not isinstance(api_base, str): raise Exception("api base needs to be a string. api_base={}".format(api_base)) diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 0945c45491..9b3efe94de 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -3258,6 +3258,7 @@ class StandardLoggingPayloadSetup: additional_headers=None, litellm_overhead_time_ms=None, batch_models=None, + litellm_model_name=None, ) if hidden_params is not None: for key in StandardLoggingHiddenParams.__annotations__.keys(): @@ -3372,6 +3373,7 @@ def get_standard_logging_object_payload( response_cost=None, litellm_overhead_time_ms=None, batch_models=None, + litellm_model_name=None, ) ) @@ -3657,6 +3659,7 @@ def create_dummy_standard_logging_payload() -> StandardLoggingPayload: additional_headers=None, litellm_overhead_time_ms=None, batch_models=None, + litellm_model_name=None, ) # Convert numeric values to appropriate types diff --git a/litellm/litellm_core_utils/llm_response_utils/response_metadata.py b/litellm/litellm_core_utils/llm_response_utils/response_metadata.py index 03595e27a4..84c80174f9 100644 --- a/litellm/litellm_core_utils/llm_response_utils/response_metadata.py +++ b/litellm/litellm_core_utils/llm_response_utils/response_metadata.py @@ -44,6 +44,7 @@ class ResponseMetadata: "additional_headers": process_response_headers( self._get_value_from_hidden_params("additional_headers") or {} ), + "litellm_model_name": model, } self._update_hidden_params(new_params) diff --git a/litellm/router_strategy/lowest_tpm_rpm_v2.py b/litellm/router_strategy/lowest_tpm_rpm_v2.py index 64f086036b..667246ea2f 100644 --- a/litellm/router_strategy/lowest_tpm_rpm_v2.py +++ b/litellm/router_strategy/lowest_tpm_rpm_v2.py @@ -65,10 +65,13 @@ class LowestTPMLoggingHandler_v2(CustomLogger): # ------------ # Setup values # ------------ + dt = get_utc_datetime() current_minute = dt.strftime("%H-%M") model_id = deployment.get("model_info", {}).get("id") - rpm_key = f"{model_id}:rpm:{current_minute}" + deployment_name = deployment.get("litellm_params", {}).get("model") + rpm_key = f"{model_id}:{deployment_name}:rpm:{current_minute}" + local_result = self.router_cache.get_cache( key=rpm_key, local_only=True ) # check local result first @@ -151,7 +154,9 @@ class LowestTPMLoggingHandler_v2(CustomLogger): dt = get_utc_datetime() current_minute = dt.strftime("%H-%M") model_id = deployment.get("model_info", {}).get("id") - rpm_key = f"{model_id}:rpm:{current_minute}" + deployment_name = deployment.get("litellm_params", {}).get("model") + + rpm_key = f"{model_id}:{deployment_name}:rpm:{current_minute}" local_result = await self.router_cache.async_get_cache( key=rpm_key, local_only=True ) # check local result first @@ -228,8 +233,9 @@ class LowestTPMLoggingHandler_v2(CustomLogger): if standard_logging_object is None: raise ValueError("standard_logging_object not passed in.") model_group = standard_logging_object.get("model_group") + model = standard_logging_object["hidden_params"].get("litellm_model_name") id = standard_logging_object.get("model_id") - if model_group is None or id is None: + if model_group is None or id is None or model is None: return elif isinstance(id, int): id = str(id) @@ -244,7 +250,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger): "%H-%M" ) # use the same timezone regardless of system clock - tpm_key = f"{id}:tpm:{current_minute}" + tpm_key = f"{id}:{model}:tpm:{current_minute}" # ------------ # Update usage # ------------ @@ -276,6 +282,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger): if standard_logging_object is None: raise ValueError("standard_logging_object not passed in.") model_group = standard_logging_object.get("model_group") + model = standard_logging_object["hidden_params"]["litellm_model_name"] id = standard_logging_object.get("model_id") if model_group is None or id is None: return @@ -290,7 +297,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger): "%H-%M" ) # use the same timezone regardless of system clock - tpm_key = f"{id}:tpm:{current_minute}" + tpm_key = f"{id}:{model}:tpm:{current_minute}" # ------------ # Update usage # ------------ @@ -458,8 +465,9 @@ class LowestTPMLoggingHandler_v2(CustomLogger): id = m.get("model_info", {}).get( "id" ) # a deployment should always have an 'id'. this is set in router.py - tpm_key = "{}:tpm:{}".format(id, current_minute) - rpm_key = "{}:rpm:{}".format(id, current_minute) + deployment_name = m.get("litellm_params", {}).get("model") + tpm_key = "{}:{}:tpm:{}".format(id, deployment_name, current_minute) + rpm_key = "{}:{}:rpm:{}".format(id, deployment_name, current_minute) tpm_keys.append(tpm_key) rpm_keys.append(rpm_key) @@ -576,8 +584,9 @@ class LowestTPMLoggingHandler_v2(CustomLogger): id = m.get("model_info", {}).get( "id" ) # a deployment should always have an 'id'. this is set in router.py - tpm_key = "{}:tpm:{}".format(id, current_minute) - rpm_key = "{}:rpm:{}".format(id, current_minute) + deployment_name = m.get("litellm_params", {}).get("model") + tpm_key = "{}:{}:tpm:{}".format(id, deployment_name, current_minute) + rpm_key = "{}:{}:rpm:{}".format(id, deployment_name, current_minute) tpm_keys.append(tpm_key) rpm_keys.append(rpm_key) diff --git a/litellm/types/utils.py b/litellm/types/utils.py index a2d41d8fb9..5ecd490ee8 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1625,13 +1625,16 @@ class StandardLoggingAdditionalHeaders(TypedDict, total=False): class StandardLoggingHiddenParams(TypedDict): - model_id: Optional[str] + model_id: Optional[ + str + ] # id of the model in the router, separates multiple models with the same name but different credentials cache_key: Optional[str] api_base: Optional[str] response_cost: Optional[str] litellm_overhead_time_ms: Optional[float] additional_headers: Optional[StandardLoggingAdditionalHeaders] batch_models: Optional[List[str]] + litellm_model_name: Optional[str] # the model name sent to the provider by litellm class StandardLoggingModelInformation(TypedDict): diff --git a/tests/litellm/caching/test_redis_cache.py b/tests/litellm/caching/test_redis_cache.py index 3b7bdc5629..10064c0b13 100644 --- a/tests/litellm/caching/test_redis_cache.py +++ b/tests/litellm/caching/test_redis_cache.py @@ -20,7 +20,8 @@ from litellm.caching.redis_cache import RedisCache @pytest.mark.parametrize("namespace", [None, "test"]) @pytest.mark.asyncio -async def test_redis_cache_async_increment(namespace): +async def test_redis_cache_async_increment(namespace, monkeypatch): + monkeypatch.setenv("REDIS_HOST", "https://my-test-host") redis_cache = RedisCache(namespace=namespace) # Create an AsyncMock for the Redis client mock_redis_instance = AsyncMock() @@ -46,7 +47,8 @@ async def test_redis_cache_async_increment(namespace): @pytest.mark.asyncio -async def test_redis_client_init_with_socket_timeout(): +async def test_redis_client_init_with_socket_timeout(monkeypatch): + monkeypatch.setenv("REDIS_HOST", "my-fake-host") redis_cache = RedisCache(socket_timeout=1.0) assert redis_cache.redis_kwargs["socket_timeout"] == 1.0 client = redis_cache.init_async_client() diff --git a/tests/litellm/llms/azure_ai/chat/test_azure_ai_transformation.py b/tests/litellm/llms/azure_ai/chat/test_azure_ai_transformation.py index 239bea950c..6a42b51fe9 100644 --- a/tests/litellm/llms/azure_ai/chat/test_azure_ai_transformation.py +++ b/tests/litellm/llms/azure_ai/chat/test_azure_ai_transformation.py @@ -8,84 +8,24 @@ import pytest sys.path.insert( 0, os.path.abspath("../../../../..") ) # Adds the parent directory to the system path -import litellm +from litellm.llms.azure_ai.chat.transformation import AzureAIStudioConfig -@pytest.mark.parametrize("is_async", [True, False]) @pytest.mark.asyncio -async def test_azure_ai_request_format(is_async): +async def test_get_openai_compatible_provider_info(): """ Test that Azure AI requests are formatted correctly with the proper endpoint and parameters for both synchronous and asynchronous calls """ - litellm._turn_on_debug() + config = AzureAIStudioConfig() - # Set up the test parameters - api_key = "00xxx" - api_base = "https://my-endpoint-europe-berri-992.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2024-08-01-preview" - model = "azure_ai/gpt-4o-mini" - messages = [ - {"role": "user", "content": "hi"}, - {"role": "assistant", "content": "Hello! How can I assist you today?"}, - {"role": "user", "content": "hi"}, - ] + api_base, dynamic_api_key, custom_llm_provider = ( + config._get_openai_compatible_provider_info( + model="azure_ai/gpt-4o-mini", + api_base="https://my-base", + api_key="my-key", + custom_llm_provider="azure_ai", + ) + ) - if is_async: - # Mock AsyncHTTPHandler.post method for async test - with patch( - "litellm.llms.custom_httpx.llm_http_handler.AsyncHTTPHandler.post" - ) as mock_post: - # Set up mock response - mock_post.return_value = AsyncMock() - - # Call the acompletion function - try: - await litellm.acompletion( - custom_llm_provider="azure_ai", - api_key=api_key, - api_base=api_base, - model=model, - messages=messages, - ) - except Exception as e: - # We expect an exception since we're mocking the response - pass - - else: - # Mock HTTPHandler.post method for sync test - with patch( - "litellm.llms.custom_httpx.llm_http_handler.HTTPHandler.post" - ) as mock_post: - # Set up mock response - mock_post.return_value = MagicMock() - - # Call the completion function - try: - litellm.completion( - custom_llm_provider="azure_ai", - api_key=api_key, - api_base=api_base, - model=model, - messages=messages, - ) - except Exception as e: - # We expect an exception since we're mocking the response - pass - - # Verify the request was made with the correct parameters - mock_post.assert_called_once() - call_args = mock_post.call_args - print("sync request call=", json.dumps(call_args.kwargs, indent=4, default=str)) - - # Check URL - assert call_args.kwargs["url"] == api_base - - # Check headers - assert call_args.kwargs["headers"]["api-key"] == api_key - - # Check request body - request_body = json.loads(call_args.kwargs["data"]) - assert ( - request_body["model"] == "gpt-4o-mini" - ) # Model name should be stripped of provider prefix - assert request_body["messages"] == messages + assert custom_llm_provider == "azure" diff --git a/tests/llm_translation/test_azure_ai.py b/tests/llm_translation/test_azure_ai.py index 6d4284cd86..6ec2050638 100644 --- a/tests/llm_translation/test_azure_ai.py +++ b/tests/llm_translation/test_azure_ai.py @@ -265,3 +265,32 @@ class TestAzureAIRerank(BaseLLMRerankTest): "api_base": os.getenv("AZURE_AI_COHERE_API_BASE"), "api_key": os.getenv("AZURE_AI_COHERE_API_KEY"), } + + +@pytest.mark.asyncio +async def test_azure_ai_request_format(): + """ + Test that Azure AI requests are formatted correctly with the proper endpoint and parameters + for both synchronous and asynchronous calls + """ + from openai import AsyncAzureOpenAI, AzureOpenAI + + litellm._turn_on_debug() + + # Set up the test parameters + api_key = os.getenv("AZURE_API_KEY") + api_base = f"{os.getenv('AZURE_API_BASE')}/openai/deployments/gpt-4o/chat/completions?api-version=2024-08-01-preview" + model = "azure_ai/gpt-4o" + messages = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "Hello! How can I assist you today?"}, + {"role": "user", "content": "hi"}, + ] + + await litellm.acompletion( + custom_llm_provider="azure_ai", + api_key=api_key, + api_base=api_base, + model=model, + messages=messages, + ) diff --git a/tests/local_testing/test_get_llm_provider.py b/tests/local_testing/test_get_llm_provider.py index c3f4c15c27..fa27a8378c 100644 --- a/tests/local_testing/test_get_llm_provider.py +++ b/tests/local_testing/test_get_llm_provider.py @@ -216,3 +216,21 @@ def test_bedrock_invoke_anthropic(): ) assert custom_llm_provider == "bedrock" assert model == "invoke/anthropic.claude-3-5-sonnet-20240620-v1:0" + + +@pytest.mark.parametrize("model", ["xai/grok-2-vision-latest", "grok-2-vision-latest"]) +def test_xai_api_base(model): + args = { + "model": model, + "custom_llm_provider": "xai", + "api_base": None, + "api_key": "xai-my-specialkey", + "litellm_params": None, + } + model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider( + **args + ) + assert custom_llm_provider == "xai" + assert model == "grok-2-vision-latest" + assert api_base == "https://api.x.ai/v1" + assert dynamic_api_key == "xai-my-specialkey" diff --git a/tests/local_testing/test_tpm_rpm_routing_v2.py b/tests/local_testing/test_tpm_rpm_routing_v2.py index a7073b4acd..d2b951a187 100644 --- a/tests/local_testing/test_tpm_rpm_routing_v2.py +++ b/tests/local_testing/test_tpm_rpm_routing_v2.py @@ -20,7 +20,7 @@ sys.path.insert( from unittest.mock import AsyncMock, MagicMock, patch from litellm.types.utils import StandardLoggingPayload import pytest - +from litellm.types.router import DeploymentTypedDict import litellm from litellm import Router from litellm.caching.caching import DualCache @@ -47,12 +47,14 @@ def test_tpm_rpm_updated(): deployment_id = "1234" deployment = "azure/chatgpt-v-2" total_tokens = 50 - standard_logging_payload = create_standard_logging_payload() + standard_logging_payload: StandardLoggingPayload = create_standard_logging_payload() standard_logging_payload["model_group"] = model_group standard_logging_payload["model_id"] = deployment_id standard_logging_payload["total_tokens"] = total_tokens + standard_logging_payload["hidden_params"]["litellm_model_name"] = deployment kwargs = { "litellm_params": { + "model": deployment, "metadata": { "model_group": model_group, "deployment": deployment, @@ -62,10 +64,16 @@ def test_tpm_rpm_updated(): "standard_logging_object": standard_logging_payload, } + litellm_deployment_dict: DeploymentTypedDict = { + "model_name": model_group, + "litellm_params": {"model": deployment}, + "model_info": {"id": deployment_id}, + } + start_time = time.time() response_obj = {"usage": {"total_tokens": total_tokens}} end_time = time.time() - lowest_tpm_logger.pre_call_check(deployment=kwargs["litellm_params"]) + lowest_tpm_logger.pre_call_check(deployment=litellm_deployment_dict) lowest_tpm_logger.log_success_event( response_obj=response_obj, kwargs=kwargs, @@ -74,8 +82,8 @@ def test_tpm_rpm_updated(): ) dt = get_utc_datetime() current_minute = dt.strftime("%H-%M") - tpm_count_api_key = f"{deployment_id}:tpm:{current_minute}" - rpm_count_api_key = f"{deployment_id}:rpm:{current_minute}" + tpm_count_api_key = f"{deployment_id}:{deployment}:tpm:{current_minute}" + rpm_count_api_key = f"{deployment_id}:{deployment}:rpm:{current_minute}" print(f"tpm_count_api_key={tpm_count_api_key}") assert response_obj["usage"]["total_tokens"] == test_cache.get_cache( @@ -113,6 +121,7 @@ def test_get_available_deployments(): standard_logging_payload["model_group"] = model_group standard_logging_payload["model_id"] = deployment_id standard_logging_payload["total_tokens"] = total_tokens + standard_logging_payload["hidden_params"]["litellm_model_name"] = deployment kwargs = { "litellm_params": { "metadata": { @@ -135,10 +144,11 @@ def test_get_available_deployments(): ## DEPLOYMENT 2 ## total_tokens = 20 deployment_id = "5678" - standard_logging_payload = create_standard_logging_payload() + standard_logging_payload: StandardLoggingPayload = create_standard_logging_payload() standard_logging_payload["model_group"] = model_group standard_logging_payload["model_id"] = deployment_id standard_logging_payload["total_tokens"] = total_tokens + standard_logging_payload["hidden_params"]["litellm_model_name"] = deployment kwargs = { "litellm_params": { "metadata": { @@ -209,11 +219,12 @@ def test_router_get_available_deployments(): print(f"router id's: {router.get_model_ids()}") ## DEPLOYMENT 1 ## deployment_id = 1 - standard_logging_payload = create_standard_logging_payload() + standard_logging_payload: StandardLoggingPayload = create_standard_logging_payload() standard_logging_payload["model_group"] = "azure-model" standard_logging_payload["model_id"] = str(deployment_id) total_tokens = 50 standard_logging_payload["total_tokens"] = total_tokens + standard_logging_payload["hidden_params"]["litellm_model_name"] = "azure/gpt-turbo" kwargs = { "litellm_params": { "metadata": { @@ -237,6 +248,9 @@ def test_router_get_available_deployments(): standard_logging_payload = create_standard_logging_payload() standard_logging_payload["model_group"] = "azure-model" standard_logging_payload["model_id"] = str(deployment_id) + standard_logging_payload["hidden_params"][ + "litellm_model_name" + ] = "azure/gpt-35-turbo" kwargs = { "litellm_params": { "metadata": { @@ -293,10 +307,11 @@ def test_router_skip_rate_limited_deployments(): ## DEPLOYMENT 1 ## deployment_id = 1 total_tokens = 1439 - standard_logging_payload = create_standard_logging_payload() + standard_logging_payload: StandardLoggingPayload = create_standard_logging_payload() standard_logging_payload["model_group"] = "azure-model" standard_logging_payload["model_id"] = str(deployment_id) standard_logging_payload["total_tokens"] = total_tokens + standard_logging_payload["hidden_params"]["litellm_model_name"] = "azure/gpt-turbo" kwargs = { "litellm_params": { "metadata": { @@ -699,3 +714,54 @@ def test_return_potential_deployments(): ) assert len(potential_deployments) == 1 + + +@pytest.mark.asyncio +async def test_tpm_rpm_routing_model_name_checks(): + deployment = { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "azure/chatgpt-v-2", + "api_key": os.getenv("AZURE_API_KEY"), + "api_base": os.getenv("AZURE_API_BASE"), + "mock_response": "Hey, how's it going?", + }, + } + router = Router(model_list=[deployment], routing_strategy="usage-based-routing-v2") + + async def side_effect_pre_call_check(*args, **kwargs): + return args[0] + + with patch.object( + router.lowesttpm_logger_v2, + "async_pre_call_check", + side_effect=side_effect_pre_call_check, + ) as mock_object, patch.object( + router.lowesttpm_logger_v2, "async_log_success_event" + ) as mock_logging_event: + response = await router.acompletion( + model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hey!"}] + ) + + mock_object.assert_called() + print(f"mock_object.call_args: {mock_object.call_args[0][0]}") + assert ( + mock_object.call_args[0][0]["litellm_params"]["model"] + == deployment["litellm_params"]["model"] + ) + + await asyncio.sleep(1) + + mock_logging_event.assert_called() + + print(f"mock_logging_event: {mock_logging_event.call_args.kwargs}") + standard_logging_payload: StandardLoggingPayload = ( + mock_logging_event.call_args.kwargs.get("kwargs", {}).get( + "standard_logging_object" + ) + ) + + assert ( + standard_logging_payload["hidden_params"]["litellm_model_name"] + == "azure/chatgpt-v-2" + )