forked from phoenix/litellm-mirror
feat(litellm_pre_call_utils.py): support 'add_user_information_to_llm… (#6390)
* feat(litellm_pre_call_utils.py): support 'add_user_information_to_llm_headers' param enables passing user info to backend llm (user request for custom vllm server) * fix(litellm_logging.py): fix linting error
This commit is contained in:
parent
4e310051c7
commit
9fccf829b1
6 changed files with 221 additions and 73 deletions
|
@ -80,6 +80,9 @@ turn_off_message_logging: Optional[bool] = False
|
||||||
log_raw_request_response: bool = False
|
log_raw_request_response: bool = False
|
||||||
redact_messages_in_exceptions: Optional[bool] = False
|
redact_messages_in_exceptions: Optional[bool] = False
|
||||||
redact_user_api_key_info: Optional[bool] = False
|
redact_user_api_key_info: Optional[bool] = False
|
||||||
|
add_user_information_to_llm_headers: Optional[bool] = (
|
||||||
|
None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers
|
||||||
|
)
|
||||||
store_audit_logs = False # Enterprise feature, allow users to see audit logs
|
store_audit_logs = False # Enterprise feature, allow users to see audit logs
|
||||||
## end of callbacks #############
|
## end of callbacks #############
|
||||||
|
|
||||||
|
|
|
@ -2798,6 +2798,52 @@ def get_standard_logging_object_payload(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_standard_logging_metadata(
|
||||||
|
metadata: Optional[Dict[str, Any]]
|
||||||
|
) -> StandardLoggingMetadata:
|
||||||
|
"""
|
||||||
|
Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadata (Optional[Dict[str, Any]]): The original metadata dictionary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StandardLoggingMetadata: A StandardLoggingMetadata object containing the cleaned metadata.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- If the input metadata is None or not a dictionary, an empty StandardLoggingMetadata object is returned.
|
||||||
|
- If 'user_api_key' is present in metadata and is a valid SHA256 hash, it's stored as 'user_api_key_hash'.
|
||||||
|
"""
|
||||||
|
# Initialize with default values
|
||||||
|
clean_metadata = StandardLoggingMetadata(
|
||||||
|
user_api_key_hash=None,
|
||||||
|
user_api_key_alias=None,
|
||||||
|
user_api_key_team_id=None,
|
||||||
|
user_api_key_org_id=None,
|
||||||
|
user_api_key_user_id=None,
|
||||||
|
user_api_key_team_alias=None,
|
||||||
|
spend_logs_metadata=None,
|
||||||
|
requester_ip_address=None,
|
||||||
|
requester_metadata=None,
|
||||||
|
)
|
||||||
|
if isinstance(metadata, dict):
|
||||||
|
# Filter the metadata dictionary to include only the specified keys
|
||||||
|
clean_metadata = StandardLoggingMetadata(
|
||||||
|
**{ # type: ignore
|
||||||
|
key: metadata[key]
|
||||||
|
for key in StandardLoggingMetadata.__annotations__.keys()
|
||||||
|
if key in metadata
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if metadata.get("user_api_key") is not None:
|
||||||
|
if is_valid_sha256_hash(str(metadata.get("user_api_key"))):
|
||||||
|
clean_metadata["user_api_key_hash"] = metadata.get(
|
||||||
|
"user_api_key"
|
||||||
|
) # this is the hash
|
||||||
|
return clean_metadata
|
||||||
|
|
||||||
|
|
||||||
def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]):
|
def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]):
|
||||||
if litellm_params is None:
|
if litellm_params is None:
|
||||||
litellm_params = {}
|
litellm_params = {}
|
||||||
|
|
|
@ -16,7 +16,10 @@ from litellm.proxy._types import (
|
||||||
UserAPIKeyAuth,
|
UserAPIKeyAuth,
|
||||||
)
|
)
|
||||||
from litellm.proxy.auth.auth_utils import get_request_route
|
from litellm.proxy.auth.auth_utils import get_request_route
|
||||||
from litellm.types.utils import SupportedCacheControls
|
from litellm.types.utils import (
|
||||||
|
StandardLoggingUserAPIKeyMetadata,
|
||||||
|
SupportedCacheControls,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from litellm.proxy.proxy_server import ProxyConfig as _ProxyConfig
|
from litellm.proxy.proxy_server import ProxyConfig as _ProxyConfig
|
||||||
|
@ -159,7 +162,9 @@ def clean_headers(
|
||||||
return clean_headers
|
return clean_headers
|
||||||
|
|
||||||
|
|
||||||
def get_forwardable_headers(
|
class LiteLLMProxyRequestSetup:
|
||||||
|
@staticmethod
|
||||||
|
def _get_forwardable_headers(
|
||||||
headers: Union[Headers, dict],
|
headers: Union[Headers, dict],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -176,7 +181,7 @@ def get_forwardable_headers(
|
||||||
|
|
||||||
return forwarded_headers
|
return forwarded_headers
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def get_openai_org_id_from_headers(
|
def get_openai_org_id_from_headers(
|
||||||
headers: dict, general_settings: Optional[Dict] = None
|
headers: dict, general_settings: Optional[Dict] = None
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
|
@ -193,23 +198,72 @@ def get_openai_org_id_from_headers(
|
||||||
return value
|
return value
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_headers_to_llm_call(
|
||||||
|
headers: dict, user_api_key_dict: UserAPIKeyAuth
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Add headers to the LLM call
|
||||||
|
|
||||||
|
- Checks request headers for forwardable headers
|
||||||
|
- Checks if user information should be added to the headers
|
||||||
|
"""
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import (
|
||||||
|
get_standard_logging_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
returned_headers = LiteLLMProxyRequestSetup._get_forwardable_headers(headers)
|
||||||
|
|
||||||
|
if litellm.add_user_information_to_llm_headers is True:
|
||||||
|
litellm_logging_metadata_headers = (
|
||||||
|
LiteLLMProxyRequestSetup.get_sanitized_user_information_from_key(
|
||||||
|
user_api_key_dict=user_api_key_dict
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for k, v in litellm_logging_metadata_headers.items():
|
||||||
|
if v is not None:
|
||||||
|
returned_headers["x-litellm-{}".format(k)] = v
|
||||||
|
|
||||||
|
return returned_headers
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def add_litellm_data_for_backend_llm_call(
|
def add_litellm_data_for_backend_llm_call(
|
||||||
headers: dict, general_settings: Optional[Dict[str, Any]] = None
|
*,
|
||||||
|
headers: dict,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
general_settings: Optional[Dict[str, Any]] = None,
|
||||||
) -> LitellmDataForBackendLLMCall:
|
) -> LitellmDataForBackendLLMCall:
|
||||||
"""
|
"""
|
||||||
- Adds forwardable headers
|
- Adds forwardable headers
|
||||||
- Adds org id
|
- Adds org id
|
||||||
"""
|
"""
|
||||||
data = LitellmDataForBackendLLMCall()
|
data = LitellmDataForBackendLLMCall()
|
||||||
_headers = get_forwardable_headers(headers)
|
_headers = LiteLLMProxyRequestSetup.add_headers_to_llm_call(
|
||||||
|
headers, user_api_key_dict
|
||||||
|
)
|
||||||
if _headers != {}:
|
if _headers != {}:
|
||||||
data["headers"] = _headers
|
data["headers"] = _headers
|
||||||
_organization = get_openai_org_id_from_headers(headers, general_settings)
|
_organization = LiteLLMProxyRequestSetup.get_openai_org_id_from_headers(
|
||||||
|
headers, general_settings
|
||||||
|
)
|
||||||
if _organization is not None:
|
if _organization is not None:
|
||||||
data["organization"] = _organization
|
data["organization"] = _organization
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_sanitized_user_information_from_key(
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
) -> StandardLoggingUserAPIKeyMetadata:
|
||||||
|
user_api_key_logged_metadata = StandardLoggingUserAPIKeyMetadata(
|
||||||
|
user_api_key_hash=user_api_key_dict.api_key, # just the hashed token
|
||||||
|
user_api_key_alias=user_api_key_dict.key_alias,
|
||||||
|
user_api_key_team_id=user_api_key_dict.team_id,
|
||||||
|
user_api_key_user_id=user_api_key_dict.user_id,
|
||||||
|
user_api_key_org_id=user_api_key_dict.org_id,
|
||||||
|
user_api_key_team_alias=user_api_key_dict.team_alias,
|
||||||
|
)
|
||||||
|
return user_api_key_logged_metadata
|
||||||
|
|
||||||
|
|
||||||
async def add_litellm_data_to_request( # noqa: PLR0915
|
async def add_litellm_data_to_request( # noqa: PLR0915
|
||||||
data: dict,
|
data: dict,
|
||||||
|
@ -246,7 +300,13 @@ async def add_litellm_data_to_request( # noqa: PLR0915
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
data.update(add_litellm_data_for_backend_llm_call(_headers, general_settings))
|
data.update(
|
||||||
|
LiteLLMProxyRequestSetup.add_litellm_data_for_backend_llm_call(
|
||||||
|
headers=_headers,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
general_settings=general_settings,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Include original request and headers in the data
|
# Include original request and headers in the data
|
||||||
data["proxy_server_request"] = {
|
data["proxy_server_request"] = {
|
||||||
|
@ -294,13 +354,22 @@ async def add_litellm_data_to_request( # noqa: PLR0915
|
||||||
data["metadata"]
|
data["metadata"]
|
||||||
)
|
)
|
||||||
|
|
||||||
data[_metadata_variable_name]["user_api_key"] = user_api_key_dict.api_key
|
user_api_key_logged_metadata = (
|
||||||
data[_metadata_variable_name]["user_api_key_alias"] = getattr(
|
LiteLLMProxyRequestSetup.get_sanitized_user_information_from_key(
|
||||||
user_api_key_dict, "key_alias", None
|
user_api_key_dict=user_api_key_dict
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
data[_metadata_variable_name].update(user_api_key_logged_metadata)
|
||||||
|
data[_metadata_variable_name][
|
||||||
|
"user_api_key"
|
||||||
|
] = (
|
||||||
|
user_api_key_dict.api_key
|
||||||
|
) # this is just the hashed token. [TODO]: replace variable name in repo.
|
||||||
|
|
||||||
data[_metadata_variable_name]["user_api_end_user_max_budget"] = getattr(
|
data[_metadata_variable_name]["user_api_end_user_max_budget"] = getattr(
|
||||||
user_api_key_dict, "end_user_max_budget", None
|
user_api_key_dict, "end_user_max_budget", None
|
||||||
)
|
)
|
||||||
|
|
||||||
data[_metadata_variable_name]["litellm_api_version"] = version
|
data[_metadata_variable_name]["litellm_api_version"] = version
|
||||||
|
|
||||||
if general_settings is not None:
|
if general_settings is not None:
|
||||||
|
@ -308,15 +377,6 @@ async def add_litellm_data_to_request( # noqa: PLR0915
|
||||||
general_settings.get("global_max_parallel_requests", None)
|
general_settings.get("global_max_parallel_requests", None)
|
||||||
)
|
)
|
||||||
|
|
||||||
data[_metadata_variable_name]["user_api_key_user_id"] = user_api_key_dict.user_id
|
|
||||||
data[_metadata_variable_name]["user_api_key_org_id"] = user_api_key_dict.org_id
|
|
||||||
data[_metadata_variable_name]["user_api_key_team_id"] = getattr(
|
|
||||||
user_api_key_dict, "team_id", None
|
|
||||||
)
|
|
||||||
data[_metadata_variable_name]["user_api_key_team_alias"] = getattr(
|
|
||||||
user_api_key_dict, "team_alias", None
|
|
||||||
)
|
|
||||||
|
|
||||||
### KEY-LEVEL Controls
|
### KEY-LEVEL Controls
|
||||||
key_metadata = user_api_key_dict.metadata
|
key_metadata = user_api_key_dict.metadata
|
||||||
if "cache" in key_metadata:
|
if "cache" in key_metadata:
|
||||||
|
|
|
@ -1412,16 +1412,20 @@ class AdapterCompletionStreamWrapper:
|
||||||
raise StopAsyncIteration
|
raise StopAsyncIteration
|
||||||
|
|
||||||
|
|
||||||
class StandardLoggingMetadata(TypedDict):
|
class StandardLoggingUserAPIKeyMetadata(TypedDict):
|
||||||
|
user_api_key_hash: Optional[str] # hash of the litellm virtual key used
|
||||||
|
user_api_key_alias: Optional[str]
|
||||||
|
user_api_key_org_id: Optional[str]
|
||||||
|
user_api_key_team_id: Optional[str]
|
||||||
|
user_api_key_user_id: Optional[str]
|
||||||
|
user_api_key_team_alias: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
class StandardLoggingMetadata(StandardLoggingUserAPIKeyMetadata):
|
||||||
"""
|
"""
|
||||||
Specific metadata k,v pairs logged to integration for easier cost tracking
|
Specific metadata k,v pairs logged to integration for easier cost tracking
|
||||||
"""
|
"""
|
||||||
|
|
||||||
user_api_key_hash: Optional[str] # hash of the litellm virtual key used
|
|
||||||
user_api_key_alias: Optional[str]
|
|
||||||
user_api_key_team_id: Optional[str]
|
|
||||||
user_api_key_user_id: Optional[str]
|
|
||||||
user_api_key_team_alias: Optional[str]
|
|
||||||
spend_logs_metadata: Optional[
|
spend_logs_metadata: Optional[
|
||||||
dict
|
dict
|
||||||
] # special param to log k,v pairs to spendlogs for a call
|
] # special param to log k,v pairs to spendlogs for a call
|
||||||
|
|
|
@ -203,7 +203,7 @@ def test_add_headers_to_request(litellm_key_header_name):
|
||||||
import json
|
import json
|
||||||
from litellm.proxy.litellm_pre_call_utils import (
|
from litellm.proxy.litellm_pre_call_utils import (
|
||||||
clean_headers,
|
clean_headers,
|
||||||
get_forwardable_headers,
|
LiteLLMProxyRequestSetup,
|
||||||
)
|
)
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
|
@ -215,7 +215,9 @@ def test_add_headers_to_request(litellm_key_header_name):
|
||||||
request._url = URL(url="/chat/completions")
|
request._url = URL(url="/chat/completions")
|
||||||
request._body = json.dumps({"model": "gpt-3.5-turbo"}).encode("utf-8")
|
request._body = json.dumps({"model": "gpt-3.5-turbo"}).encode("utf-8")
|
||||||
request_headers = clean_headers(headers, litellm_key_header_name)
|
request_headers = clean_headers(headers, litellm_key_header_name)
|
||||||
forwarded_headers = get_forwardable_headers(request_headers)
|
forwarded_headers = LiteLLMProxyRequestSetup._get_forwardable_headers(
|
||||||
|
request_headers
|
||||||
|
)
|
||||||
assert forwarded_headers == {"X-Custom-Header": "Custom-Value"}
|
assert forwarded_headers == {"X-Custom-Header": "Custom-Value"}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -371,12 +371,12 @@ def test_is_request_body_safe_model_enabled(
|
||||||
|
|
||||||
|
|
||||||
def test_reading_openai_org_id_from_headers():
|
def test_reading_openai_org_id_from_headers():
|
||||||
from litellm.proxy.litellm_pre_call_utils import get_openai_org_id_from_headers
|
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
"OpenAI-Organization": "test_org_id",
|
"OpenAI-Organization": "test_org_id",
|
||||||
}
|
}
|
||||||
org_id = get_openai_org_id_from_headers(headers)
|
org_id = LiteLLMProxyRequestSetup.get_openai_org_id_from_headers(headers)
|
||||||
assert org_id == "test_org_id"
|
assert org_id == "test_org_id"
|
||||||
|
|
||||||
|
|
||||||
|
@ -399,11 +399,44 @@ def test_reading_openai_org_id_from_headers():
|
||||||
)
|
)
|
||||||
def test_add_litellm_data_for_backend_llm_call(headers, expected_data):
|
def test_add_litellm_data_for_backend_llm_call(headers, expected_data):
|
||||||
import json
|
import json
|
||||||
from litellm.proxy.litellm_pre_call_utils import (
|
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
|
||||||
add_litellm_data_for_backend_llm_call,
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
|
||||||
|
user_api_key_dict = UserAPIKeyAuth(
|
||||||
|
api_key="test_api_key", user_id="test_user_id", org_id="test_org_id"
|
||||||
)
|
)
|
||||||
|
|
||||||
data = add_litellm_data_for_backend_llm_call(headers)
|
data = LiteLLMProxyRequestSetup.add_litellm_data_for_backend_llm_call(
|
||||||
|
headers=headers,
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
general_settings=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert json.dumps(data, sort_keys=True) == json.dumps(expected_data, sort_keys=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_foward_litellm_user_info_to_backend_llm_call():
|
||||||
|
import json
|
||||||
|
|
||||||
|
litellm.add_user_information_to_llm_headers = True
|
||||||
|
|
||||||
|
from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
|
||||||
|
user_api_key_dict = UserAPIKeyAuth(
|
||||||
|
api_key="test_api_key", user_id="test_user_id", org_id="test_org_id"
|
||||||
|
)
|
||||||
|
|
||||||
|
data = LiteLLMProxyRequestSetup.add_headers_to_llm_call(
|
||||||
|
headers={},
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_data = {
|
||||||
|
"x-litellm-user_api_key_user_id": "test_user_id",
|
||||||
|
"x-litellm-user_api_key_org_id": "test_org_id",
|
||||||
|
"x-litellm-user_api_key_hash": "test_api_key",
|
||||||
|
}
|
||||||
|
|
||||||
assert json.dumps(data, sort_keys=True) == json.dumps(expected_data, sort_keys=True)
|
assert json.dumps(data, sort_keys=True) == json.dumps(expected_data, sort_keys=True)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue