diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index f5c2349635..2ad3af7b9b 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -3658,6 +3658,7 @@ async def chat_completion( hidden_params = getattr(response, "_hidden_params", {}) or {} model_id = hidden_params.get("model_id", None) or "" cache_key = hidden_params.get("cache_key", None) or "" + api_base = hidden_params.get("api_base", None) or "" # Post Call Processing if llm_router is not None: @@ -3670,6 +3671,7 @@ async def chat_completion( custom_headers = { "x-litellm-model-id": model_id, "x-litellm-cache-key": cache_key, + "x-litellm-model-api-base": api_base, } selected_data_generator = select_data_generator( response=response, user_api_key_dict=user_api_key_dict @@ -3682,6 +3684,7 @@ async def chat_completion( fastapi_response.headers["x-litellm-model-id"] = model_id fastapi_response.headers["x-litellm-cache-key"] = cache_key + fastapi_response.headers["x-litellm-model-api-base"] = api_base ### CALL HOOKS ### - modify outgoing data response = await proxy_logging_obj.post_call_success_hook( diff --git a/litellm/tests/test_alerting.py b/litellm/tests/test_alerting.py index a74e25910c..40c75b86b7 100644 --- a/litellm/tests/test_alerting.py +++ b/litellm/tests/test_alerting.py @@ -15,6 +15,7 @@ import litellm import pytest import asyncio from unittest.mock import patch, MagicMock +from litellm.utils import get_api_base from litellm.caching import DualCache from litellm.integrations.slack_alerting import SlackAlerting from litellm.proxy._types import UserAPIKeyAuth @@ -74,6 +75,19 @@ async def test_slack_alerting_llm_exceptions(exception_type, monkeypatch): await asyncio.sleep(2) +@pytest.mark.parametrize( + "model, optional_params, expected_api_base", + [ + ("openai/my-fake-model", {"api_base": "my-fake-api-base"}, "my-fake-api-base"), + ("gpt-3.5-turbo", {}, "https://api.openai.com"), + ], +) +def test_get_api_base_unit_test(model, optional_params, expected_api_base): + api_base = get_api_base(model=model, optional_params=optional_params) + + assert api_base == expected_api_base + + @pytest.mark.asyncio async def test_get_api_base(): _pl = ProxyLogging(user_api_key_cache=DualCache()) diff --git a/litellm/types/router.py b/litellm/types/router.py index 64b71b999e..068a99b005 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -99,6 +99,7 @@ class ModelInfo(BaseModel): class LiteLLM_Params(BaseModel): model: str + custom_llm_provider: Optional[str] = None tpm: Optional[int] = None rpm: Optional[int] = None api_key: Optional[str] = None @@ -123,6 +124,7 @@ class LiteLLM_Params(BaseModel): def __init__( self, model: str, + custom_llm_provider: Optional[str] = None, max_retries: Optional[Union[int, str]] = None, tpm: Optional[int] = None, rpm: Optional[int] = None, diff --git a/litellm/utils.py b/litellm/utils.py index 80d26f58b9..589ea4d078 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -315,6 +315,7 @@ class ChatCompletionDeltaToolCall(OpenAIObject): class HiddenParams(OpenAIObject): original_response: Optional[str] = None model_id: Optional[str] = None # used in Router for individual deployments + api_base: Optional[str] = None # returns api base used for making completion call class Config: extra = "allow" @@ -3157,6 +3158,10 @@ def client(original_function): result._hidden_params["model_id"] = kwargs.get("model_info", {}).get( "id", None ) + result._hidden_params["api_base"] = get_api_base( + model=model, + optional_params=getattr(logging_obj, "optional_params", {}), + ) result._response_ms = ( end_time - start_time ).total_seconds() * 1000 # return response latency in ms like openai @@ -3226,6 +3231,8 @@ def client(original_function): call_type = original_function.__name__ if "litellm_call_id" not in kwargs: kwargs["litellm_call_id"] = str(uuid.uuid4()) + + model = "" try: model = args[0] if len(args) > 0 else kwargs["model"] except: @@ -3547,6 +3554,10 @@ def client(original_function): result._hidden_params["model_id"] = kwargs.get("model_info", {}).get( "id", None ) + result._hidden_params["api_base"] = get_api_base( + model=model, + optional_params=kwargs, + ) if ( isinstance(result, ModelResponse) or isinstance(result, EmbeddingResponse) @@ -5810,19 +5821,40 @@ def get_api_base(model: str, optional_params: dict) -> Optional[str]: get_api_base(model="gemini/gemini-pro") ``` """ - _optional_params = LiteLLM_Params( - model=model, **optional_params - ) # convert to pydantic object - # get llm provider + try: - model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider( - model=model - ) - except: - custom_llm_provider = None + if "model" in optional_params: + _optional_params = LiteLLM_Params(**optional_params) + else: # prevent needing to copy and pop the dict + _optional_params = LiteLLM_Params( + model=model, **optional_params + ) # convert to pydantic object + except Exception as e: + verbose_logger.error("Error occurred in getting api base - {}".format(str(e))) + return None + # get llm provider + if _optional_params.api_base is not None: return _optional_params.api_base + try: + model, custom_llm_provider, dynamic_api_key, dynamic_api_base = ( + get_llm_provider( + model=model, + custom_llm_provider=_optional_params.custom_llm_provider, + api_base=_optional_params.api_base, + api_key=_optional_params.api_key, + ) + ) + except Exception as e: + verbose_logger.error("Error occurred in getting api base - {}".format(str(e))) + custom_llm_provider = None + dynamic_api_key = None + dynamic_api_base = None + + if dynamic_api_base is not None: + return dynamic_api_base + if ( _optional_params.vertex_location is not None and _optional_params.vertex_project is not None @@ -5835,11 +5867,17 @@ def get_api_base(model: str, optional_params: dict) -> Optional[str]: ) return _api_base - if custom_llm_provider is not None and custom_llm_provider == "gemini": + if custom_llm_provider is None: + return None + + if custom_llm_provider == "gemini": _api_base = "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent".format( model ) return _api_base + elif custom_llm_provider == "openai": + _api_base = "https://api.openai.com" + return _api_base return None @@ -6147,7 +6185,6 @@ def get_llm_provider( try: dynamic_api_key = None # check if llm provider provided - # AZURE AI-Studio Logic - Azure AI Studio supports AZURE/Cohere # If User passes azure/command-r-plus -> we should send it to cohere_chat/command-r-plus if model.split("/", 1)[0] == "azure":