feat(proxy_server.py): return api base in response headers

Closes https://github.com/BerriAI/litellm/issues/2631
This commit is contained in:
Krrish Dholakia 2024-05-03 15:27:32 -07:00
parent b2a0502383
commit 5b39f8e282
4 changed files with 67 additions and 11 deletions

View file

@ -315,6 +315,7 @@ class ChatCompletionDeltaToolCall(OpenAIObject):
class HiddenParams(OpenAIObject):
original_response: Optional[str] = None
model_id: Optional[str] = None # used in Router for individual deployments
api_base: Optional[str] = None # returns api base used for making completion call
class Config:
extra = "allow"
@ -3157,6 +3158,10 @@ def client(original_function):
result._hidden_params["model_id"] = kwargs.get("model_info", {}).get(
"id", None
)
result._hidden_params["api_base"] = get_api_base(
model=model,
optional_params=getattr(logging_obj, "optional_params", {}),
)
result._response_ms = (
end_time - start_time
).total_seconds() * 1000 # return response latency in ms like openai
@ -3226,6 +3231,8 @@ def client(original_function):
call_type = original_function.__name__
if "litellm_call_id" not in kwargs:
kwargs["litellm_call_id"] = str(uuid.uuid4())
model = ""
try:
model = args[0] if len(args) > 0 else kwargs["model"]
except:
@ -3547,6 +3554,10 @@ def client(original_function):
result._hidden_params["model_id"] = kwargs.get("model_info", {}).get(
"id", None
)
result._hidden_params["api_base"] = get_api_base(
model=model,
optional_params=kwargs,
)
if (
isinstance(result, ModelResponse)
or isinstance(result, EmbeddingResponse)
@ -5810,19 +5821,40 @@ def get_api_base(model: str, optional_params: dict) -> Optional[str]:
get_api_base(model="gemini/gemini-pro")
```
"""
_optional_params = LiteLLM_Params(
model=model, **optional_params
) # convert to pydantic object
# get llm provider
try:
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
model=model
)
except:
custom_llm_provider = None
if "model" in optional_params:
_optional_params = LiteLLM_Params(**optional_params)
else: # prevent needing to copy and pop the dict
_optional_params = LiteLLM_Params(
model=model, **optional_params
) # convert to pydantic object
except Exception as e:
verbose_logger.error("Error occurred in getting api base - {}".format(str(e)))
return None
# get llm provider
if _optional_params.api_base is not None:
return _optional_params.api_base
try:
model, custom_llm_provider, dynamic_api_key, dynamic_api_base = (
get_llm_provider(
model=model,
custom_llm_provider=_optional_params.custom_llm_provider,
api_base=_optional_params.api_base,
api_key=_optional_params.api_key,
)
)
except Exception as e:
verbose_logger.error("Error occurred in getting api base - {}".format(str(e)))
custom_llm_provider = None
dynamic_api_key = None
dynamic_api_base = None
if dynamic_api_base is not None:
return dynamic_api_base
if (
_optional_params.vertex_location is not None
and _optional_params.vertex_project is not None
@ -5835,11 +5867,17 @@ def get_api_base(model: str, optional_params: dict) -> Optional[str]:
)
return _api_base
if custom_llm_provider is not None and custom_llm_provider == "gemini":
if custom_llm_provider is None:
return None
if custom_llm_provider == "gemini":
_api_base = "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent".format(
model
)
return _api_base
elif custom_llm_provider == "openai":
_api_base = "https://api.openai.com"
return _api_base
return None
@ -6147,7 +6185,6 @@ def get_llm_provider(
try:
dynamic_api_key = None
# check if llm provider provided
# AZURE AI-Studio Logic - Azure AI Studio supports AZURE/Cohere
# If User passes azure/command-r-plus -> we should send it to cohere_chat/command-r-plus
if model.split("/", 1)[0] == "azure":