forked from phoenix/litellm-mirror
Merge pull request #3430 from BerriAI/litellm_return_api_base
feat(proxy_server.py): return api base in response headers
This commit is contained in:
commit
1b35a75245
4 changed files with 67 additions and 11 deletions
|
@ -3658,6 +3658,7 @@ async def chat_completion(
|
|||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||
model_id = hidden_params.get("model_id", None) or ""
|
||||
cache_key = hidden_params.get("cache_key", None) or ""
|
||||
api_base = hidden_params.get("api_base", None) or ""
|
||||
|
||||
# Post Call Processing
|
||||
if llm_router is not None:
|
||||
|
@ -3670,6 +3671,7 @@ async def chat_completion(
|
|||
custom_headers = {
|
||||
"x-litellm-model-id": model_id,
|
||||
"x-litellm-cache-key": cache_key,
|
||||
"x-litellm-model-api-base": api_base,
|
||||
}
|
||||
selected_data_generator = select_data_generator(
|
||||
response=response, user_api_key_dict=user_api_key_dict
|
||||
|
@ -3682,6 +3684,7 @@ async def chat_completion(
|
|||
|
||||
fastapi_response.headers["x-litellm-model-id"] = model_id
|
||||
fastapi_response.headers["x-litellm-cache-key"] = cache_key
|
||||
fastapi_response.headers["x-litellm-model-api-base"] = api_base
|
||||
|
||||
### CALL HOOKS ### - modify outgoing data
|
||||
response = await proxy_logging_obj.post_call_success_hook(
|
||||
|
|
|
@ -15,6 +15,7 @@ import litellm
|
|||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import patch, MagicMock
|
||||
from litellm.utils import get_api_base
|
||||
from litellm.caching import DualCache
|
||||
from litellm.integrations.slack_alerting import SlackAlerting
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
@ -74,6 +75,19 @@ async def test_slack_alerting_llm_exceptions(exception_type, monkeypatch):
|
|||
await asyncio.sleep(2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, optional_params, expected_api_base",
|
||||
[
|
||||
("openai/my-fake-model", {"api_base": "my-fake-api-base"}, "my-fake-api-base"),
|
||||
("gpt-3.5-turbo", {}, "https://api.openai.com"),
|
||||
],
|
||||
)
|
||||
def test_get_api_base_unit_test(model, optional_params, expected_api_base):
|
||||
api_base = get_api_base(model=model, optional_params=optional_params)
|
||||
|
||||
assert api_base == expected_api_base
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_api_base():
|
||||
_pl = ProxyLogging(user_api_key_cache=DualCache())
|
||||
|
|
|
@ -99,6 +99,7 @@ class ModelInfo(BaseModel):
|
|||
|
||||
class LiteLLM_Params(BaseModel):
|
||||
model: str
|
||||
custom_llm_provider: Optional[str] = None
|
||||
tpm: Optional[int] = None
|
||||
rpm: Optional[int] = None
|
||||
api_key: Optional[str] = None
|
||||
|
@ -123,6 +124,7 @@ class LiteLLM_Params(BaseModel):
|
|||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
max_retries: Optional[Union[int, str]] = None,
|
||||
tpm: Optional[int] = None,
|
||||
rpm: Optional[int] = None,
|
||||
|
|
|
@ -315,6 +315,7 @@ class ChatCompletionDeltaToolCall(OpenAIObject):
|
|||
class HiddenParams(OpenAIObject):
|
||||
original_response: Optional[str] = None
|
||||
model_id: Optional[str] = None # used in Router for individual deployments
|
||||
api_base: Optional[str] = None # returns api base used for making completion call
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
@ -3157,6 +3158,10 @@ def client(original_function):
|
|||
result._hidden_params["model_id"] = kwargs.get("model_info", {}).get(
|
||||
"id", None
|
||||
)
|
||||
result._hidden_params["api_base"] = get_api_base(
|
||||
model=model,
|
||||
optional_params=getattr(logging_obj, "optional_params", {}),
|
||||
)
|
||||
result._response_ms = (
|
||||
end_time - start_time
|
||||
).total_seconds() * 1000 # return response latency in ms like openai
|
||||
|
@ -3226,6 +3231,8 @@ def client(original_function):
|
|||
call_type = original_function.__name__
|
||||
if "litellm_call_id" not in kwargs:
|
||||
kwargs["litellm_call_id"] = str(uuid.uuid4())
|
||||
|
||||
model = ""
|
||||
try:
|
||||
model = args[0] if len(args) > 0 else kwargs["model"]
|
||||
except:
|
||||
|
@ -3547,6 +3554,10 @@ def client(original_function):
|
|||
result._hidden_params["model_id"] = kwargs.get("model_info", {}).get(
|
||||
"id", None
|
||||
)
|
||||
result._hidden_params["api_base"] = get_api_base(
|
||||
model=model,
|
||||
optional_params=kwargs,
|
||||
)
|
||||
if (
|
||||
isinstance(result, ModelResponse)
|
||||
or isinstance(result, EmbeddingResponse)
|
||||
|
@ -5810,19 +5821,40 @@ def get_api_base(model: str, optional_params: dict) -> Optional[str]:
|
|||
get_api_base(model="gemini/gemini-pro")
|
||||
```
|
||||
"""
|
||||
_optional_params = LiteLLM_Params(
|
||||
model=model, **optional_params
|
||||
) # convert to pydantic object
|
||||
# get llm provider
|
||||
|
||||
try:
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
|
||||
model=model
|
||||
)
|
||||
except:
|
||||
custom_llm_provider = None
|
||||
if "model" in optional_params:
|
||||
_optional_params = LiteLLM_Params(**optional_params)
|
||||
else: # prevent needing to copy and pop the dict
|
||||
_optional_params = LiteLLM_Params(
|
||||
model=model, **optional_params
|
||||
) # convert to pydantic object
|
||||
except Exception as e:
|
||||
verbose_logger.error("Error occurred in getting api base - {}".format(str(e)))
|
||||
return None
|
||||
# get llm provider
|
||||
|
||||
if _optional_params.api_base is not None:
|
||||
return _optional_params.api_base
|
||||
|
||||
try:
|
||||
model, custom_llm_provider, dynamic_api_key, dynamic_api_base = (
|
||||
get_llm_provider(
|
||||
model=model,
|
||||
custom_llm_provider=_optional_params.custom_llm_provider,
|
||||
api_base=_optional_params.api_base,
|
||||
api_key=_optional_params.api_key,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.error("Error occurred in getting api base - {}".format(str(e)))
|
||||
custom_llm_provider = None
|
||||
dynamic_api_key = None
|
||||
dynamic_api_base = None
|
||||
|
||||
if dynamic_api_base is not None:
|
||||
return dynamic_api_base
|
||||
|
||||
if (
|
||||
_optional_params.vertex_location is not None
|
||||
and _optional_params.vertex_project is not None
|
||||
|
@ -5835,11 +5867,17 @@ def get_api_base(model: str, optional_params: dict) -> Optional[str]:
|
|||
)
|
||||
return _api_base
|
||||
|
||||
if custom_llm_provider is not None and custom_llm_provider == "gemini":
|
||||
if custom_llm_provider is None:
|
||||
return None
|
||||
|
||||
if custom_llm_provider == "gemini":
|
||||
_api_base = "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent".format(
|
||||
model
|
||||
)
|
||||
return _api_base
|
||||
elif custom_llm_provider == "openai":
|
||||
_api_base = "https://api.openai.com"
|
||||
return _api_base
|
||||
return None
|
||||
|
||||
|
||||
|
@ -6147,7 +6185,6 @@ def get_llm_provider(
|
|||
try:
|
||||
dynamic_api_key = None
|
||||
# check if llm provider provided
|
||||
|
||||
# AZURE AI-Studio Logic - Azure AI Studio supports AZURE/Cohere
|
||||
# If User passes azure/command-r-plus -> we should send it to cohere_chat/command-r-plus
|
||||
if model.split("/", 1)[0] == "azure":
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue