forked from phoenix/litellm-mirror
feat(litellm_pre_call_utils.py): support 'add_user_information_to_llm… (#6390)
* feat(litellm_pre_call_utils.py): support 'add_user_information_to_llm_headers' param enables passing user info to backend llm (user request for custom vllm server) * fix(litellm_logging.py): fix linting error
This commit is contained in:
parent
4e310051c7
commit
9fccf829b1
6 changed files with 221 additions and 73 deletions
|
@ -16,7 +16,10 @@ from litellm.proxy._types import (
|
|||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.auth.auth_utils import get_request_route
|
||||
from litellm.types.utils import SupportedCacheControls
|
||||
from litellm.types.utils import (
|
||||
StandardLoggingUserAPIKeyMetadata,
|
||||
SupportedCacheControls,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.proxy.proxy_server import ProxyConfig as _ProxyConfig
|
||||
|
@ -159,56 +162,107 @@ def clean_headers(
|
|||
return clean_headers
|
||||
|
||||
|
||||
def get_forwardable_headers(
|
||||
headers: Union[Headers, dict],
|
||||
):
|
||||
"""
|
||||
Get the headers that should be forwarded to the LLM Provider.
|
||||
|
||||
Looks for any `x-` headers and sends them to the LLM Provider.
|
||||
"""
|
||||
forwarded_headers = {}
|
||||
for header, value in headers.items():
|
||||
if header.lower().startswith("x-") and not header.lower().startswith(
|
||||
"x-stainless"
|
||||
): # causes openai sdk to fail
|
||||
forwarded_headers[header] = value
|
||||
|
||||
return forwarded_headers
|
||||
|
||||
|
||||
def get_openai_org_id_from_headers(
|
||||
headers: dict, general_settings: Optional[Dict] = None
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get the OpenAI Org ID from the headers.
|
||||
"""
|
||||
if (
|
||||
general_settings is not None
|
||||
and general_settings.get("forward_openai_org_id") is not True
|
||||
class LiteLLMProxyRequestSetup:
|
||||
@staticmethod
|
||||
def _get_forwardable_headers(
|
||||
headers: Union[Headers, dict],
|
||||
):
|
||||
"""
|
||||
Get the headers that should be forwarded to the LLM Provider.
|
||||
|
||||
Looks for any `x-` headers and sends them to the LLM Provider.
|
||||
"""
|
||||
forwarded_headers = {}
|
||||
for header, value in headers.items():
|
||||
if header.lower().startswith("x-") and not header.lower().startswith(
|
||||
"x-stainless"
|
||||
): # causes openai sdk to fail
|
||||
forwarded_headers[header] = value
|
||||
|
||||
return forwarded_headers
|
||||
|
||||
@staticmethod
|
||||
def get_openai_org_id_from_headers(
|
||||
headers: dict, general_settings: Optional[Dict] = None
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get the OpenAI Org ID from the headers.
|
||||
"""
|
||||
if (
|
||||
general_settings is not None
|
||||
and general_settings.get("forward_openai_org_id") is not True
|
||||
):
|
||||
return None
|
||||
for header, value in headers.items():
|
||||
if header.lower() == "openai-organization":
|
||||
return value
|
||||
return None
|
||||
for header, value in headers.items():
|
||||
if header.lower() == "openai-organization":
|
||||
return value
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def add_headers_to_llm_call(
|
||||
headers: dict, user_api_key_dict: UserAPIKeyAuth
|
||||
) -> dict:
|
||||
"""
|
||||
Add headers to the LLM call
|
||||
|
||||
def add_litellm_data_for_backend_llm_call(
|
||||
headers: dict, general_settings: Optional[Dict[str, Any]] = None
|
||||
) -> LitellmDataForBackendLLMCall:
|
||||
"""
|
||||
- Adds forwardable headers
|
||||
- Adds org id
|
||||
"""
|
||||
data = LitellmDataForBackendLLMCall()
|
||||
_headers = get_forwardable_headers(headers)
|
||||
if _headers != {}:
|
||||
data["headers"] = _headers
|
||||
_organization = get_openai_org_id_from_headers(headers, general_settings)
|
||||
if _organization is not None:
|
||||
data["organization"] = _organization
|
||||
return data
|
||||
- Checks request headers for forwardable headers
|
||||
- Checks if user information should be added to the headers
|
||||
"""
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
get_standard_logging_metadata,
|
||||
)
|
||||
|
||||
returned_headers = LiteLLMProxyRequestSetup._get_forwardable_headers(headers)
|
||||
|
||||
if litellm.add_user_information_to_llm_headers is True:
|
||||
litellm_logging_metadata_headers = (
|
||||
LiteLLMProxyRequestSetup.get_sanitized_user_information_from_key(
|
||||
user_api_key_dict=user_api_key_dict
|
||||
)
|
||||
)
|
||||
for k, v in litellm_logging_metadata_headers.items():
|
||||
if v is not None:
|
||||
returned_headers["x-litellm-{}".format(k)] = v
|
||||
|
||||
return returned_headers
|
||||
|
||||
@staticmethod
|
||||
def add_litellm_data_for_backend_llm_call(
|
||||
*,
|
||||
headers: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
general_settings: Optional[Dict[str, Any]] = None,
|
||||
) -> LitellmDataForBackendLLMCall:
|
||||
"""
|
||||
- Adds forwardable headers
|
||||
- Adds org id
|
||||
"""
|
||||
data = LitellmDataForBackendLLMCall()
|
||||
_headers = LiteLLMProxyRequestSetup.add_headers_to_llm_call(
|
||||
headers, user_api_key_dict
|
||||
)
|
||||
if _headers != {}:
|
||||
data["headers"] = _headers
|
||||
_organization = LiteLLMProxyRequestSetup.get_openai_org_id_from_headers(
|
||||
headers, general_settings
|
||||
)
|
||||
if _organization is not None:
|
||||
data["organization"] = _organization
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def get_sanitized_user_information_from_key(
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
) -> StandardLoggingUserAPIKeyMetadata:
|
||||
user_api_key_logged_metadata = StandardLoggingUserAPIKeyMetadata(
|
||||
user_api_key_hash=user_api_key_dict.api_key, # just the hashed token
|
||||
user_api_key_alias=user_api_key_dict.key_alias,
|
||||
user_api_key_team_id=user_api_key_dict.team_id,
|
||||
user_api_key_user_id=user_api_key_dict.user_id,
|
||||
user_api_key_org_id=user_api_key_dict.org_id,
|
||||
user_api_key_team_alias=user_api_key_dict.team_alias,
|
||||
)
|
||||
return user_api_key_logged_metadata
|
||||
|
||||
|
||||
async def add_litellm_data_to_request( # noqa: PLR0915
|
||||
|
@ -246,7 +300,13 @@ async def add_litellm_data_to_request( # noqa: PLR0915
|
|||
),
|
||||
)
|
||||
|
||||
data.update(add_litellm_data_for_backend_llm_call(_headers, general_settings))
|
||||
data.update(
|
||||
LiteLLMProxyRequestSetup.add_litellm_data_for_backend_llm_call(
|
||||
headers=_headers,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
general_settings=general_settings,
|
||||
)
|
||||
)
|
||||
|
||||
# Include original request and headers in the data
|
||||
data["proxy_server_request"] = {
|
||||
|
@ -294,13 +354,22 @@ async def add_litellm_data_to_request( # noqa: PLR0915
|
|||
data["metadata"]
|
||||
)
|
||||
|
||||
data[_metadata_variable_name]["user_api_key"] = user_api_key_dict.api_key
|
||||
data[_metadata_variable_name]["user_api_key_alias"] = getattr(
|
||||
user_api_key_dict, "key_alias", None
|
||||
user_api_key_logged_metadata = (
|
||||
LiteLLMProxyRequestSetup.get_sanitized_user_information_from_key(
|
||||
user_api_key_dict=user_api_key_dict
|
||||
)
|
||||
)
|
||||
data[_metadata_variable_name].update(user_api_key_logged_metadata)
|
||||
data[_metadata_variable_name][
|
||||
"user_api_key"
|
||||
] = (
|
||||
user_api_key_dict.api_key
|
||||
) # this is just the hashed token. [TODO]: replace variable name in repo.
|
||||
|
||||
data[_metadata_variable_name]["user_api_end_user_max_budget"] = getattr(
|
||||
user_api_key_dict, "end_user_max_budget", None
|
||||
)
|
||||
|
||||
data[_metadata_variable_name]["litellm_api_version"] = version
|
||||
|
||||
if general_settings is not None:
|
||||
|
@ -308,15 +377,6 @@ async def add_litellm_data_to_request( # noqa: PLR0915
|
|||
general_settings.get("global_max_parallel_requests", None)
|
||||
)
|
||||
|
||||
data[_metadata_variable_name]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||
data[_metadata_variable_name]["user_api_key_org_id"] = user_api_key_dict.org_id
|
||||
data[_metadata_variable_name]["user_api_key_team_id"] = getattr(
|
||||
user_api_key_dict, "team_id", None
|
||||
)
|
||||
data[_metadata_variable_name]["user_api_key_team_alias"] = getattr(
|
||||
user_api_key_dict, "team_alias", None
|
||||
)
|
||||
|
||||
### KEY-LEVEL Controls
|
||||
key_metadata = user_api_key_dict.metadata
|
||||
if "cache" in key_metadata:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue