Merge pull request #3430 from BerriAI/litellm_return_api_base

feat(proxy_server.py): return api base in response headers
This commit is contained in:
Krish Dholakia 2024-05-03 16:25:21 -07:00 committed by GitHub
commit 1b35a75245
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 67 additions and 11 deletions

View file

@ -3658,6 +3658,7 @@ async def chat_completion(
hidden_params = getattr(response, "_hidden_params", {}) or {} hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or "" model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or "" cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
# Post Call Processing # Post Call Processing
if llm_router is not None: if llm_router is not None:
@ -3670,6 +3671,7 @@ async def chat_completion(
custom_headers = { custom_headers = {
"x-litellm-model-id": model_id, "x-litellm-model-id": model_id,
"x-litellm-cache-key": cache_key, "x-litellm-cache-key": cache_key,
"x-litellm-model-api-base": api_base,
} }
selected_data_generator = select_data_generator( selected_data_generator = select_data_generator(
response=response, user_api_key_dict=user_api_key_dict response=response, user_api_key_dict=user_api_key_dict
@ -3682,6 +3684,7 @@ async def chat_completion(
fastapi_response.headers["x-litellm-model-id"] = model_id fastapi_response.headers["x-litellm-model-id"] = model_id
fastapi_response.headers["x-litellm-cache-key"] = cache_key fastapi_response.headers["x-litellm-cache-key"] = cache_key
fastapi_response.headers["x-litellm-model-api-base"] = api_base
### CALL HOOKS ### - modify outgoing data ### CALL HOOKS ### - modify outgoing data
response = await proxy_logging_obj.post_call_success_hook( response = await proxy_logging_obj.post_call_success_hook(

View file

@ -15,6 +15,7 @@ import litellm
import pytest import pytest
import asyncio import asyncio
from unittest.mock import patch, MagicMock from unittest.mock import patch, MagicMock
from litellm.utils import get_api_base
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.integrations.slack_alerting import SlackAlerting from litellm.integrations.slack_alerting import SlackAlerting
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
@ -74,6 +75,19 @@ async def test_slack_alerting_llm_exceptions(exception_type, monkeypatch):
await asyncio.sleep(2) await asyncio.sleep(2)
@pytest.mark.parametrize(
"model, optional_params, expected_api_base",
[
("openai/my-fake-model", {"api_base": "my-fake-api-base"}, "my-fake-api-base"),
("gpt-3.5-turbo", {}, "https://api.openai.com"),
],
)
def test_get_api_base_unit_test(model, optional_params, expected_api_base):
api_base = get_api_base(model=model, optional_params=optional_params)
assert api_base == expected_api_base
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_api_base(): async def test_get_api_base():
_pl = ProxyLogging(user_api_key_cache=DualCache()) _pl = ProxyLogging(user_api_key_cache=DualCache())

View file

@ -99,6 +99,7 @@ class ModelInfo(BaseModel):
class LiteLLM_Params(BaseModel): class LiteLLM_Params(BaseModel):
model: str model: str
custom_llm_provider: Optional[str] = None
tpm: Optional[int] = None tpm: Optional[int] = None
rpm: Optional[int] = None rpm: Optional[int] = None
api_key: Optional[str] = None api_key: Optional[str] = None
@ -123,6 +124,7 @@ class LiteLLM_Params(BaseModel):
def __init__( def __init__(
self, self,
model: str, model: str,
custom_llm_provider: Optional[str] = None,
max_retries: Optional[Union[int, str]] = None, max_retries: Optional[Union[int, str]] = None,
tpm: Optional[int] = None, tpm: Optional[int] = None,
rpm: Optional[int] = None, rpm: Optional[int] = None,

View file

@ -315,6 +315,7 @@ class ChatCompletionDeltaToolCall(OpenAIObject):
class HiddenParams(OpenAIObject): class HiddenParams(OpenAIObject):
original_response: Optional[str] = None original_response: Optional[str] = None
model_id: Optional[str] = None # used in Router for individual deployments model_id: Optional[str] = None # used in Router for individual deployments
api_base: Optional[str] = None # returns api base used for making completion call
class Config: class Config:
extra = "allow" extra = "allow"
@ -3157,6 +3158,10 @@ def client(original_function):
result._hidden_params["model_id"] = kwargs.get("model_info", {}).get( result._hidden_params["model_id"] = kwargs.get("model_info", {}).get(
"id", None "id", None
) )
result._hidden_params["api_base"] = get_api_base(
model=model,
optional_params=getattr(logging_obj, "optional_params", {}),
)
result._response_ms = ( result._response_ms = (
end_time - start_time end_time - start_time
).total_seconds() * 1000 # return response latency in ms like openai ).total_seconds() * 1000 # return response latency in ms like openai
@ -3226,6 +3231,8 @@ def client(original_function):
call_type = original_function.__name__ call_type = original_function.__name__
if "litellm_call_id" not in kwargs: if "litellm_call_id" not in kwargs:
kwargs["litellm_call_id"] = str(uuid.uuid4()) kwargs["litellm_call_id"] = str(uuid.uuid4())
model = ""
try: try:
model = args[0] if len(args) > 0 else kwargs["model"] model = args[0] if len(args) > 0 else kwargs["model"]
except: except:
@ -3547,6 +3554,10 @@ def client(original_function):
result._hidden_params["model_id"] = kwargs.get("model_info", {}).get( result._hidden_params["model_id"] = kwargs.get("model_info", {}).get(
"id", None "id", None
) )
result._hidden_params["api_base"] = get_api_base(
model=model,
optional_params=kwargs,
)
if ( if (
isinstance(result, ModelResponse) isinstance(result, ModelResponse)
or isinstance(result, EmbeddingResponse) or isinstance(result, EmbeddingResponse)
@ -5810,19 +5821,40 @@ def get_api_base(model: str, optional_params: dict) -> Optional[str]:
get_api_base(model="gemini/gemini-pro") get_api_base(model="gemini/gemini-pro")
``` ```
""" """
try:
if "model" in optional_params:
_optional_params = LiteLLM_Params(**optional_params)
else: # prevent needing to copy and pop the dict
_optional_params = LiteLLM_Params( _optional_params = LiteLLM_Params(
model=model, **optional_params model=model, **optional_params
) # convert to pydantic object ) # convert to pydantic object
except Exception as e:
verbose_logger.error("Error occurred in getting api base - {}".format(str(e)))
return None
# get llm provider # get llm provider
try:
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
model=model
)
except:
custom_llm_provider = None
if _optional_params.api_base is not None: if _optional_params.api_base is not None:
return _optional_params.api_base return _optional_params.api_base
try:
model, custom_llm_provider, dynamic_api_key, dynamic_api_base = (
get_llm_provider(
model=model,
custom_llm_provider=_optional_params.custom_llm_provider,
api_base=_optional_params.api_base,
api_key=_optional_params.api_key,
)
)
except Exception as e:
verbose_logger.error("Error occurred in getting api base - {}".format(str(e)))
custom_llm_provider = None
dynamic_api_key = None
dynamic_api_base = None
if dynamic_api_base is not None:
return dynamic_api_base
if ( if (
_optional_params.vertex_location is not None _optional_params.vertex_location is not None
and _optional_params.vertex_project is not None and _optional_params.vertex_project is not None
@ -5835,11 +5867,17 @@ def get_api_base(model: str, optional_params: dict) -> Optional[str]:
) )
return _api_base return _api_base
if custom_llm_provider is not None and custom_llm_provider == "gemini": if custom_llm_provider is None:
return None
if custom_llm_provider == "gemini":
_api_base = "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent".format( _api_base = "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent".format(
model model
) )
return _api_base return _api_base
elif custom_llm_provider == "openai":
_api_base = "https://api.openai.com"
return _api_base
return None return None
@ -6147,7 +6185,6 @@ def get_llm_provider(
try: try:
dynamic_api_key = None dynamic_api_key = None
# check if llm provider provided # check if llm provider provided
# AZURE AI-Studio Logic - Azure AI Studio supports AZURE/Cohere # AZURE AI-Studio Logic - Azure AI Studio supports AZURE/Cohere
# If User passes azure/command-r-plus -> we should send it to cohere_chat/command-r-plus # If User passes azure/command-r-plus -> we should send it to cohere_chat/command-r-plus
if model.split("/", 1)[0] == "azure": if model.split("/", 1)[0] == "azure":