diff --git a/litellm/__init__.py b/litellm/__init__.py index 11b34f504..3282660e9 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -80,6 +80,9 @@ turn_off_message_logging: Optional[bool] = False log_raw_request_response: bool = False redact_messages_in_exceptions: Optional[bool] = False redact_user_api_key_info: Optional[bool] = False +add_user_information_to_llm_headers: Optional[bool] = ( + None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers +) store_audit_logs = False # Enterprise feature, allow users to see audit logs ## end of callbacks ############# diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index fd7335201..5201bfe1e 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -2798,6 +2798,52 @@ def get_standard_logging_object_payload( return None +def get_standard_logging_metadata( + metadata: Optional[Dict[str, Any]] +) -> StandardLoggingMetadata: + """ + Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata. + + Args: + metadata (Optional[Dict[str, Any]]): The original metadata dictionary. + + Returns: + StandardLoggingMetadata: A StandardLoggingMetadata object containing the cleaned metadata. + + Note: + - If the input metadata is None or not a dictionary, an empty StandardLoggingMetadata object is returned. + - If 'user_api_key' is present in metadata and is a valid SHA256 hash, it's stored as 'user_api_key_hash'. + """ + # Initialize with default values + clean_metadata = StandardLoggingMetadata( + user_api_key_hash=None, + user_api_key_alias=None, + user_api_key_team_id=None, + user_api_key_org_id=None, + user_api_key_user_id=None, + user_api_key_team_alias=None, + spend_logs_metadata=None, + requester_ip_address=None, + requester_metadata=None, + ) + if isinstance(metadata, dict): + # Filter the metadata dictionary to include only the specified keys + clean_metadata = StandardLoggingMetadata( + **{ # type: ignore + key: metadata[key] + for key in StandardLoggingMetadata.__annotations__.keys() + if key in metadata + } + ) + + if metadata.get("user_api_key") is not None: + if is_valid_sha256_hash(str(metadata.get("user_api_key"))): + clean_metadata["user_api_key_hash"] = metadata.get( + "user_api_key" + ) # this is the hash + return clean_metadata + + def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]): if litellm_params is None: litellm_params = {} diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 62f4ce440..9ee547652 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -16,7 +16,10 @@ from litellm.proxy._types import ( UserAPIKeyAuth, ) from litellm.proxy.auth.auth_utils import get_request_route -from litellm.types.utils import SupportedCacheControls +from litellm.types.utils import ( + StandardLoggingUserAPIKeyMetadata, + SupportedCacheControls, +) if TYPE_CHECKING: from litellm.proxy.proxy_server import ProxyConfig as _ProxyConfig @@ -159,56 +162,107 @@ def clean_headers( return clean_headers -def get_forwardable_headers( - headers: Union[Headers, dict], -): - """ - Get the headers that should be forwarded to the LLM Provider. - - Looks for any `x-` headers and sends them to the LLM Provider. - """ - forwarded_headers = {} - for header, value in headers.items(): - if header.lower().startswith("x-") and not header.lower().startswith( - "x-stainless" - ): # causes openai sdk to fail - forwarded_headers[header] = value - - return forwarded_headers - - -def get_openai_org_id_from_headers( - headers: dict, general_settings: Optional[Dict] = None -) -> Optional[str]: - """ - Get the OpenAI Org ID from the headers. - """ - if ( - general_settings is not None - and general_settings.get("forward_openai_org_id") is not True +class LiteLLMProxyRequestSetup: + @staticmethod + def _get_forwardable_headers( + headers: Union[Headers, dict], ): + """ + Get the headers that should be forwarded to the LLM Provider. + + Looks for any `x-` headers and sends them to the LLM Provider. + """ + forwarded_headers = {} + for header, value in headers.items(): + if header.lower().startswith("x-") and not header.lower().startswith( + "x-stainless" + ): # causes openai sdk to fail + forwarded_headers[header] = value + + return forwarded_headers + + @staticmethod + def get_openai_org_id_from_headers( + headers: dict, general_settings: Optional[Dict] = None + ) -> Optional[str]: + """ + Get the OpenAI Org ID from the headers. + """ + if ( + general_settings is not None + and general_settings.get("forward_openai_org_id") is not True + ): + return None + for header, value in headers.items(): + if header.lower() == "openai-organization": + return value return None - for header, value in headers.items(): - if header.lower() == "openai-organization": - return value - return None + @staticmethod + def add_headers_to_llm_call( + headers: dict, user_api_key_dict: UserAPIKeyAuth + ) -> dict: + """ + Add headers to the LLM call -def add_litellm_data_for_backend_llm_call( - headers: dict, general_settings: Optional[Dict[str, Any]] = None -) -> LitellmDataForBackendLLMCall: - """ - - Adds forwardable headers - - Adds org id - """ - data = LitellmDataForBackendLLMCall() - _headers = get_forwardable_headers(headers) - if _headers != {}: - data["headers"] = _headers - _organization = get_openai_org_id_from_headers(headers, general_settings) - if _organization is not None: - data["organization"] = _organization - return data + - Checks request headers for forwardable headers + - Checks if user information should be added to the headers + """ + from litellm.litellm_core_utils.litellm_logging import ( + get_standard_logging_metadata, + ) + + returned_headers = LiteLLMProxyRequestSetup._get_forwardable_headers(headers) + + if litellm.add_user_information_to_llm_headers is True: + litellm_logging_metadata_headers = ( + LiteLLMProxyRequestSetup.get_sanitized_user_information_from_key( + user_api_key_dict=user_api_key_dict + ) + ) + for k, v in litellm_logging_metadata_headers.items(): + if v is not None: + returned_headers["x-litellm-{}".format(k)] = v + + return returned_headers + + @staticmethod + def add_litellm_data_for_backend_llm_call( + *, + headers: dict, + user_api_key_dict: UserAPIKeyAuth, + general_settings: Optional[Dict[str, Any]] = None, + ) -> LitellmDataForBackendLLMCall: + """ + - Adds forwardable headers + - Adds org id + """ + data = LitellmDataForBackendLLMCall() + _headers = LiteLLMProxyRequestSetup.add_headers_to_llm_call( + headers, user_api_key_dict + ) + if _headers != {}: + data["headers"] = _headers + _organization = LiteLLMProxyRequestSetup.get_openai_org_id_from_headers( + headers, general_settings + ) + if _organization is not None: + data["organization"] = _organization + return data + + @staticmethod + def get_sanitized_user_information_from_key( + user_api_key_dict: UserAPIKeyAuth, + ) -> StandardLoggingUserAPIKeyMetadata: + user_api_key_logged_metadata = StandardLoggingUserAPIKeyMetadata( + user_api_key_hash=user_api_key_dict.api_key, # just the hashed token + user_api_key_alias=user_api_key_dict.key_alias, + user_api_key_team_id=user_api_key_dict.team_id, + user_api_key_user_id=user_api_key_dict.user_id, + user_api_key_org_id=user_api_key_dict.org_id, + user_api_key_team_alias=user_api_key_dict.team_alias, + ) + return user_api_key_logged_metadata async def add_litellm_data_to_request( # noqa: PLR0915 @@ -246,7 +300,13 @@ async def add_litellm_data_to_request( # noqa: PLR0915 ), ) - data.update(add_litellm_data_for_backend_llm_call(_headers, general_settings)) + data.update( + LiteLLMProxyRequestSetup.add_litellm_data_for_backend_llm_call( + headers=_headers, + user_api_key_dict=user_api_key_dict, + general_settings=general_settings, + ) + ) # Include original request and headers in the data data["proxy_server_request"] = { @@ -294,13 +354,22 @@ async def add_litellm_data_to_request( # noqa: PLR0915 data["metadata"] ) - data[_metadata_variable_name]["user_api_key"] = user_api_key_dict.api_key - data[_metadata_variable_name]["user_api_key_alias"] = getattr( - user_api_key_dict, "key_alias", None + user_api_key_logged_metadata = ( + LiteLLMProxyRequestSetup.get_sanitized_user_information_from_key( + user_api_key_dict=user_api_key_dict + ) ) + data[_metadata_variable_name].update(user_api_key_logged_metadata) + data[_metadata_variable_name][ + "user_api_key" + ] = ( + user_api_key_dict.api_key + ) # this is just the hashed token. [TODO]: replace variable name in repo. + data[_metadata_variable_name]["user_api_end_user_max_budget"] = getattr( user_api_key_dict, "end_user_max_budget", None ) + data[_metadata_variable_name]["litellm_api_version"] = version if general_settings is not None: @@ -308,15 +377,6 @@ async def add_litellm_data_to_request( # noqa: PLR0915 general_settings.get("global_max_parallel_requests", None) ) - data[_metadata_variable_name]["user_api_key_user_id"] = user_api_key_dict.user_id - data[_metadata_variable_name]["user_api_key_org_id"] = user_api_key_dict.org_id - data[_metadata_variable_name]["user_api_key_team_id"] = getattr( - user_api_key_dict, "team_id", None - ) - data[_metadata_variable_name]["user_api_key_team_alias"] = getattr( - user_api_key_dict, "team_alias", None - ) - ### KEY-LEVEL Controls key_metadata = user_api_key_dict.metadata if "cache" in key_metadata: diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 8cc0844b3..0b7a29c91 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1412,16 +1412,20 @@ class AdapterCompletionStreamWrapper: raise StopAsyncIteration -class StandardLoggingMetadata(TypedDict): +class StandardLoggingUserAPIKeyMetadata(TypedDict): + user_api_key_hash: Optional[str] # hash of the litellm virtual key used + user_api_key_alias: Optional[str] + user_api_key_org_id: Optional[str] + user_api_key_team_id: Optional[str] + user_api_key_user_id: Optional[str] + user_api_key_team_alias: Optional[str] + + +class StandardLoggingMetadata(StandardLoggingUserAPIKeyMetadata): """ Specific metadata k,v pairs logged to integration for easier cost tracking """ - user_api_key_hash: Optional[str] # hash of the litellm virtual key used - user_api_key_alias: Optional[str] - user_api_key_team_id: Optional[str] - user_api_key_user_id: Optional[str] - user_api_key_team_alias: Optional[str] spend_logs_metadata: Optional[ dict ] # special param to log k,v pairs to spendlogs for a call diff --git a/tests/local_testing/test_proxy_server.py b/tests/local_testing/test_proxy_server.py index e92d84c55..803243557 100644 --- a/tests/local_testing/test_proxy_server.py +++ b/tests/local_testing/test_proxy_server.py @@ -203,7 +203,7 @@ def test_add_headers_to_request(litellm_key_header_name): import json from litellm.proxy.litellm_pre_call_utils import ( clean_headers, - get_forwardable_headers, + LiteLLMProxyRequestSetup, ) headers = { @@ -215,7 +215,9 @@ def test_add_headers_to_request(litellm_key_header_name): request._url = URL(url="/chat/completions") request._body = json.dumps({"model": "gpt-3.5-turbo"}).encode("utf-8") request_headers = clean_headers(headers, litellm_key_header_name) - forwarded_headers = get_forwardable_headers(request_headers) + forwarded_headers = LiteLLMProxyRequestSetup._get_forwardable_headers( + request_headers + ) assert forwarded_headers == {"X-Custom-Header": "Custom-Value"} diff --git a/tests/local_testing/test_proxy_utils.py b/tests/local_testing/test_proxy_utils.py index a74e9e78b..74ef75392 100644 --- a/tests/local_testing/test_proxy_utils.py +++ b/tests/local_testing/test_proxy_utils.py @@ -371,12 +371,12 @@ def test_is_request_body_safe_model_enabled( def test_reading_openai_org_id_from_headers(): - from litellm.proxy.litellm_pre_call_utils import get_openai_org_id_from_headers + from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup headers = { "OpenAI-Organization": "test_org_id", } - org_id = get_openai_org_id_from_headers(headers) + org_id = LiteLLMProxyRequestSetup.get_openai_org_id_from_headers(headers) assert org_id == "test_org_id" @@ -399,11 +399,44 @@ def test_reading_openai_org_id_from_headers(): ) def test_add_litellm_data_for_backend_llm_call(headers, expected_data): import json - from litellm.proxy.litellm_pre_call_utils import ( - add_litellm_data_for_backend_llm_call, + from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup + from litellm.proxy._types import UserAPIKeyAuth + + user_api_key_dict = UserAPIKeyAuth( + api_key="test_api_key", user_id="test_user_id", org_id="test_org_id" ) - data = add_litellm_data_for_backend_llm_call(headers) + data = LiteLLMProxyRequestSetup.add_litellm_data_for_backend_llm_call( + headers=headers, + user_api_key_dict=user_api_key_dict, + general_settings=None, + ) + + assert json.dumps(data, sort_keys=True) == json.dumps(expected_data, sort_keys=True) + + +def test_foward_litellm_user_info_to_backend_llm_call(): + import json + + litellm.add_user_information_to_llm_headers = True + + from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup + from litellm.proxy._types import UserAPIKeyAuth + + user_api_key_dict = UserAPIKeyAuth( + api_key="test_api_key", user_id="test_user_id", org_id="test_org_id" + ) + + data = LiteLLMProxyRequestSetup.add_headers_to_llm_call( + headers={}, + user_api_key_dict=user_api_key_dict, + ) + + expected_data = { + "x-litellm-user_api_key_user_id": "test_user_id", + "x-litellm-user_api_key_org_id": "test_org_id", + "x-litellm-user_api_key_hash": "test_api_key", + } assert json.dumps(data, sort_keys=True) == json.dumps(expected_data, sort_keys=True)