mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Merge pull request #3430 from BerriAI/litellm_return_api_base
feat(proxy_server.py): return api base in response headers
This commit is contained in:
commit
1b35a75245
4 changed files with 67 additions and 11 deletions
|
@ -3658,6 +3658,7 @@ async def chat_completion(
|
||||||
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||||
model_id = hidden_params.get("model_id", None) or ""
|
model_id = hidden_params.get("model_id", None) or ""
|
||||||
cache_key = hidden_params.get("cache_key", None) or ""
|
cache_key = hidden_params.get("cache_key", None) or ""
|
||||||
|
api_base = hidden_params.get("api_base", None) or ""
|
||||||
|
|
||||||
# Post Call Processing
|
# Post Call Processing
|
||||||
if llm_router is not None:
|
if llm_router is not None:
|
||||||
|
@ -3670,6 +3671,7 @@ async def chat_completion(
|
||||||
custom_headers = {
|
custom_headers = {
|
||||||
"x-litellm-model-id": model_id,
|
"x-litellm-model-id": model_id,
|
||||||
"x-litellm-cache-key": cache_key,
|
"x-litellm-cache-key": cache_key,
|
||||||
|
"x-litellm-model-api-base": api_base,
|
||||||
}
|
}
|
||||||
selected_data_generator = select_data_generator(
|
selected_data_generator = select_data_generator(
|
||||||
response=response, user_api_key_dict=user_api_key_dict
|
response=response, user_api_key_dict=user_api_key_dict
|
||||||
|
@ -3682,6 +3684,7 @@ async def chat_completion(
|
||||||
|
|
||||||
fastapi_response.headers["x-litellm-model-id"] = model_id
|
fastapi_response.headers["x-litellm-model-id"] = model_id
|
||||||
fastapi_response.headers["x-litellm-cache-key"] = cache_key
|
fastapi_response.headers["x-litellm-cache-key"] = cache_key
|
||||||
|
fastapi_response.headers["x-litellm-model-api-base"] = api_base
|
||||||
|
|
||||||
### CALL HOOKS ### - modify outgoing data
|
### CALL HOOKS ### - modify outgoing data
|
||||||
response = await proxy_logging_obj.post_call_success_hook(
|
response = await proxy_logging_obj.post_call_success_hook(
|
||||||
|
|
|
@ -15,6 +15,7 @@ import litellm
|
||||||
import pytest
|
import pytest
|
||||||
import asyncio
|
import asyncio
|
||||||
from unittest.mock import patch, MagicMock
|
from unittest.mock import patch, MagicMock
|
||||||
|
from litellm.utils import get_api_base
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
from litellm.integrations.slack_alerting import SlackAlerting
|
from litellm.integrations.slack_alerting import SlackAlerting
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
@ -74,6 +75,19 @@ async def test_slack_alerting_llm_exceptions(exception_type, monkeypatch):
|
||||||
await asyncio.sleep(2)
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model, optional_params, expected_api_base",
|
||||||
|
[
|
||||||
|
("openai/my-fake-model", {"api_base": "my-fake-api-base"}, "my-fake-api-base"),
|
||||||
|
("gpt-3.5-turbo", {}, "https://api.openai.com"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_get_api_base_unit_test(model, optional_params, expected_api_base):
|
||||||
|
api_base = get_api_base(model=model, optional_params=optional_params)
|
||||||
|
|
||||||
|
assert api_base == expected_api_base
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_api_base():
|
async def test_get_api_base():
|
||||||
_pl = ProxyLogging(user_api_key_cache=DualCache())
|
_pl = ProxyLogging(user_api_key_cache=DualCache())
|
||||||
|
|
|
@ -99,6 +99,7 @@ class ModelInfo(BaseModel):
|
||||||
|
|
||||||
class LiteLLM_Params(BaseModel):
|
class LiteLLM_Params(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
|
custom_llm_provider: Optional[str] = None
|
||||||
tpm: Optional[int] = None
|
tpm: Optional[int] = None
|
||||||
rpm: Optional[int] = None
|
rpm: Optional[int] = None
|
||||||
api_key: Optional[str] = None
|
api_key: Optional[str] = None
|
||||||
|
@ -123,6 +124,7 @@ class LiteLLM_Params(BaseModel):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
custom_llm_provider: Optional[str] = None,
|
||||||
max_retries: Optional[Union[int, str]] = None,
|
max_retries: Optional[Union[int, str]] = None,
|
||||||
tpm: Optional[int] = None,
|
tpm: Optional[int] = None,
|
||||||
rpm: Optional[int] = None,
|
rpm: Optional[int] = None,
|
||||||
|
|
|
@ -315,6 +315,7 @@ class ChatCompletionDeltaToolCall(OpenAIObject):
|
||||||
class HiddenParams(OpenAIObject):
|
class HiddenParams(OpenAIObject):
|
||||||
original_response: Optional[str] = None
|
original_response: Optional[str] = None
|
||||||
model_id: Optional[str] = None # used in Router for individual deployments
|
model_id: Optional[str] = None # used in Router for individual deployments
|
||||||
|
api_base: Optional[str] = None # returns api base used for making completion call
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
extra = "allow"
|
extra = "allow"
|
||||||
|
@ -3157,6 +3158,10 @@ def client(original_function):
|
||||||
result._hidden_params["model_id"] = kwargs.get("model_info", {}).get(
|
result._hidden_params["model_id"] = kwargs.get("model_info", {}).get(
|
||||||
"id", None
|
"id", None
|
||||||
)
|
)
|
||||||
|
result._hidden_params["api_base"] = get_api_base(
|
||||||
|
model=model,
|
||||||
|
optional_params=getattr(logging_obj, "optional_params", {}),
|
||||||
|
)
|
||||||
result._response_ms = (
|
result._response_ms = (
|
||||||
end_time - start_time
|
end_time - start_time
|
||||||
).total_seconds() * 1000 # return response latency in ms like openai
|
).total_seconds() * 1000 # return response latency in ms like openai
|
||||||
|
@ -3226,6 +3231,8 @@ def client(original_function):
|
||||||
call_type = original_function.__name__
|
call_type = original_function.__name__
|
||||||
if "litellm_call_id" not in kwargs:
|
if "litellm_call_id" not in kwargs:
|
||||||
kwargs["litellm_call_id"] = str(uuid.uuid4())
|
kwargs["litellm_call_id"] = str(uuid.uuid4())
|
||||||
|
|
||||||
|
model = ""
|
||||||
try:
|
try:
|
||||||
model = args[0] if len(args) > 0 else kwargs["model"]
|
model = args[0] if len(args) > 0 else kwargs["model"]
|
||||||
except:
|
except:
|
||||||
|
@ -3547,6 +3554,10 @@ def client(original_function):
|
||||||
result._hidden_params["model_id"] = kwargs.get("model_info", {}).get(
|
result._hidden_params["model_id"] = kwargs.get("model_info", {}).get(
|
||||||
"id", None
|
"id", None
|
||||||
)
|
)
|
||||||
|
result._hidden_params["api_base"] = get_api_base(
|
||||||
|
model=model,
|
||||||
|
optional_params=kwargs,
|
||||||
|
)
|
||||||
if (
|
if (
|
||||||
isinstance(result, ModelResponse)
|
isinstance(result, ModelResponse)
|
||||||
or isinstance(result, EmbeddingResponse)
|
or isinstance(result, EmbeddingResponse)
|
||||||
|
@ -5810,19 +5821,40 @@ def get_api_base(model: str, optional_params: dict) -> Optional[str]:
|
||||||
get_api_base(model="gemini/gemini-pro")
|
get_api_base(model="gemini/gemini-pro")
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
if "model" in optional_params:
|
||||||
|
_optional_params = LiteLLM_Params(**optional_params)
|
||||||
|
else: # prevent needing to copy and pop the dict
|
||||||
_optional_params = LiteLLM_Params(
|
_optional_params = LiteLLM_Params(
|
||||||
model=model, **optional_params
|
model=model, **optional_params
|
||||||
) # convert to pydantic object
|
) # convert to pydantic object
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.error("Error occurred in getting api base - {}".format(str(e)))
|
||||||
|
return None
|
||||||
# get llm provider
|
# get llm provider
|
||||||
try:
|
|
||||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
|
|
||||||
model=model
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
custom_llm_provider = None
|
|
||||||
if _optional_params.api_base is not None:
|
if _optional_params.api_base is not None:
|
||||||
return _optional_params.api_base
|
return _optional_params.api_base
|
||||||
|
|
||||||
|
try:
|
||||||
|
model, custom_llm_provider, dynamic_api_key, dynamic_api_base = (
|
||||||
|
get_llm_provider(
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider=_optional_params.custom_llm_provider,
|
||||||
|
api_base=_optional_params.api_base,
|
||||||
|
api_key=_optional_params.api_key,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.error("Error occurred in getting api base - {}".format(str(e)))
|
||||||
|
custom_llm_provider = None
|
||||||
|
dynamic_api_key = None
|
||||||
|
dynamic_api_base = None
|
||||||
|
|
||||||
|
if dynamic_api_base is not None:
|
||||||
|
return dynamic_api_base
|
||||||
|
|
||||||
if (
|
if (
|
||||||
_optional_params.vertex_location is not None
|
_optional_params.vertex_location is not None
|
||||||
and _optional_params.vertex_project is not None
|
and _optional_params.vertex_project is not None
|
||||||
|
@ -5835,11 +5867,17 @@ def get_api_base(model: str, optional_params: dict) -> Optional[str]:
|
||||||
)
|
)
|
||||||
return _api_base
|
return _api_base
|
||||||
|
|
||||||
if custom_llm_provider is not None and custom_llm_provider == "gemini":
|
if custom_llm_provider is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if custom_llm_provider == "gemini":
|
||||||
_api_base = "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent".format(
|
_api_base = "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent".format(
|
||||||
model
|
model
|
||||||
)
|
)
|
||||||
return _api_base
|
return _api_base
|
||||||
|
elif custom_llm_provider == "openai":
|
||||||
|
_api_base = "https://api.openai.com"
|
||||||
|
return _api_base
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@ -6147,7 +6185,6 @@ def get_llm_provider(
|
||||||
try:
|
try:
|
||||||
dynamic_api_key = None
|
dynamic_api_key = None
|
||||||
# check if llm provider provided
|
# check if llm provider provided
|
||||||
|
|
||||||
# AZURE AI-Studio Logic - Azure AI Studio supports AZURE/Cohere
|
# AZURE AI-Studio Logic - Azure AI Studio supports AZURE/Cohere
|
||||||
# If User passes azure/command-r-plus -> we should send it to cohere_chat/command-r-plus
|
# If User passes azure/command-r-plus -> we should send it to cohere_chat/command-r-plus
|
||||||
if model.split("/", 1)[0] == "azure":
|
if model.split("/", 1)[0] == "azure":
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue