mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
refactor location of proxy
This commit is contained in:
parent
baa5564f95
commit
ce58c53ff1
413 changed files with 2087 additions and 2088 deletions
171
litellm_proxy/guardrails/init_guardrails.py
Normal file
171
litellm_proxy/guardrails/init_guardrails.py
Normal file
|
@ -0,0 +1,171 @@
|
|||
import importlib
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import litellm
|
||||
from litellm import get_secret
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm_proxy.common_utils.callback_utils import initialize_callbacks_on_proxy
|
||||
|
||||
# v2 implementation
|
||||
from litellm.types.guardrails import (
|
||||
Guardrail,
|
||||
GuardrailItem,
|
||||
GuardrailItemSpec,
|
||||
LakeraCategoryThresholds,
|
||||
LitellmParams,
|
||||
)
|
||||
|
||||
from .guardrail_registry import guardrail_registry
|
||||
|
||||
all_guardrails: List[GuardrailItem] = []
|
||||
|
||||
|
||||
def initialize_guardrails(
|
||||
guardrails_config: List[Dict[str, GuardrailItemSpec]],
|
||||
premium_user: bool,
|
||||
config_file_path: str,
|
||||
litellm_settings: dict,
|
||||
) -> Dict[str, GuardrailItem]:
|
||||
try:
|
||||
verbose_proxy_logger.debug(f"validating guardrails passed {guardrails_config}")
|
||||
global all_guardrails
|
||||
for item in guardrails_config:
|
||||
"""
|
||||
one item looks like this:
|
||||
|
||||
{'prompt_injection': {'callbacks': ['lakera_prompt_injection', 'prompt_injection_api_2'], 'default_on': True, 'enabled_roles': ['user']}}
|
||||
"""
|
||||
for k, v in item.items():
|
||||
guardrail_item = GuardrailItem(**v, guardrail_name=k)
|
||||
all_guardrails.append(guardrail_item)
|
||||
litellm.guardrail_name_config_map[k] = guardrail_item
|
||||
|
||||
# set appropriate callbacks if they are default on
|
||||
default_on_callbacks = set()
|
||||
callback_specific_params = {}
|
||||
for guardrail in all_guardrails:
|
||||
verbose_proxy_logger.debug(guardrail.guardrail_name)
|
||||
verbose_proxy_logger.debug(guardrail.default_on)
|
||||
|
||||
callback_specific_params.update(guardrail.callback_args)
|
||||
|
||||
if guardrail.default_on is True:
|
||||
# add these to litellm callbacks if they don't exist
|
||||
for callback in guardrail.callbacks:
|
||||
if callback not in litellm.callbacks:
|
||||
default_on_callbacks.add(callback)
|
||||
|
||||
if guardrail.logging_only is True:
|
||||
if callback == "presidio":
|
||||
callback_specific_params["presidio"] = {"logging_only": True} # type: ignore
|
||||
|
||||
default_on_callbacks_list = list(default_on_callbacks)
|
||||
if len(default_on_callbacks_list) > 0:
|
||||
initialize_callbacks_on_proxy(
|
||||
value=default_on_callbacks_list,
|
||||
premium_user=premium_user,
|
||||
config_file_path=config_file_path,
|
||||
litellm_settings=litellm_settings,
|
||||
callback_specific_params=callback_specific_params,
|
||||
)
|
||||
|
||||
return litellm.guardrail_name_config_map
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"error initializing guardrails {}".format(str(e))
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
"""
|
||||
Map guardrail_name: <pre_call>, <post_call>, during_call
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def init_guardrails_v2(
|
||||
all_guardrails: List[Dict],
|
||||
config_file_path: Optional[str] = None,
|
||||
):
|
||||
guardrail_list = []
|
||||
|
||||
for guardrail in all_guardrails:
|
||||
litellm_params_data = guardrail["litellm_params"]
|
||||
verbose_proxy_logger.debug("litellm_params= %s", litellm_params_data)
|
||||
|
||||
_litellm_params_kwargs = {
|
||||
k: litellm_params_data.get(k) for k in LitellmParams.__annotations__.keys()
|
||||
}
|
||||
|
||||
litellm_params = LitellmParams(**_litellm_params_kwargs) # type: ignore
|
||||
|
||||
if (
|
||||
"category_thresholds" in litellm_params_data
|
||||
and litellm_params_data["category_thresholds"]
|
||||
):
|
||||
lakera_category_thresholds = LakeraCategoryThresholds(
|
||||
**litellm_params_data["category_thresholds"]
|
||||
)
|
||||
litellm_params["category_thresholds"] = lakera_category_thresholds
|
||||
|
||||
if litellm_params["api_key"] and litellm_params["api_key"].startswith(
|
||||
"os.environ/"
|
||||
):
|
||||
litellm_params["api_key"] = str(get_secret(litellm_params["api_key"])) # type: ignore
|
||||
|
||||
if litellm_params["api_base"] and litellm_params["api_base"].startswith(
|
||||
"os.environ/"
|
||||
):
|
||||
litellm_params["api_base"] = str(get_secret(litellm_params["api_base"])) # type: ignore
|
||||
|
||||
guardrail_type = litellm_params["guardrail"]
|
||||
|
||||
initializer = guardrail_registry.get(guardrail_type)
|
||||
|
||||
if initializer:
|
||||
initializer(litellm_params, guardrail)
|
||||
elif isinstance(guardrail_type, str) and "." in guardrail_type:
|
||||
if not config_file_path:
|
||||
raise Exception(
|
||||
"GuardrailsAIException - Please pass the config_file_path to initialize_guardrails_v2"
|
||||
)
|
||||
|
||||
_file_name, _class_name = guardrail_type.split(".")
|
||||
verbose_proxy_logger.debug(
|
||||
"Initializing custom guardrail: %s, file_name: %s, class_name: %s",
|
||||
guardrail_type,
|
||||
_file_name,
|
||||
_class_name,
|
||||
)
|
||||
|
||||
directory = os.path.dirname(config_file_path)
|
||||
module_file_path = os.path.join(directory, _file_name) + ".py"
|
||||
|
||||
spec = importlib.util.spec_from_file_location(_class_name, module_file_path) # type: ignore
|
||||
if not spec:
|
||||
raise ImportError(
|
||||
f"Could not find a module specification for {module_file_path}"
|
||||
)
|
||||
|
||||
module = importlib.util.module_from_spec(spec) # type: ignore
|
||||
spec.loader.exec_module(module) # type: ignore
|
||||
_guardrail_class = getattr(module, _class_name)
|
||||
|
||||
_guardrail_callback = _guardrail_class(
|
||||
guardrail_name=guardrail["guardrail_name"],
|
||||
event_hook=litellm_params["mode"],
|
||||
default_on=litellm_params["default_on"],
|
||||
)
|
||||
litellm.logging_callback_manager.add_litellm_callback(_guardrail_callback) # type: ignore
|
||||
else:
|
||||
raise ValueError(f"Unsupported guardrail: {guardrail_type}")
|
||||
|
||||
parsed_guardrail = Guardrail(
|
||||
guardrail_name=guardrail["guardrail_name"],
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
guardrail_list.append(parsed_guardrail)
|
||||
|
||||
verbose_proxy_logger.info(f"\nGuardrail List:{guardrail_list}\n")
|
Loading…
Add table
Add a link
Reference in a new issue