diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 5aca61a96..2e670de85 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -1,8 +1,10 @@ import copy +from typing import TYPE_CHECKING, Any, Dict, Optional + from fastapi import Request -from typing import Any, Dict, Optional, TYPE_CHECKING + +from litellm._logging import verbose_logger, verbose_proxy_logger from litellm.proxy._types import UserAPIKeyAuth -from litellm._logging import verbose_proxy_logger, verbose_logger from litellm.types.utils import SupportedCacheControls if TYPE_CHECKING: @@ -27,6 +29,20 @@ def parse_cache_control(cache_control): return cache_dict +def _get_metadata_variable_name(request: Request) -> str: + """ + Helper to return what the "metadata" field should be called in the request data + + For all /thread endpoints we need to call this "litellm_metadata" + + For ALL other endpoints we call this "metadata + """ + if "thread" in request.url.path: + return "litellm_metadata" + else: + return "metadata" + + async def add_litellm_data_to_request( data: dict, request: Request, @@ -80,48 +96,58 @@ async def add_litellm_data_to_request( verbose_proxy_logger.debug("receiving data: %s", data) - if "metadata" not in data: - data["metadata"] = {} - data["metadata"]["user_api_key"] = user_api_key_dict.api_key - data["metadata"]["user_api_key_alias"] = getattr( + _metadata_variable_name = _get_metadata_variable_name(request) + + if _metadata_variable_name not in data: + data[_metadata_variable_name] = {} + data[_metadata_variable_name]["user_api_key"] = user_api_key_dict.api_key + data[_metadata_variable_name]["user_api_key_alias"] = getattr( user_api_key_dict, "key_alias", None ) - data["metadata"]["user_api_end_user_max_budget"] = getattr( + data[_metadata_variable_name]["user_api_end_user_max_budget"] = getattr( user_api_key_dict, "end_user_max_budget", None ) - data["metadata"]["litellm_api_version"] = version + data[_metadata_variable_name]["litellm_api_version"] = version if general_settings is not None: - data["metadata"]["global_max_parallel_requests"] = general_settings.get( - "global_max_parallel_requests", None + data[_metadata_variable_name]["global_max_parallel_requests"] = ( + general_settings.get("global_max_parallel_requests", None) ) - data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - data["metadata"]["user_api_key_org_id"] = user_api_key_dict.org_id - data["metadata"]["user_api_key_team_id"] = getattr( + data[_metadata_variable_name]["user_api_key_user_id"] = user_api_key_dict.user_id + data[_metadata_variable_name]["user_api_key_org_id"] = user_api_key_dict.org_id + data[_metadata_variable_name]["user_api_key_team_id"] = getattr( user_api_key_dict, "team_id", None ) - data["metadata"]["user_api_key_team_alias"] = getattr( + data[_metadata_variable_name]["user_api_key_team_alias"] = getattr( user_api_key_dict, "team_alias", None ) # Team spend, budget - used by prometheus.py - data["metadata"]["user_api_key_team_max_budget"] = user_api_key_dict.team_max_budget - data["metadata"]["user_api_key_team_spend"] = user_api_key_dict.team_spend + data[_metadata_variable_name][ + "user_api_key_team_max_budget" + ] = user_api_key_dict.team_max_budget + data[_metadata_variable_name][ + "user_api_key_team_spend" + ] = user_api_key_dict.team_spend # API Key spend, budget - used by prometheus.py - data["metadata"]["user_api_key_spend"] = user_api_key_dict.spend - data["metadata"]["user_api_key_max_budget"] = user_api_key_dict.max_budget + data[_metadata_variable_name]["user_api_key_spend"] = user_api_key_dict.spend + data[_metadata_variable_name][ + "user_api_key_max_budget" + ] = user_api_key_dict.max_budget - data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata + data[_metadata_variable_name]["user_api_key_metadata"] = user_api_key_dict.metadata _headers = dict(request.headers) _headers.pop( "authorization", None ) # do not store the original `sk-..` api key in the db - data["metadata"]["headers"] = _headers - data["metadata"]["endpoint"] = str(request.url) + data[_metadata_variable_name]["headers"] = _headers + data[_metadata_variable_name]["endpoint"] = str(request.url) # Add the OTEL Parent Trace before sending it LiteLLM - data["metadata"]["litellm_parent_otel_span"] = user_api_key_dict.parent_otel_span + data[_metadata_variable_name][ + "litellm_parent_otel_span" + ] = user_api_key_dict.parent_otel_span ### END-USER SPECIFIC PARAMS ### if user_api_key_dict.allowed_model_region is not None: @@ -136,7 +162,7 @@ async def add_litellm_data_to_request( pass else: team_id = team_config.pop("team_id", None) - data["metadata"]["team_id"] = team_id + data[_metadata_variable_name]["team_id"] = team_id data = { **team_config, **data,