mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
refactor location of proxy
This commit is contained in:
parent
baa5564f95
commit
ce58c53ff1
413 changed files with 2087 additions and 2088 deletions
|
@ -1,918 +0,0 @@
|
|||
import asyncio
|
||||
import copy
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
from fastapi import Request
|
||||
from starlette.datastructures import Headers
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||
from litellm._service_logger import ServiceLogging
|
||||
from litellm.proxy._types import (
|
||||
AddTeamCallback,
|
||||
CommonProxyErrors,
|
||||
LitellmDataForBackendLLMCall,
|
||||
SpecialHeaders,
|
||||
TeamCallbackMetadata,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.auth.route_checks import RouteChecks
|
||||
from litellm.router import Router
|
||||
from litellm.types.llms.anthropic import ANTHROPIC_API_HEADERS
|
||||
from litellm.types.services import ServiceTypes
|
||||
from litellm.types.utils import (
|
||||
ProviderSpecificHeader,
|
||||
StandardLoggingUserAPIKeyMetadata,
|
||||
SupportedCacheControls,
|
||||
)
|
||||
|
||||
service_logger_obj = ServiceLogging() # used for tracking latency on OTEL
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.proxy.proxy_server import ProxyConfig as _ProxyConfig
|
||||
|
||||
ProxyConfig = _ProxyConfig
|
||||
else:
|
||||
ProxyConfig = Any
|
||||
|
||||
|
||||
def parse_cache_control(cache_control):
|
||||
cache_dict = {}
|
||||
directives = cache_control.split(", ")
|
||||
|
||||
for directive in directives:
|
||||
if "=" in directive:
|
||||
key, value = directive.split("=")
|
||||
cache_dict[key] = value
|
||||
else:
|
||||
cache_dict[directive] = True
|
||||
|
||||
return cache_dict
|
||||
|
||||
|
||||
def _get_metadata_variable_name(request: Request) -> str:
|
||||
"""
|
||||
Helper to return what the "metadata" field should be called in the request data
|
||||
|
||||
For all /thread or /assistant endpoints we need to call this "litellm_metadata"
|
||||
|
||||
For ALL other endpoints we call this "metadata
|
||||
"""
|
||||
if RouteChecks._is_assistants_api_request(request):
|
||||
return "litellm_metadata"
|
||||
|
||||
LITELLM_METADATA_ROUTES = [
|
||||
"batches",
|
||||
"/v1/messages",
|
||||
"responses",
|
||||
]
|
||||
if any(
|
||||
[
|
||||
litellm_metadata_route in request.url.path
|
||||
for litellm_metadata_route in LITELLM_METADATA_ROUTES
|
||||
]
|
||||
):
|
||||
return "litellm_metadata"
|
||||
else:
|
||||
return "metadata"
|
||||
|
||||
|
||||
def safe_add_api_version_from_query_params(data: dict, request: Request):
|
||||
try:
|
||||
if hasattr(request, "query_params"):
|
||||
query_params = dict(request.query_params)
|
||||
if "api-version" in query_params:
|
||||
data["api_version"] = query_params["api-version"]
|
||||
except KeyError:
|
||||
pass
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
"error checking api version in query params: %s", str(e)
|
||||
)
|
||||
|
||||
|
||||
def convert_key_logging_metadata_to_callback(
|
||||
data: AddTeamCallback, team_callback_settings_obj: Optional[TeamCallbackMetadata]
|
||||
) -> TeamCallbackMetadata:
|
||||
if team_callback_settings_obj is None:
|
||||
team_callback_settings_obj = TeamCallbackMetadata()
|
||||
if data.callback_type == "success":
|
||||
if team_callback_settings_obj.success_callback is None:
|
||||
team_callback_settings_obj.success_callback = []
|
||||
|
||||
if data.callback_name not in team_callback_settings_obj.success_callback:
|
||||
team_callback_settings_obj.success_callback.append(data.callback_name)
|
||||
elif data.callback_type == "failure":
|
||||
if team_callback_settings_obj.failure_callback is None:
|
||||
team_callback_settings_obj.failure_callback = []
|
||||
|
||||
if data.callback_name not in team_callback_settings_obj.failure_callback:
|
||||
team_callback_settings_obj.failure_callback.append(data.callback_name)
|
||||
elif (
|
||||
not data.callback_type or data.callback_type == "success_and_failure"
|
||||
): # assume 'success_and_failure' = litellm.callbacks
|
||||
if team_callback_settings_obj.success_callback is None:
|
||||
team_callback_settings_obj.success_callback = []
|
||||
if team_callback_settings_obj.failure_callback is None:
|
||||
team_callback_settings_obj.failure_callback = []
|
||||
if team_callback_settings_obj.callbacks is None:
|
||||
team_callback_settings_obj.callbacks = []
|
||||
|
||||
if data.callback_name not in team_callback_settings_obj.success_callback:
|
||||
team_callback_settings_obj.success_callback.append(data.callback_name)
|
||||
|
||||
if data.callback_name not in team_callback_settings_obj.failure_callback:
|
||||
team_callback_settings_obj.failure_callback.append(data.callback_name)
|
||||
|
||||
if data.callback_name not in team_callback_settings_obj.callbacks:
|
||||
team_callback_settings_obj.callbacks.append(data.callback_name)
|
||||
|
||||
for var, value in data.callback_vars.items():
|
||||
if team_callback_settings_obj.callback_vars is None:
|
||||
team_callback_settings_obj.callback_vars = {}
|
||||
team_callback_settings_obj.callback_vars[var] = str(
|
||||
litellm.utils.get_secret(value, default_value=value) or value
|
||||
)
|
||||
|
||||
return team_callback_settings_obj
|
||||
|
||||
|
||||
def _get_dynamic_logging_metadata(
|
||||
user_api_key_dict: UserAPIKeyAuth, proxy_config: ProxyConfig
|
||||
) -> Optional[TeamCallbackMetadata]:
|
||||
callback_settings_obj: Optional[TeamCallbackMetadata] = None
|
||||
if (
|
||||
user_api_key_dict.metadata is not None
|
||||
and "logging" in user_api_key_dict.metadata
|
||||
):
|
||||
for item in user_api_key_dict.metadata["logging"]:
|
||||
callback_settings_obj = convert_key_logging_metadata_to_callback(
|
||||
data=AddTeamCallback(**item),
|
||||
team_callback_settings_obj=callback_settings_obj,
|
||||
)
|
||||
elif (
|
||||
user_api_key_dict.team_metadata is not None
|
||||
and "callback_settings" in user_api_key_dict.team_metadata
|
||||
):
|
||||
"""
|
||||
callback_settings = {
|
||||
{
|
||||
'callback_vars': {'langfuse_public_key': 'pk', 'langfuse_secret_key': 'sk_'},
|
||||
'failure_callback': [],
|
||||
'success_callback': ['langfuse', 'langfuse']
|
||||
}
|
||||
}
|
||||
"""
|
||||
team_metadata = user_api_key_dict.team_metadata
|
||||
callback_settings = team_metadata.get("callback_settings", None) or {}
|
||||
callback_settings_obj = TeamCallbackMetadata(**callback_settings)
|
||||
verbose_proxy_logger.debug(
|
||||
"Team callback settings activated: %s", callback_settings_obj
|
||||
)
|
||||
elif user_api_key_dict.team_id is not None:
|
||||
callback_settings_obj = (
|
||||
LiteLLMProxyRequestSetup.add_team_based_callbacks_from_config(
|
||||
team_id=user_api_key_dict.team_id, proxy_config=proxy_config
|
||||
)
|
||||
)
|
||||
return callback_settings_obj
|
||||
|
||||
|
||||
def clean_headers(
|
||||
headers: Headers, litellm_key_header_name: Optional[str] = None
|
||||
) -> dict:
|
||||
"""
|
||||
Removes litellm api key from headers
|
||||
"""
|
||||
special_headers = [v.value.lower() for v in SpecialHeaders._member_map_.values()]
|
||||
special_headers = special_headers
|
||||
if litellm_key_header_name is not None:
|
||||
special_headers.append(litellm_key_header_name.lower())
|
||||
clean_headers = {}
|
||||
|
||||
for header, value in headers.items():
|
||||
if header.lower() not in special_headers:
|
||||
clean_headers[header] = value
|
||||
return clean_headers
|
||||
|
||||
|
||||
class LiteLLMProxyRequestSetup:
|
||||
@staticmethod
|
||||
def _get_timeout_from_request(headers: dict) -> Optional[float]:
|
||||
"""
|
||||
Workaround for client request from Vercel's AI SDK.
|
||||
|
||||
Allow's user to set a timeout in the request headers.
|
||||
|
||||
Example:
|
||||
|
||||
```js
|
||||
const openaiProvider = createOpenAI({
|
||||
baseURL: liteLLM.baseURL,
|
||||
apiKey: liteLLM.apiKey,
|
||||
compatibility: "compatible",
|
||||
headers: {
|
||||
"x-litellm-timeout": "90"
|
||||
},
|
||||
});
|
||||
```
|
||||
"""
|
||||
timeout_header = headers.get("x-litellm-timeout", None)
|
||||
if timeout_header is not None:
|
||||
return float(timeout_header)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_forwardable_headers(
|
||||
headers: Union[Headers, dict],
|
||||
):
|
||||
"""
|
||||
Get the headers that should be forwarded to the LLM Provider.
|
||||
|
||||
Looks for any `x-` headers and sends them to the LLM Provider.
|
||||
"""
|
||||
forwarded_headers = {}
|
||||
for header, value in headers.items():
|
||||
if header.lower().startswith("x-") and not header.lower().startswith(
|
||||
"x-stainless"
|
||||
): # causes openai sdk to fail
|
||||
forwarded_headers[header] = value
|
||||
|
||||
return forwarded_headers
|
||||
|
||||
@staticmethod
|
||||
def get_openai_org_id_from_headers(
|
||||
headers: dict, general_settings: Optional[Dict] = None
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get the OpenAI Org ID from the headers.
|
||||
"""
|
||||
if (
|
||||
general_settings is not None
|
||||
and general_settings.get("forward_openai_org_id") is not True
|
||||
):
|
||||
return None
|
||||
for header, value in headers.items():
|
||||
if header.lower() == "openai-organization":
|
||||
verbose_logger.info(f"found openai org id: {value}, sending to llm")
|
||||
return value
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def add_headers_to_llm_call(
|
||||
headers: dict, user_api_key_dict: UserAPIKeyAuth
|
||||
) -> dict:
|
||||
"""
|
||||
Add headers to the LLM call
|
||||
|
||||
- Checks request headers for forwardable headers
|
||||
- Checks if user information should be added to the headers
|
||||
"""
|
||||
|
||||
returned_headers = LiteLLMProxyRequestSetup._get_forwardable_headers(headers)
|
||||
|
||||
if litellm.add_user_information_to_llm_headers is True:
|
||||
litellm_logging_metadata_headers = (
|
||||
LiteLLMProxyRequestSetup.get_sanitized_user_information_from_key(
|
||||
user_api_key_dict=user_api_key_dict
|
||||
)
|
||||
)
|
||||
for k, v in litellm_logging_metadata_headers.items():
|
||||
if v is not None:
|
||||
returned_headers["x-litellm-{}".format(k)] = v
|
||||
|
||||
return returned_headers
|
||||
|
||||
@staticmethod
|
||||
def add_litellm_data_for_backend_llm_call(
|
||||
*,
|
||||
headers: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
general_settings: Optional[Dict[str, Any]] = None,
|
||||
) -> LitellmDataForBackendLLMCall:
|
||||
"""
|
||||
- Adds forwardable headers
|
||||
- Adds org id
|
||||
"""
|
||||
data = LitellmDataForBackendLLMCall()
|
||||
if (
|
||||
general_settings
|
||||
and general_settings.get("forward_client_headers_to_llm_api") is True
|
||||
):
|
||||
_headers = LiteLLMProxyRequestSetup.add_headers_to_llm_call(
|
||||
headers, user_api_key_dict
|
||||
)
|
||||
if _headers != {}:
|
||||
data["headers"] = _headers
|
||||
_organization = LiteLLMProxyRequestSetup.get_openai_org_id_from_headers(
|
||||
headers, general_settings
|
||||
)
|
||||
if _organization is not None:
|
||||
data["organization"] = _organization
|
||||
|
||||
timeout = LiteLLMProxyRequestSetup._get_timeout_from_request(headers)
|
||||
if timeout is not None:
|
||||
data["timeout"] = timeout
|
||||
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def get_sanitized_user_information_from_key(
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
) -> StandardLoggingUserAPIKeyMetadata:
|
||||
user_api_key_logged_metadata = StandardLoggingUserAPIKeyMetadata(
|
||||
user_api_key_hash=user_api_key_dict.api_key, # just the hashed token
|
||||
user_api_key_alias=user_api_key_dict.key_alias,
|
||||
user_api_key_team_id=user_api_key_dict.team_id,
|
||||
user_api_key_user_id=user_api_key_dict.user_id,
|
||||
user_api_key_org_id=user_api_key_dict.org_id,
|
||||
user_api_key_team_alias=user_api_key_dict.team_alias,
|
||||
user_api_key_end_user_id=user_api_key_dict.end_user_id,
|
||||
user_api_key_user_email=user_api_key_dict.user_email,
|
||||
)
|
||||
return user_api_key_logged_metadata
|
||||
|
||||
@staticmethod
|
||||
def add_key_level_controls(
|
||||
key_metadata: dict, data: dict, _metadata_variable_name: str
|
||||
):
|
||||
if "cache" in key_metadata:
|
||||
data["cache"] = {}
|
||||
if isinstance(key_metadata["cache"], dict):
|
||||
for k, v in key_metadata["cache"].items():
|
||||
if k in SupportedCacheControls:
|
||||
data["cache"][k] = v
|
||||
|
||||
## KEY-LEVEL SPEND LOGS / TAGS
|
||||
if "tags" in key_metadata and key_metadata["tags"] is not None:
|
||||
data[_metadata_variable_name][
|
||||
"tags"
|
||||
] = LiteLLMProxyRequestSetup._merge_tags(
|
||||
request_tags=data[_metadata_variable_name].get("tags"),
|
||||
tags_to_add=key_metadata["tags"],
|
||||
)
|
||||
if "spend_logs_metadata" in key_metadata and isinstance(
|
||||
key_metadata["spend_logs_metadata"], dict
|
||||
):
|
||||
if "spend_logs_metadata" in data[_metadata_variable_name] and isinstance(
|
||||
data[_metadata_variable_name]["spend_logs_metadata"], dict
|
||||
):
|
||||
for key, value in key_metadata["spend_logs_metadata"].items():
|
||||
if (
|
||||
key not in data[_metadata_variable_name]["spend_logs_metadata"]
|
||||
): # don't override k-v pair sent by request (user request)
|
||||
data[_metadata_variable_name]["spend_logs_metadata"][
|
||||
key
|
||||
] = value
|
||||
else:
|
||||
data[_metadata_variable_name]["spend_logs_metadata"] = key_metadata[
|
||||
"spend_logs_metadata"
|
||||
]
|
||||
|
||||
## KEY-LEVEL DISABLE FALLBACKS
|
||||
if "disable_fallbacks" in key_metadata and isinstance(
|
||||
key_metadata["disable_fallbacks"], bool
|
||||
):
|
||||
data["disable_fallbacks"] = key_metadata["disable_fallbacks"]
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def _merge_tags(request_tags: Optional[list], tags_to_add: Optional[list]) -> list:
|
||||
"""
|
||||
Helper function to merge two lists of tags, ensuring no duplicates.
|
||||
|
||||
Args:
|
||||
request_tags (Optional[list]): List of tags from the original request
|
||||
tags_to_add (Optional[list]): List of tags to add
|
||||
|
||||
Returns:
|
||||
list: Combined list of unique tags
|
||||
"""
|
||||
final_tags = []
|
||||
|
||||
if request_tags and isinstance(request_tags, list):
|
||||
final_tags.extend(request_tags)
|
||||
|
||||
if tags_to_add and isinstance(tags_to_add, list):
|
||||
for tag in tags_to_add:
|
||||
if tag not in final_tags:
|
||||
final_tags.append(tag)
|
||||
|
||||
return final_tags
|
||||
|
||||
@staticmethod
|
||||
def add_team_based_callbacks_from_config(
|
||||
team_id: str,
|
||||
proxy_config: ProxyConfig,
|
||||
) -> Optional[TeamCallbackMetadata]:
|
||||
"""
|
||||
Add team-based callbacks from the config
|
||||
"""
|
||||
team_config = proxy_config.load_team_config(team_id=team_id)
|
||||
if len(team_config.keys()) == 0:
|
||||
return None
|
||||
|
||||
callback_vars_dict = {**team_config.get("callback_vars", team_config)}
|
||||
callback_vars_dict.pop("team_id", None)
|
||||
callback_vars_dict.pop("success_callback", None)
|
||||
callback_vars_dict.pop("failure_callback", None)
|
||||
|
||||
return TeamCallbackMetadata(
|
||||
success_callback=team_config.get("success_callback", None),
|
||||
failure_callback=team_config.get("failure_callback", None),
|
||||
callback_vars=callback_vars_dict,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def add_request_tag_to_metadata(
|
||||
llm_router: Optional[Router],
|
||||
headers: dict,
|
||||
data: dict,
|
||||
) -> Optional[List[str]]:
|
||||
tags = None
|
||||
|
||||
# Check request headers for tags
|
||||
if "x-litellm-tags" in headers:
|
||||
if isinstance(headers["x-litellm-tags"], str):
|
||||
_tags = headers["x-litellm-tags"].split(",")
|
||||
tags = [tag.strip() for tag in _tags]
|
||||
elif isinstance(headers["x-litellm-tags"], list):
|
||||
tags = headers["x-litellm-tags"]
|
||||
# Check request body for tags
|
||||
if "tags" in data and isinstance(data["tags"], list):
|
||||
tags = data["tags"]
|
||||
|
||||
return tags
|
||||
|
||||
|
||||
async def add_litellm_data_to_request( # noqa: PLR0915
|
||||
data: dict,
|
||||
request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
proxy_config: ProxyConfig,
|
||||
general_settings: Optional[Dict[str, Any]] = None,
|
||||
version: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Adds LiteLLM-specific data to the request.
|
||||
|
||||
Args:
|
||||
data (dict): The data dictionary to be modified.
|
||||
request (Request): The incoming request.
|
||||
user_api_key_dict (UserAPIKeyAuth): The user API key dictionary.
|
||||
general_settings (Optional[Dict[str, Any]], optional): General settings. Defaults to None.
|
||||
version (Optional[str], optional): Version. Defaults to None.
|
||||
|
||||
Returns:
|
||||
dict: The modified data dictionary.
|
||||
|
||||
"""
|
||||
|
||||
from litellm.proxy.proxy_server import llm_router, premium_user
|
||||
|
||||
safe_add_api_version_from_query_params(data, request)
|
||||
|
||||
_headers = clean_headers(
|
||||
request.headers,
|
||||
litellm_key_header_name=(
|
||||
general_settings.get("litellm_key_header_name")
|
||||
if general_settings is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
data.update(
|
||||
LiteLLMProxyRequestSetup.add_litellm_data_for_backend_llm_call(
|
||||
headers=_headers,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
general_settings=general_settings,
|
||||
)
|
||||
)
|
||||
|
||||
# Include original request and headers in the data
|
||||
data["proxy_server_request"] = {
|
||||
"url": str(request.url),
|
||||
"method": request.method,
|
||||
"headers": _headers,
|
||||
"body": copy.copy(data), # use copy instead of deepcopy
|
||||
}
|
||||
|
||||
## Dynamic api version (Azure OpenAI endpoints) ##
|
||||
try:
|
||||
query_params = request.query_params
|
||||
# Convert query parameters to a dictionary (optional)
|
||||
query_dict = dict(query_params)
|
||||
except KeyError:
|
||||
query_dict = {}
|
||||
|
||||
## check for api version in query params
|
||||
dynamic_api_version: Optional[str] = query_dict.get("api-version")
|
||||
|
||||
if dynamic_api_version is not None: # only pass, if set
|
||||
data["api_version"] = dynamic_api_version
|
||||
|
||||
## Forward any LLM API Provider specific headers in extra_headers
|
||||
add_provider_specific_headers_to_request(data=data, headers=_headers)
|
||||
|
||||
## Cache Controls
|
||||
headers = request.headers
|
||||
verbose_proxy_logger.debug("Request Headers: %s", headers)
|
||||
cache_control_header = headers.get("Cache-Control", None)
|
||||
if cache_control_header:
|
||||
cache_dict = parse_cache_control(cache_control_header)
|
||||
data["ttl"] = cache_dict.get("s-maxage")
|
||||
|
||||
verbose_proxy_logger.debug("receiving data: %s", data)
|
||||
|
||||
_metadata_variable_name = _get_metadata_variable_name(request)
|
||||
|
||||
if data.get(_metadata_variable_name, None) is None:
|
||||
data[_metadata_variable_name] = {}
|
||||
|
||||
# We want to log the "metadata" from the client side request. Avoid circular reference by not directly assigning metadata to itself.
|
||||
if "metadata" in data and data["metadata"] is not None:
|
||||
data[_metadata_variable_name]["requester_metadata"] = copy.deepcopy(
|
||||
data["metadata"]
|
||||
)
|
||||
|
||||
user_api_key_logged_metadata = (
|
||||
LiteLLMProxyRequestSetup.get_sanitized_user_information_from_key(
|
||||
user_api_key_dict=user_api_key_dict
|
||||
)
|
||||
)
|
||||
data[_metadata_variable_name].update(user_api_key_logged_metadata)
|
||||
data[_metadata_variable_name][
|
||||
"user_api_key"
|
||||
] = (
|
||||
user_api_key_dict.api_key
|
||||
) # this is just the hashed token. [TODO]: replace variable name in repo.
|
||||
|
||||
data[_metadata_variable_name]["user_api_end_user_max_budget"] = getattr(
|
||||
user_api_key_dict, "end_user_max_budget", None
|
||||
)
|
||||
|
||||
data[_metadata_variable_name]["litellm_api_version"] = version
|
||||
|
||||
if general_settings is not None:
|
||||
data[_metadata_variable_name][
|
||||
"global_max_parallel_requests"
|
||||
] = general_settings.get("global_max_parallel_requests", None)
|
||||
|
||||
### KEY-LEVEL Controls
|
||||
key_metadata = user_api_key_dict.metadata
|
||||
data = LiteLLMProxyRequestSetup.add_key_level_controls(
|
||||
key_metadata=key_metadata,
|
||||
data=data,
|
||||
_metadata_variable_name=_metadata_variable_name,
|
||||
)
|
||||
## TEAM-LEVEL SPEND LOGS/TAGS
|
||||
team_metadata = user_api_key_dict.team_metadata or {}
|
||||
if "tags" in team_metadata and team_metadata["tags"] is not None:
|
||||
data[_metadata_variable_name]["tags"] = LiteLLMProxyRequestSetup._merge_tags(
|
||||
request_tags=data[_metadata_variable_name].get("tags"),
|
||||
tags_to_add=team_metadata["tags"],
|
||||
)
|
||||
if "spend_logs_metadata" in team_metadata and isinstance(
|
||||
team_metadata["spend_logs_metadata"], dict
|
||||
):
|
||||
if "spend_logs_metadata" in data[_metadata_variable_name] and isinstance(
|
||||
data[_metadata_variable_name]["spend_logs_metadata"], dict
|
||||
):
|
||||
for key, value in team_metadata["spend_logs_metadata"].items():
|
||||
if (
|
||||
key not in data[_metadata_variable_name]["spend_logs_metadata"]
|
||||
): # don't override k-v pair sent by request (user request)
|
||||
data[_metadata_variable_name]["spend_logs_metadata"][key] = value
|
||||
else:
|
||||
data[_metadata_variable_name]["spend_logs_metadata"] = team_metadata[
|
||||
"spend_logs_metadata"
|
||||
]
|
||||
|
||||
# Team spend, budget - used by prometheus.py
|
||||
data[_metadata_variable_name][
|
||||
"user_api_key_team_max_budget"
|
||||
] = user_api_key_dict.team_max_budget
|
||||
data[_metadata_variable_name][
|
||||
"user_api_key_team_spend"
|
||||
] = user_api_key_dict.team_spend
|
||||
|
||||
# API Key spend, budget - used by prometheus.py
|
||||
data[_metadata_variable_name]["user_api_key_spend"] = user_api_key_dict.spend
|
||||
data[_metadata_variable_name][
|
||||
"user_api_key_max_budget"
|
||||
] = user_api_key_dict.max_budget
|
||||
data[_metadata_variable_name][
|
||||
"user_api_key_model_max_budget"
|
||||
] = user_api_key_dict.model_max_budget
|
||||
|
||||
data[_metadata_variable_name]["user_api_key_metadata"] = user_api_key_dict.metadata
|
||||
_headers = dict(request.headers)
|
||||
_headers.pop(
|
||||
"authorization", None
|
||||
) # do not store the original `sk-..` api key in the db
|
||||
data[_metadata_variable_name]["headers"] = _headers
|
||||
data[_metadata_variable_name]["endpoint"] = str(request.url)
|
||||
|
||||
# OTEL Controls / Tracing
|
||||
# Add the OTEL Parent Trace before sending it LiteLLM
|
||||
data[_metadata_variable_name][
|
||||
"litellm_parent_otel_span"
|
||||
] = user_api_key_dict.parent_otel_span
|
||||
_add_otel_traceparent_to_data(data, request=request)
|
||||
|
||||
### END-USER SPECIFIC PARAMS ###
|
||||
if user_api_key_dict.allowed_model_region is not None:
|
||||
data["allowed_model_region"] = user_api_key_dict.allowed_model_region
|
||||
start_time = time.time()
|
||||
## [Enterprise Only]
|
||||
# Add User-IP Address
|
||||
requester_ip_address = ""
|
||||
if premium_user is True:
|
||||
# Only set the IP Address for Enterprise Users
|
||||
|
||||
# logic for tracking IP Address
|
||||
if (
|
||||
general_settings is not None
|
||||
and general_settings.get("use_x_forwarded_for") is True
|
||||
and request is not None
|
||||
and hasattr(request, "headers")
|
||||
and "x-forwarded-for" in request.headers
|
||||
):
|
||||
requester_ip_address = request.headers["x-forwarded-for"]
|
||||
elif (
|
||||
request is not None
|
||||
and hasattr(request, "client")
|
||||
and hasattr(request.client, "host")
|
||||
and request.client is not None
|
||||
):
|
||||
requester_ip_address = request.client.host
|
||||
data[_metadata_variable_name]["requester_ip_address"] = requester_ip_address
|
||||
|
||||
# Check if using tag based routing
|
||||
tags = LiteLLMProxyRequestSetup.add_request_tag_to_metadata(
|
||||
llm_router=llm_router,
|
||||
headers=dict(request.headers),
|
||||
data=data,
|
||||
)
|
||||
|
||||
if tags is not None:
|
||||
data[_metadata_variable_name]["tags"] = tags
|
||||
|
||||
# Team Callbacks controls
|
||||
callback_settings_obj = _get_dynamic_logging_metadata(
|
||||
user_api_key_dict=user_api_key_dict, proxy_config=proxy_config
|
||||
)
|
||||
if callback_settings_obj is not None:
|
||||
data["success_callback"] = callback_settings_obj.success_callback
|
||||
data["failure_callback"] = callback_settings_obj.failure_callback
|
||||
|
||||
if callback_settings_obj.callback_vars is not None:
|
||||
# unpack callback_vars in data
|
||||
for k, v in callback_settings_obj.callback_vars.items():
|
||||
data[k] = v
|
||||
|
||||
# Guardrails
|
||||
move_guardrails_to_metadata(
|
||||
data=data,
|
||||
_metadata_variable_name=_metadata_variable_name,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
|
||||
# Team Model Aliases
|
||||
_update_model_if_team_alias_exists(
|
||||
data=data,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"[PROXY] returned data from litellm_pre_call_utils: %s", data
|
||||
)
|
||||
|
||||
## ENFORCED PARAMS CHECK
|
||||
# loop through each enforced param
|
||||
# example enforced_params ['user', 'metadata', 'metadata.generation_name']
|
||||
_enforced_params_check(
|
||||
request_body=data,
|
||||
general_settings=general_settings,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
premium_user=premium_user,
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
asyncio.create_task(
|
||||
service_logger_obj.async_service_success_hook(
|
||||
service=ServiceTypes.PROXY_PRE_CALL,
|
||||
duration=end_time - start_time,
|
||||
call_type="add_litellm_data_to_request",
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
)
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def _update_model_if_team_alias_exists(
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
) -> None:
|
||||
"""
|
||||
Update the model if the team alias exists
|
||||
|
||||
If a alias map has been set on a team, then we want to make the request with the model the team alias is pointing to
|
||||
|
||||
eg.
|
||||
- user calls `gpt-4o`
|
||||
- team.model_alias_map = {
|
||||
"gpt-4o": "gpt-4o-team-1"
|
||||
}
|
||||
- requested_model = "gpt-4o-team-1"
|
||||
"""
|
||||
_model = data.get("model")
|
||||
if (
|
||||
_model
|
||||
and user_api_key_dict.team_model_aliases
|
||||
and _model in user_api_key_dict.team_model_aliases
|
||||
):
|
||||
data["model"] = user_api_key_dict.team_model_aliases[_model]
|
||||
return
|
||||
|
||||
|
||||
def _get_enforced_params(
|
||||
general_settings: Optional[dict], user_api_key_dict: UserAPIKeyAuth
|
||||
) -> Optional[list]:
|
||||
enforced_params: Optional[list] = None
|
||||
if general_settings is not None:
|
||||
enforced_params = general_settings.get("enforced_params")
|
||||
if (
|
||||
"service_account_settings" in general_settings
|
||||
and check_if_token_is_service_account(user_api_key_dict) is True
|
||||
):
|
||||
service_account_settings = general_settings["service_account_settings"]
|
||||
if "enforced_params" in service_account_settings:
|
||||
if enforced_params is None:
|
||||
enforced_params = []
|
||||
enforced_params.extend(service_account_settings["enforced_params"])
|
||||
if user_api_key_dict.metadata.get("enforced_params", None) is not None:
|
||||
if enforced_params is None:
|
||||
enforced_params = []
|
||||
enforced_params.extend(user_api_key_dict.metadata["enforced_params"])
|
||||
return enforced_params
|
||||
|
||||
|
||||
def check_if_token_is_service_account(valid_token: UserAPIKeyAuth) -> bool:
|
||||
"""
|
||||
Checks if the token is a service account
|
||||
|
||||
Returns:
|
||||
bool: True if token is a service account
|
||||
|
||||
"""
|
||||
if valid_token.metadata:
|
||||
if "service_account_id" in valid_token.metadata:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _enforced_params_check(
|
||||
request_body: dict,
|
||||
general_settings: Optional[dict],
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
premium_user: bool,
|
||||
) -> bool:
|
||||
"""
|
||||
If enforced params are set, check if the request body contains the enforced params.
|
||||
"""
|
||||
enforced_params: Optional[list] = _get_enforced_params(
|
||||
general_settings=general_settings, user_api_key_dict=user_api_key_dict
|
||||
)
|
||||
if enforced_params is None:
|
||||
return True
|
||||
if enforced_params is not None and premium_user is not True:
|
||||
raise ValueError(
|
||||
f"Enforced Params is an Enterprise feature. Enforced Params: {enforced_params}. {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
|
||||
for enforced_param in enforced_params:
|
||||
_enforced_params = enforced_param.split(".")
|
||||
if len(_enforced_params) == 1:
|
||||
if _enforced_params[0] not in request_body:
|
||||
raise ValueError(
|
||||
f"BadRequest please pass param={_enforced_params[0]} in request body. This is a required param"
|
||||
)
|
||||
elif len(_enforced_params) == 2:
|
||||
# this is a scenario where user requires request['metadata']['generation_name'] to exist
|
||||
if _enforced_params[0] not in request_body:
|
||||
raise ValueError(
|
||||
f"BadRequest please pass param={_enforced_params[0]} in request body. This is a required param"
|
||||
)
|
||||
if _enforced_params[1] not in request_body[_enforced_params[0]]:
|
||||
raise ValueError(
|
||||
f"BadRequest please pass param=[{_enforced_params[0]}][{_enforced_params[1]}] in request body. This is a required param"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def _add_guardrails_from_key_or_team_metadata(
|
||||
key_metadata: Optional[dict],
|
||||
team_metadata: Optional[dict],
|
||||
data: dict,
|
||||
metadata_variable_name: str,
|
||||
) -> None:
|
||||
"""
|
||||
Helper add guardrails from key or team metadata to request data
|
||||
|
||||
Args:
|
||||
key_metadata: The key metadata dictionary to check for guardrails
|
||||
team_metadata: The team metadata dictionary to check for guardrails
|
||||
data: The request data to update
|
||||
metadata_variable_name: The name of the metadata field in data
|
||||
|
||||
"""
|
||||
from litellm.proxy.utils import _premium_user_check
|
||||
|
||||
for _management_object_metadata in [key_metadata, team_metadata]:
|
||||
if _management_object_metadata and "guardrails" in _management_object_metadata:
|
||||
if len(_management_object_metadata["guardrails"]) > 0:
|
||||
_premium_user_check()
|
||||
|
||||
data[metadata_variable_name]["guardrails"] = _management_object_metadata[
|
||||
"guardrails"
|
||||
]
|
||||
|
||||
|
||||
def move_guardrails_to_metadata(
|
||||
data: dict,
|
||||
_metadata_variable_name: str,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
):
|
||||
"""
|
||||
Helper to add guardrails from request to metadata
|
||||
|
||||
- If guardrails set on API Key metadata then sets guardrails on request metadata
|
||||
- If guardrails not set on API key, then checks request metadata
|
||||
"""
|
||||
# Check key-level guardrails
|
||||
_add_guardrails_from_key_or_team_metadata(
|
||||
key_metadata=user_api_key_dict.metadata,
|
||||
team_metadata=user_api_key_dict.team_metadata,
|
||||
data=data,
|
||||
metadata_variable_name=_metadata_variable_name,
|
||||
)
|
||||
|
||||
# Check request-level guardrails
|
||||
if "guardrails" in data:
|
||||
data[_metadata_variable_name]["guardrails"] = data["guardrails"]
|
||||
del data["guardrails"]
|
||||
|
||||
if "guardrail_config" in data:
|
||||
data[_metadata_variable_name]["guardrail_config"] = data["guardrail_config"]
|
||||
del data["guardrail_config"]
|
||||
|
||||
|
||||
def add_provider_specific_headers_to_request(
|
||||
data: dict,
|
||||
headers: dict,
|
||||
):
|
||||
anthropic_headers = {}
|
||||
# boolean to indicate if a header was added
|
||||
added_header = False
|
||||
for header in ANTHROPIC_API_HEADERS:
|
||||
if header in headers:
|
||||
header_value = headers[header]
|
||||
anthropic_headers[header] = header_value
|
||||
added_header = True
|
||||
|
||||
if added_header is True:
|
||||
data["provider_specific_header"] = ProviderSpecificHeader(
|
||||
custom_llm_provider="anthropic",
|
||||
extra_headers=anthropic_headers,
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def _add_otel_traceparent_to_data(data: dict, request: Request):
|
||||
from litellm.proxy.proxy_server import open_telemetry_logger
|
||||
|
||||
if data is None:
|
||||
return
|
||||
if open_telemetry_logger is None:
|
||||
# if user is not use OTEL don't send extra_headers
|
||||
# relevant issue: https://github.com/BerriAI/litellm/issues/4448
|
||||
return
|
||||
|
||||
if litellm.forward_traceparent_to_llm_provider is True:
|
||||
if request.headers:
|
||||
if "traceparent" in request.headers:
|
||||
# we want to forward this to the LLM Provider
|
||||
# Relevant issue: https://github.com/BerriAI/litellm/issues/4419
|
||||
# pass this in extra_headers
|
||||
if "extra_headers" not in data:
|
||||
data["extra_headers"] = {}
|
||||
_exra_headers = data["extra_headers"]
|
||||
if "traceparent" not in _exra_headers:
|
||||
_exra_headers["traceparent"] = request.headers["traceparent"]
|
Loading…
Add table
Add a link
Reference in a new issue