mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
fix(key_management_endpoints.py): override metadata field value on up… (#7008)
* fix(key_management_endpoints.py): override metadata field value on update allow user to override tags * feat(__init__.py): expose new disable_end_user_cost_tracking_prometheus_only metric allow disabling end user cost tracking on prometheus - fixes cardinality issue * fix(litellm_pre_call_utils.py): add key/team level enforced params Fixes https://github.com/BerriAI/litellm/issues/6652 * fix(key_management_endpoints.py): allow user to pass in `enforced_params` as a top level param on /key/generate and /key/update * docs(enterprise.md): add docs on enforcing required params for llm requests * Add support of Galadriel API (#7005) * fix(router.py): robust retry after handling set retry after time to 0 if >0 healthy deployments. handle base case = 1 deployment * test(test_router.py): fix test * feat(bedrock/): add support for 'nova' models also adds explicit 'converse/' route for simpler routing * fix: fix 'supports_pdf_input' return if model supports pdf input on get_model_info * feat(converse_transformation.py): support bedrock pdf input * docs(document_understanding.md): add document understanding to docs * fix(litellm_pre_call_utils.py): fix linting error * fix(init.py): fix passing of bedrock converse models * feat(bedrock/converse): support 'response_format={"type": "json_object"}' * fix(converse_handler.py): fix linting error * fix(base_llm_unit_tests.py): fix test * fix: fix test * test: fix test * test: fix test * test: remove duplicate test --------- Co-authored-by: h4n0 <4738254+h4n0@users.noreply.github.com>
This commit is contained in:
parent
d558b643be
commit
6bb934c0ac
37 changed files with 1297 additions and 503 deletions
|
@ -587,6 +587,16 @@ async def add_litellm_data_to_request( # noqa: PLR0915
|
|||
f"[PROXY]returned data from litellm_pre_call_utils: {data}"
|
||||
)
|
||||
|
||||
## ENFORCED PARAMS CHECK
|
||||
# loop through each enforced param
|
||||
# example enforced_params ['user', 'metadata', 'metadata.generation_name']
|
||||
_enforced_params_check(
|
||||
request_body=data,
|
||||
general_settings=general_settings,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
premium_user=premium_user,
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
await service_logger_obj.async_service_success_hook(
|
||||
service=ServiceTypes.PROXY_PRE_CALL,
|
||||
|
@ -599,6 +609,64 @@ async def add_litellm_data_to_request( # noqa: PLR0915
|
|||
return data
|
||||
|
||||
|
||||
def _get_enforced_params(
|
||||
general_settings: Optional[dict], user_api_key_dict: UserAPIKeyAuth
|
||||
) -> Optional[list]:
|
||||
enforced_params: Optional[list] = None
|
||||
if general_settings is not None:
|
||||
enforced_params = general_settings.get("enforced_params")
|
||||
if "service_account_settings" in general_settings:
|
||||
service_account_settings = general_settings["service_account_settings"]
|
||||
if "enforced_params" in service_account_settings:
|
||||
if enforced_params is None:
|
||||
enforced_params = []
|
||||
enforced_params.extend(service_account_settings["enforced_params"])
|
||||
if user_api_key_dict.metadata.get("enforced_params", None) is not None:
|
||||
if enforced_params is None:
|
||||
enforced_params = []
|
||||
enforced_params.extend(user_api_key_dict.metadata["enforced_params"])
|
||||
return enforced_params
|
||||
|
||||
|
||||
def _enforced_params_check(
|
||||
request_body: dict,
|
||||
general_settings: Optional[dict],
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
premium_user: bool,
|
||||
) -> bool:
|
||||
"""
|
||||
If enforced params are set, check if the request body contains the enforced params.
|
||||
"""
|
||||
enforced_params: Optional[list] = _get_enforced_params(
|
||||
general_settings=general_settings, user_api_key_dict=user_api_key_dict
|
||||
)
|
||||
if enforced_params is None:
|
||||
return True
|
||||
if enforced_params is not None and premium_user is not True:
|
||||
raise ValueError(
|
||||
f"Enforced Params is an Enterprise feature. Enforced Params: {enforced_params}. {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
|
||||
for enforced_param in enforced_params:
|
||||
_enforced_params = enforced_param.split(".")
|
||||
if len(_enforced_params) == 1:
|
||||
if _enforced_params[0] not in request_body:
|
||||
raise ValueError(
|
||||
f"BadRequest please pass param={_enforced_params[0]} in request body. This is a required param"
|
||||
)
|
||||
elif len(_enforced_params) == 2:
|
||||
# this is a scenario where user requires request['metadata']['generation_name'] to exist
|
||||
if _enforced_params[0] not in request_body:
|
||||
raise ValueError(
|
||||
f"BadRequest please pass param={_enforced_params[0]} in request body. This is a required param"
|
||||
)
|
||||
if _enforced_params[1] not in request_body[_enforced_params[0]]:
|
||||
raise ValueError(
|
||||
f"BadRequest please pass param=[{_enforced_params[0]}][{_enforced_params[1]}] in request body. This is a required param"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def move_guardrails_to_metadata(
|
||||
data: dict,
|
||||
_metadata_variable_name: str,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue