feat(proxy/utils.py): return api base for request hanging alerts

This commit is contained in:
Krrish Dholakia 2024-04-06 15:58:53 -07:00
parent b49e47b634
commit 6110d32b1c
7 changed files with 180 additions and 15 deletions

View file

@ -75,6 +75,7 @@ from .proxy._types import KeyManagementSystem
from openai import OpenAIError as OriginalError
from openai._models import BaseModel as OpenAIObject
from .caching import S3Cache, RedisSemanticCache, RedisCache
from .router import LiteLLM_Params
from .exceptions import (
AuthenticationError,
BadRequestError,
@ -1075,6 +1076,9 @@ class Logging:
headers = {}
data = additional_args.get("complete_input_dict", {})
api_base = additional_args.get("api_base", "")
self.model_call_details["litellm_params"]["api_base"] = str(
api_base
) # used for alerting
masked_headers = {
k: (v[:-20] + "*" * 20) if (isinstance(v, str) and len(v) > 20) else v
for k, v in headers.items()
@ -1203,7 +1207,6 @@ class Logging:
self.model_call_details["original_response"] = original_response
self.model_call_details["additional_args"] = additional_args
self.model_call_details["log_event_type"] = "post_api_call"
# User Logging -> if you pass in a custom logging function
print_verbose(
f"RAW RESPONSE:\n{self.model_call_details.get('original_response', self.model_call_details)}\n\n",
@ -2546,7 +2549,7 @@ def client(original_function):
langfuse_secret=kwargs.pop("langfuse_secret", None),
)
## check if metadata is passed in
litellm_params = {}
litellm_params = {"api_base": ""}
if "metadata" in kwargs:
litellm_params["metadata"] = kwargs["metadata"]
logging_obj.update_environment_variables(
@ -3033,7 +3036,7 @@ def client(original_function):
cached_result = await litellm.cache.async_get_cache(
*args, **kwargs
)
else:
else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync]
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
kwargs["preset_cache_key"] = (
preset_cache_key # for streaming calls, we need to pass the preset_cache_key
@ -3076,6 +3079,7 @@ def client(original_function):
"preset_cache_key", None
),
"stream_response": kwargs.get("stream_response", {}),
"api_base": kwargs.get("api_base", ""),
},
input=kwargs.get("messages", ""),
api_key=kwargs.get("api_key", None),
@ -3209,6 +3213,7 @@ def client(original_function):
"stream_response": kwargs.get(
"stream_response", {}
),
"api_base": "",
},
input=kwargs.get("messages", ""),
api_key=kwargs.get("api_key", None),
@ -5305,6 +5310,27 @@ def get_optional_params(
return optional_params
def get_api_base(model: str, optional_params: dict) -> Optional[str]:
_optional_params = LiteLLM_Params(**optional_params) # convert to pydantic object
if _optional_params.api_base is not None:
return _optional_params.api_base
if (
_optional_params.vertex_location is not None
and _optional_params.vertex_project is not None
):
_api_base = "{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/google/models/{}:streamGenerateContent".format(
_optional_params.vertex_location,
_optional_params.vertex_project,
_optional_params.vertex_location,
model,
)
return _api_base
return None
def get_supported_openai_params(model: str, custom_llm_provider: str):
"""
Returns the supported openai params for a given model + provider