From e2d753969016d13991f258a183c45face227e916 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 21 Aug 2024 15:01:52 -0700 Subject: [PATCH 01/22] feat(caching.py): redis cluster support Closes https://github.com/BerriAI/litellm/issues/4358 --- litellm/_redis.py | 58 +++++++++++++++++++++++++-- litellm/caching.py | 14 +++++-- litellm/proxy/_new_secret_config.yaml | 7 ++++ litellm/proxy/proxy_server.py | 2 +- litellm/tests/test_caching.py | 32 +++++++++++++++ 5 files changed, 106 insertions(+), 7 deletions(-) diff --git a/litellm/_redis.py b/litellm/_redis.py index d72016dcd..23f82ed1a 100644 --- a/litellm/_redis.py +++ b/litellm/_redis.py @@ -7,13 +7,17 @@ # # Thank you users! We ❤️ you! - Krrish & Ishaan +import inspect + # s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation import os -import inspect -import redis, litellm # type: ignore -import redis.asyncio as async_redis # type: ignore from typing import List, Optional +import redis # type: ignore +import redis.asyncio as async_redis # type: ignore + +import litellm + def _get_redis_kwargs(): arg_spec = inspect.getfullargspec(redis.Redis) @@ -51,6 +55,19 @@ def _get_redis_url_kwargs(client=None): return available_args +def _get_redis_cluster_kwargs(client=None): + if client is None: + client = redis.Redis.from_url + arg_spec = inspect.getfullargspec(redis.RedisCluster) + + # Only allow primitive arguments + exclude_args = {"self", "connection_pool", "retry", "host", "port", "startup_nodes"} + + available_args = [x for x in arg_spec.args if x not in exclude_args] + + return available_args + + def _get_redis_env_kwarg_mapping(): PREFIX = "REDIS_" @@ -124,6 +141,22 @@ def get_redis_client(**env_overrides): url_kwargs[arg] = redis_kwargs[arg] return redis.Redis.from_url(**url_kwargs) + + if "startup_nodes" in redis_kwargs: + from redis.cluster import ClusterNode + + args = _get_redis_cluster_kwargs() + cluster_kwargs = {} + for arg in redis_kwargs: + if arg in args: + cluster_kwargs[arg] = redis_kwargs[arg] + + new_startup_nodes: List[ClusterNode] = [] + + for item in redis_kwargs["startup_nodes"]: + new_startup_nodes.append(ClusterNode(**item)) + redis_kwargs.pop("startup_nodes") + return redis.RedisCluster(startup_nodes=new_startup_nodes, **cluster_kwargs) return redis.Redis(**redis_kwargs) @@ -143,6 +176,24 @@ def get_redis_async_client(**env_overrides): ) return async_redis.Redis.from_url(**url_kwargs) + if "startup_nodes" in redis_kwargs: + from redis.cluster import ClusterNode + + args = _get_redis_cluster_kwargs() + cluster_kwargs = {} + for arg in redis_kwargs: + if arg in args: + cluster_kwargs[arg] = redis_kwargs[arg] + + new_startup_nodes: List[ClusterNode] = [] + + for item in redis_kwargs["startup_nodes"]: + new_startup_nodes.append(ClusterNode(**item)) + redis_kwargs.pop("startup_nodes") + return async_redis.RedisCluster( + startup_nodes=new_startup_nodes, **cluster_kwargs + ) + return async_redis.Redis( socket_timeout=5, **redis_kwargs, @@ -160,4 +211,5 @@ def get_redis_connection_pool(**env_overrides): connection_class = async_redis.SSLConnection redis_kwargs.pop("ssl", None) redis_kwargs["connection_class"] = connection_class + redis_kwargs.pop("startup_nodes", None) return async_redis.BlockingConnectionPool(timeout=5, **redis_kwargs) diff --git a/litellm/caching.py b/litellm/caching.py index 1c7216029..1b19fdf3e 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -203,6 +203,7 @@ class RedisCache(BaseCache): password=None, redis_flush_size=100, namespace: Optional[str] = None, + startup_nodes: Optional[List] = None, # for redis-cluster **kwargs, ): import redis @@ -218,7 +219,8 @@ class RedisCache(BaseCache): redis_kwargs["port"] = port if password is not None: redis_kwargs["password"] = password - + if startup_nodes is not None: + redis_kwargs["startup_nodes"] = startup_nodes ### HEALTH MONITORING OBJECT ### if kwargs.get("service_logger_obj", None) is not None and isinstance( kwargs["service_logger_obj"], ServiceLogging @@ -246,7 +248,7 @@ class RedisCache(BaseCache): ### ASYNC HEALTH PING ### try: # asyncio.get_running_loop().create_task(self.ping()) - result = asyncio.get_running_loop().create_task(self.ping()) + _ = asyncio.get_running_loop().create_task(self.ping()) except Exception as e: if "no running event loop" in str(e): verbose_logger.debug( @@ -2123,6 +2125,7 @@ class Cache: redis_semantic_cache_use_async=False, redis_semantic_cache_embedding_model="text-embedding-ada-002", redis_flush_size=None, + redis_startup_nodes: Optional[List] = None, disk_cache_dir=None, qdrant_api_base: Optional[str] = None, qdrant_api_key: Optional[str] = None, @@ -2155,7 +2158,12 @@ class Cache: """ if type == "redis": self.cache: BaseCache = RedisCache( - host, port, password, redis_flush_size, **kwargs + host, + port, + password, + redis_flush_size, + startup_nodes=redis_startup_nodes, + **kwargs, ) elif type == "redis-semantic": self.cache = RedisSemanticCache( diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 96a0242a8..10d608ec8 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -2,3 +2,10 @@ model_list: - model_name: "*" litellm_params: model: "*" + + +litellm_settings: + cache: True + cache_params: + type: redis + redis_startup_nodes: [{"host": "127.0.0.1", "port": "7001"}] diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index a9d0325d8..8986b587b 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1583,7 +1583,7 @@ class ProxyConfig: verbose_proxy_logger.debug( # noqa f"{blue_color_code}Set Cache on LiteLLM Proxy: {vars(litellm.cache.cache)}{reset_color_code}" ) - elif key == "cache" and value == False: + elif key == "cache" and value is False: pass elif key == "guardrails": if premium_user is not True: diff --git a/litellm/tests/test_caching.py b/litellm/tests/test_caching.py index 64196e5c5..5da883f4a 100644 --- a/litellm/tests/test_caching.py +++ b/litellm/tests/test_caching.py @@ -804,6 +804,38 @@ def test_redis_cache_completion_stream(): # test_redis_cache_completion_stream() +# @pytest.mark.skip(reason="Local test. Requires running redis cluster locally.") +@pytest.mark.asyncio +async def test_redis_cache_cluster_init_unit_test(): + try: + from redis.asyncio import RedisCluster as AsyncRedisCluster + from redis.cluster import RedisCluster + + from litellm.caching import RedisCache + + litellm.set_verbose = True + + # List of startup nodes + startup_nodes = [ + {"host": "127.0.0.1", "port": "7001"}, + ] + + resp = RedisCache(startup_nodes=startup_nodes) + + assert isinstance(resp.redis_client, RedisCluster) + assert isinstance(resp.init_async_client(), AsyncRedisCluster) + + resp = litellm.Cache(type="redis", redis_startup_nodes=startup_nodes) + + assert isinstance(resp.cache, RedisCache) + assert isinstance(resp.cache.redis_client, RedisCluster) + assert isinstance(resp.cache.init_async_client(), AsyncRedisCluster) + + except Exception as e: + print(f"{str(e)}\n\n{traceback.format_exc()}") + raise e + + @pytest.mark.asyncio async def test_redis_cache_acompletion_stream(): try: From f24075bcafd9737101d09c89c79c0182ac8d0d68 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 21 Aug 2024 15:05:18 -0700 Subject: [PATCH 02/22] test(test_caching.py): skip local test --- litellm/tests/test_caching.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/tests/test_caching.py b/litellm/tests/test_caching.py index 5da883f4a..e474dff2e 100644 --- a/litellm/tests/test_caching.py +++ b/litellm/tests/test_caching.py @@ -804,7 +804,7 @@ def test_redis_cache_completion_stream(): # test_redis_cache_completion_stream() -# @pytest.mark.skip(reason="Local test. Requires running redis cluster locally.") +@pytest.mark.skip(reason="Local test. Requires running redis cluster locally.") @pytest.mark.asyncio async def test_redis_cache_cluster_init_unit_test(): try: From 008fa494a77203016d5b71e85cf420b4309db39e Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 21 Aug 2024 15:35:10 -0700 Subject: [PATCH 03/22] fix(router.py): fix linting error --- litellm/router.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/litellm/router.py b/litellm/router.py index e261c1743..7a938f5c4 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -277,7 +277,8 @@ class Router: "local" # default to an in-memory cache ) redis_cache = None - cache_config = {} + cache_config: Dict[str, Any] = {} + self.client_ttl = client_ttl if redis_url is not None or ( redis_host is not None From a3537afbdf033390c80efd732b1a1b1c7a4130a8 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 22 Aug 2024 16:44:08 +0000 Subject: [PATCH 04/22] build(deps): bump hono from 4.2.7 to 4.5.8 in /litellm-js/spend-logs Bumps [hono](https://github.com/honojs/hono) from 4.2.7 to 4.5.8. - [Release notes](https://github.com/honojs/hono/releases) - [Commits](https://github.com/honojs/hono/compare/v4.2.7...v4.5.8) --- updated-dependencies: - dependency-name: hono dependency-type: direct:production ... Signed-off-by: dependabot[bot] --- litellm-js/spend-logs/package-lock.json | 8 ++++---- litellm-js/spend-logs/package.json | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/litellm-js/spend-logs/package-lock.json b/litellm-js/spend-logs/package-lock.json index cb4b599d3..5d8b85ad5 100644 --- a/litellm-js/spend-logs/package-lock.json +++ b/litellm-js/spend-logs/package-lock.json @@ -6,7 +6,7 @@ "": { "dependencies": { "@hono/node-server": "^1.10.1", - "hono": "^4.2.7" + "hono": "^4.5.8" }, "devDependencies": { "@types/node": "^20.11.17", @@ -463,9 +463,9 @@ } }, "node_modules/hono": { - "version": "4.2.7", - "resolved": "https://registry.npmjs.org/hono/-/hono-4.2.7.tgz", - "integrity": "sha512-k1xHi86tJnRIVvqhFMBDGFKJ8r5O+bEsT4P59ZK59r0F300Xd910/r237inVfuT/VmE86RQQffX4OYNda6dLXw==", + "version": "4.5.8", + "resolved": "https://registry.npmjs.org/hono/-/hono-4.5.8.tgz", + "integrity": "sha512-pqpSlcdqGkpTTRpLYU1PnCz52gVr0zVR9H5GzMyJWuKQLLEBQxh96q45QizJ2PPX8NATtz2mu31/PKW/Jt+90Q==", "engines": { "node": ">=16.0.0" } diff --git a/litellm-js/spend-logs/package.json b/litellm-js/spend-logs/package.json index d9543220b..359935c25 100644 --- a/litellm-js/spend-logs/package.json +++ b/litellm-js/spend-logs/package.json @@ -4,7 +4,7 @@ }, "dependencies": { "@hono/node-server": "^1.10.1", - "hono": "^4.2.7" + "hono": "^4.5.8" }, "devDependencies": { "@types/node": "^20.11.17", From 11bfc1dca7359a8c0f921bc06daf0ce2910c3bcf Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 22 Aug 2024 10:16:43 -0700 Subject: [PATCH 05/22] fix(cohere_chat.py): support passing 'extra_headers' Fixes https://github.com/BerriAI/litellm/issues/4709 --- litellm/llms/cohere.py | 20 ++++++++++++-------- litellm/llms/cohere_chat.py | 17 ++++++++++------- litellm/main.py | 26 +++++++++++++++++++++++++- litellm/tests/test_completion.py | 1 + litellm/utils.py | 2 ++ 5 files changed, 50 insertions(+), 16 deletions(-) diff --git a/litellm/llms/cohere.py b/litellm/llms/cohere.py index 3873027b2..8bd1051e8 100644 --- a/litellm/llms/cohere.py +++ b/litellm/llms/cohere.py @@ -124,12 +124,14 @@ class CohereConfig: } -def validate_environment(api_key): - headers = { - "Request-Source": "unspecified:litellm", - "accept": "application/json", - "content-type": "application/json", - } +def validate_environment(api_key, headers: dict): + headers.update( + { + "Request-Source": "unspecified:litellm", + "accept": "application/json", + "content-type": "application/json", + } + ) if api_key: headers["Authorization"] = f"Bearer {api_key}" return headers @@ -144,11 +146,12 @@ def completion( encoding, api_key, logging_obj, + headers: dict, optional_params=None, litellm_params=None, logger_fn=None, ): - headers = validate_environment(api_key) + headers = validate_environment(api_key, headers=headers) completion_url = api_base model = model prompt = " ".join(message["content"] for message in messages) @@ -338,13 +341,14 @@ def embedding( model_response: litellm.EmbeddingResponse, logging_obj: LiteLLMLoggingObj, optional_params: dict, + headers: dict, encoding: Any, api_key: Optional[str] = None, aembedding: Optional[bool] = None, timeout: Union[float, httpx.Timeout] = httpx.Timeout(None), client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, ): - headers = validate_environment(api_key) + headers = validate_environment(api_key, headers=headers) embed_url = "https://api.cohere.ai/v1/embed" model = model data = {"model": model, "texts": input, **optional_params} diff --git a/litellm/llms/cohere_chat.py b/litellm/llms/cohere_chat.py index a0a9a9874..f13e74614 100644 --- a/litellm/llms/cohere_chat.py +++ b/litellm/llms/cohere_chat.py @@ -116,12 +116,14 @@ class CohereChatConfig: } -def validate_environment(api_key): - headers = { - "Request-Source": "unspecified:litellm", - "accept": "application/json", - "content-type": "application/json", - } +def validate_environment(api_key, headers: dict): + headers.update( + { + "Request-Source": "unspecified:litellm", + "accept": "application/json", + "content-type": "application/json", + } + ) if api_key: headers["Authorization"] = f"Bearer {api_key}" return headers @@ -203,13 +205,14 @@ def completion( model_response: ModelResponse, print_verbose: Callable, optional_params: dict, + headers: dict, encoding, api_key, logging_obj, litellm_params=None, logger_fn=None, ): - headers = validate_environment(api_key) + headers = validate_environment(api_key, headers=headers) completion_url = api_base model = model most_recent_message, chat_history = cohere_messages_pt_v2( diff --git a/litellm/main.py b/litellm/main.py index 80a9a94a3..1beca0113 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1634,6 +1634,13 @@ def completion( or "https://api.cohere.ai/v1/generate" ) + headers = headers or litellm.headers or {} + if headers is None: + headers = {} + + if extra_headers is not None: + headers.update(extra_headers) + model_response = cohere.completion( model=model, messages=messages, @@ -1644,6 +1651,7 @@ def completion( litellm_params=litellm_params, logger_fn=logger_fn, encoding=encoding, + headers=headers, api_key=cohere_key, logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements ) @@ -1674,6 +1682,13 @@ def completion( or "https://api.cohere.ai/v1/chat" ) + headers = headers or litellm.headers or {} + if headers is None: + headers = {} + + if extra_headers is not None: + headers.update(extra_headers) + model_response = cohere_chat.completion( model=model, messages=messages, @@ -1682,6 +1697,7 @@ def completion( print_verbose=print_verbose, optional_params=optional_params, litellm_params=litellm_params, + headers=headers, logger_fn=logger_fn, encoding=encoding, api_key=cohere_key, @@ -3158,6 +3174,7 @@ def embedding( encoding_format = kwargs.get("encoding_format", None) proxy_server_request = kwargs.get("proxy_server_request", None) aembedding = kwargs.get("aembedding", None) + extra_headers = kwargs.get("extra_headers", None) ### CUSTOM MODEL COST ### input_cost_per_token = kwargs.get("input_cost_per_token", None) output_cost_per_token = kwargs.get("output_cost_per_token", None) @@ -3229,6 +3246,7 @@ def embedding( "model_config", "cooldown_time", "tags", + "extra_headers", ] default_params = openai_params + litellm_params non_default_params = { @@ -3292,7 +3310,7 @@ def embedding( "cooldown_time": cooldown_time, }, ) - if azure == True or custom_llm_provider == "azure": + if azure is True or custom_llm_provider == "azure": # azure configs api_type = get_secret("AZURE_API_TYPE") or "azure" @@ -3398,12 +3416,18 @@ def embedding( or get_secret("CO_API_KEY") or litellm.api_key ) + + if extra_headers is not None and isinstance(extra_headers, dict): + headers = extra_headers + else: + headers = {} response = cohere.embedding( model=model, input=input, optional_params=optional_params, encoding=encoding, api_key=cohere_key, # type: ignore + headers=headers, logging_obj=logging, model_response=EmbeddingResponse(), aembedding=aembedding, diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 0941484d9..c0c3c70f9 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -3653,6 +3653,7 @@ def test_completion_cohere(): response = completion( model="command-r", messages=messages, + extra_headers={"Helicone-Property-Locale": "ko"}, ) print(response) except Exception as e: diff --git a/litellm/utils.py b/litellm/utils.py index 0e9e531e9..f3bb944a8 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4219,6 +4219,7 @@ def get_supported_openai_params( "presence_penalty", "stop", "n", + "extra_headers", ] elif custom_llm_provider == "cohere_chat": return [ @@ -4233,6 +4234,7 @@ def get_supported_openai_params( "tools", "tool_choice", "seed", + "extra_headers", ] elif custom_llm_provider == "maritalk": return [ From 70bf8bd4f44e65e29cc11fe5da8fd141cd026410 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 22 Aug 2024 11:03:33 -0700 Subject: [PATCH 06/22] feat(factory.py): enable 'user_continue_message' for interweaving user/assistant messages when provider requires it allows bedrock to be used with autogen --- litellm/llms/bedrock_httpx.py | 16 ++++++++----- litellm/llms/prompt_templates/factory.py | 29 ++++++++++++++++++++++++ litellm/main.py | 3 ++- litellm/tests/test_bedrock_completion.py | 3 ++- litellm/types/utils.py | 1 + litellm/utils.py | 10 ++++++++ 6 files changed, 54 insertions(+), 8 deletions(-) diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock_httpx.py index e45559752..23e7fdc3e 100644 --- a/litellm/llms/bedrock_httpx.py +++ b/litellm/llms/bedrock_httpx.py @@ -84,6 +84,7 @@ BEDROCK_CONVERSE_MODELS = [ "meta.llama3-1-8b-instruct-v1:0", "meta.llama3-1-70b-instruct-v1:0", "meta.llama3-1-405b-instruct-v1:0", + "meta.llama3-70b-instruct-v1:0", "mistral.mistral-large-2407-v1:0", ] @@ -1480,7 +1481,7 @@ class BedrockConverseLLM(BaseAWSLLM): optional_params: dict, acompletion: bool, timeout: Optional[Union[float, httpx.Timeout]], - litellm_params=None, + litellm_params: dict, logger_fn=None, extra_headers: Optional[dict] = None, client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None, @@ -1596,6 +1597,14 @@ class BedrockConverseLLM(BaseAWSLLM): supported_tool_call_params = ["tools", "tool_choice"] supported_guardrail_params = ["guardrailConfig"] ## TRANSFORMATION ## + + bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt( + messages=messages, + model=model, + llm_provider="bedrock_converse", + user_continue_message=litellm_params.pop("user_continue_message", None), + ) + # send all model-specific params in 'additional_request_params' for k, v in inference_params.items(): if ( @@ -1608,11 +1617,6 @@ class BedrockConverseLLM(BaseAWSLLM): for key in additional_request_keys: inference_params.pop(key, None) - bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt( - messages=messages, - model=model, - llm_provider="bedrock_converse", - ) bedrock_tools: List[ToolBlock] = _bedrock_tools_pt( inference_params.pop("tools", []) ) diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index c9e691c00..2b9a7fc24 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -38,6 +38,18 @@ def prompt_injection_detection_default_pt(): BAD_MESSAGE_ERROR_STR = "Invalid Message " +# used to interweave user messages, to ensure user/assistant alternating +DEFAULT_USER_CONTINUE_MESSAGE = { + "role": "user", + "content": "Please continue.", +} # similar to autogen. Only used if `litellm.modify_params=True`. + +# used to interweave assistant messages, to ensure user/assistant alternating +DEFAULT_ASSISTANT_CONTINUE_MESSAGE = { + "role": "assistant", + "content": "Please continue.", +} # similar to autogen. Only used if `litellm.modify_params=True`. + def map_system_message_pt(messages: list) -> list: """ @@ -2254,6 +2266,7 @@ def _bedrock_converse_messages_pt( messages: List, model: str, llm_provider: str, + user_continue_message: Optional[dict] = None, ) -> List[BedrockMessageBlock]: """ Converts given messages from OpenAI format to Bedrock format @@ -2264,6 +2277,21 @@ def _bedrock_converse_messages_pt( contents: List[BedrockMessageBlock] = [] msg_i = 0 + + # if initial message is assistant message + if messages[0].get("role") is not None and messages[0]["role"] == "assistant": + if user_continue_message is not None: + messages.insert(0, user_continue_message) + elif litellm.modify_params: + messages.insert(0, DEFAULT_USER_CONTINUE_MESSAGE) + + # if final message is assistant message + if messages[-1].get("role") is not None and messages[-1]["role"] == "assistant": + if user_continue_message is not None: + messages.append(user_continue_message) + elif litellm.modify_params: + messages.append(DEFAULT_USER_CONTINUE_MESSAGE) + while msg_i < len(messages): user_content: List[BedrockContentBlock] = [] init_msg_i = msg_i @@ -2344,6 +2372,7 @@ def _bedrock_converse_messages_pt( model=model, llm_provider=llm_provider, ) + return contents diff --git a/litellm/main.py b/litellm/main.py index 1beca0113..28054537c 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -943,6 +943,7 @@ def completion( output_cost_per_token=output_cost_per_token, cooldown_time=cooldown_time, text_completion=kwargs.get("text_completion"), + user_continue_message=kwargs.get("user_continue_message"), ) logging.update_environment_variables( model=model, @@ -2304,7 +2305,7 @@ def completion( model_response=model_response, print_verbose=print_verbose, optional_params=optional_params, - litellm_params=litellm_params, + litellm_params=litellm_params, # type: ignore logger_fn=logger_fn, encoding=encoding, logging_obj=logging, diff --git a/litellm/tests/test_bedrock_completion.py b/litellm/tests/test_bedrock_completion.py index 4892601b1..90592b499 100644 --- a/litellm/tests/test_bedrock_completion.py +++ b/litellm/tests/test_bedrock_completion.py @@ -738,8 +738,9 @@ def test_bedrock_system_prompt(system, model): "temperature": 0.3, "messages": [ {"role": "system", "content": system}, - {"role": "user", "content": "hey, how's it going?"}, + {"role": "assistant", "content": "hey, how's it going?"}, ], + "user_continue_message": {"role": "user", "content": "Be a good bot!"}, } response: ModelResponse = completion( model="bedrock/{}".format(model), diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 8efbe5a11..6b278efa1 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1116,6 +1116,7 @@ all_litellm_params = [ "cooldown_time", "cache_key", "max_retries", + "user_continue_message", ] diff --git a/litellm/utils.py b/litellm/utils.py index f3bb944a8..9c6f0b849 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2323,6 +2323,7 @@ def get_litellm_params( output_cost_per_second=None, cooldown_time=None, text_completion=None, + user_continue_message=None, ): litellm_params = { "acompletion": acompletion, @@ -2347,6 +2348,7 @@ def get_litellm_params( "output_cost_per_second": output_cost_per_second, "cooldown_time": cooldown_time, "text_completion": text_completion, + "user_continue_message": user_continue_message, } return litellm_params @@ -7123,6 +7125,14 @@ def exception_type( llm_provider="bedrock", response=original_exception.response, ) + elif "A conversation must start with a user message." in error_str: + exception_mapping_worked = True + raise BadRequestError( + message=f"BedrockException - {error_str}\n. Pass in default user message via `completion(..,user_continue_message=)` or enable `litellm.modify_params=True`.", + model=model, + llm_provider="bedrock", + response=original_exception.response, + ) elif ( "Unable to locate credentials" in error_str or "The security token included in the request is invalid" From 98f73b35ba9578d88e0b11ebc9efc6519dd5dc3b Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 22 Aug 2024 11:05:25 -0700 Subject: [PATCH 07/22] docs(utils.py): cleanup docstring --- litellm/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/utils.py b/litellm/utils.py index 9c6f0b849..7596de81d 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -7128,7 +7128,7 @@ def exception_type( elif "A conversation must start with a user message." in error_str: exception_mapping_worked = True raise BadRequestError( - message=f"BedrockException - {error_str}\n. Pass in default user message via `completion(..,user_continue_message=)` or enable `litellm.modify_params=True`.", + message=f"BedrockException - {error_str}\n. Pass in default user message via `completion(..,user_continue_message=)` or enable `litellm.modify_params=True`.\nFor Proxy: do via `litellm_settings::modify_params: True` or user_continue_message under `litellm_params`", model=model, llm_provider="bedrock", response=original_exception.response, From a63c5c002066cefd6612b510afb0bca7f412ea5f Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 22 Aug 2024 11:34:52 -0700 Subject: [PATCH 08/22] docs(azure_ai.md): add azure ai jamba instruct to docs Closes https://github.com/BerriAI/litellm/issues/5333 --- docs/my-website/docs/providers/azure_ai.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/my-website/docs/providers/azure_ai.md b/docs/my-website/docs/providers/azure_ai.md index 26c965a0c..23993b52a 100644 --- a/docs/my-website/docs/providers/azure_ai.md +++ b/docs/my-website/docs/providers/azure_ai.md @@ -307,8 +307,9 @@ LiteLLM supports **ALL** azure ai models. Here's a few examples: | Model Name | Function Call | |--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| Cohere command-r-plus | `completion(model="azure/command-r-plus", messages)` | -| Cohere command-r | `completion(model="azure/command-r", messages)` | -| mistral-large-latest | `completion(model="azure/mistral-large-latest", messages)` | +| Cohere command-r-plus | `completion(model="azure_ai/command-r-plus", messages)` | +| Cohere command-r | `completion(model="azure_ai/command-r", messages)` | +| mistral-large-latest | `completion(model="azure_ai/mistral-large-latest", messages)` | +| AI21-Jamba-Instruct | `completion(model="azure_ai/ai21-jamba-instruct", messages)` | From 65c0626aa4c01f387b02cf0238880dfa5bd21593 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 22 Aug 2024 13:29:35 -0700 Subject: [PATCH 09/22] fix init correct prometheus metrics --- litellm/integrations/prometheus.py | 35 +++++++++++++++--------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/litellm/integrations/prometheus.py b/litellm/integrations/prometheus.py index 321c1cc1f..1471f59b7 100644 --- a/litellm/integrations/prometheus.py +++ b/litellm/integrations/prometheus.py @@ -103,26 +103,27 @@ class PrometheusLogger(CustomLogger): "Remaining budget for api key", labelnames=["hashed_api_key", "api_key_alias"], ) + + ######################################## + # LiteLLM Virtual API KEY metrics + ######################################## + # Remaining MODEL RPM limit for API Key + self.litellm_remaining_api_key_requests_for_model = Gauge( + "litellm_remaining_api_key_requests_for_model", + "Remaining Requests API Key can make for model (model based rpm limit on key)", + labelnames=["hashed_api_key", "api_key_alias", "model"], + ) + + # Remaining MODEL TPM limit for API Key + self.litellm_remaining_api_key_tokens_for_model = Gauge( + "litellm_remaining_api_key_tokens_for_model", + "Remaining Tokens API Key can make for model (model based tpm limit on key)", + labelnames=["hashed_api_key", "api_key_alias", "model"], + ) + # Litellm-Enterprise Metrics if premium_user is True: - ######################################## - # LiteLLM Virtual API KEY metrics - ######################################## - # Remaining MODEL RPM limit for API Key - self.litellm_remaining_api_key_requests_for_model = Gauge( - "litellm_remaining_api_key_requests_for_model", - "Remaining Requests API Key can make for model (model based rpm limit on key)", - labelnames=["hashed_api_key", "api_key_alias", "model"], - ) - - # Remaining MODEL TPM limit for API Key - self.litellm_remaining_api_key_tokens_for_model = Gauge( - "litellm_remaining_api_key_tokens_for_model", - "Remaining Tokens API Key can make for model (model based tpm limit on key)", - labelnames=["hashed_api_key", "api_key_alias", "model"], - ) - ######################################## # LLM API Deployment Metrics / analytics ######################################## From e2cdb00a810d2c6dde0b837f8b83490dc3715602 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 22 Aug 2024 13:52:03 -0700 Subject: [PATCH 10/22] track api_call_start_time --- litellm/litellm_core_utils/litellm_logging.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index d59f98558..dbf2a7d3e 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -354,6 +354,8 @@ class Logging: str(e) ) ) + + self.model_call_details["api_call_start_time"] = datetime.datetime.now() # Input Integration Logging -> If you want to log the fact that an attempt to call the model was made callbacks = litellm.input_callback + self.dynamic_input_callbacks for callback in callbacks: From 06a362d35fb25c4c0ab7bd239fb765a50c48392e Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 22 Aug 2024 13:58:10 -0700 Subject: [PATCH 11/22] track litellm_request_latency_metric --- litellm/integrations/prometheus.py | 38 ++++++++++++++++++++++++++++++ litellm/proxy/proxy_config.yaml | 4 +++- 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/litellm/integrations/prometheus.py b/litellm/integrations/prometheus.py index 1471f59b7..dadafa80e 100644 --- a/litellm/integrations/prometheus.py +++ b/litellm/integrations/prometheus.py @@ -60,6 +60,25 @@ class PrometheusLogger(CustomLogger): ], ) + # request latency metrics + self.litellm_request_latency_metric = Histogram( + "litellm_request_latency_metric", + "Total latency (seconds) for a request to LiteLLM", + labelnames=[ + "model", + "litellm_call_id", + ], + ) + + self.litellm_deployment_latency_metric = Histogram( + "litellm_deployment_latency_metric", + "Total latency (seconds) for a models LLM API call", + labelnames=[ + "model", + "litellm_call_id", + ], + ) + # Counter for spend self.litellm_spend_metric = Counter( "litellm_spend_metric", @@ -329,6 +348,25 @@ class PrometheusLogger(CustomLogger): user_api_key, user_api_key_alias, model_group ).set(remaining_tokens) + # latency metrics + total_time: timedelta = kwargs.get("end_time") - kwargs.get("start_time") + total_time_seconds = total_time.total_seconds() + api_call_total_time: timedelta = kwargs.get("end_time") - kwargs.get( + "api_call_start_time" + ) + + api_call_total_time_seconds = api_call_total_time.total_seconds() + + litellm_call_id = kwargs.get("litellm_call_id") + + self.litellm_request_latency_metric.labels(model, litellm_call_id).observe( + total_time_seconds + ) + + self.litellm_deployment_latency_metric.labels(model, litellm_call_id).observe( + api_call_total_time_seconds + ) + # set x-ratelimit headers if premium_user is True: self.set_llm_deployment_success_metrics( diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 65c7f7052..7c524eb18 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -4,7 +4,9 @@ model_list: model: openai/fake api_key: fake-key api_base: https://exampleopenaiendpoint-production.up.railway.app/ - +litellm_settings: + success_callback: ["prometheus"] + failure_callback: ["prometheus"] guardrails: - guardrail_name: "lakera-pre-guard" litellm_params: From 36b550b8db404999829985271be52377145f87f4 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 22 Aug 2024 14:03:00 -0700 Subject: [PATCH 12/22] update promtheus metric names --- litellm/integrations/prometheus.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/litellm/integrations/prometheus.py b/litellm/integrations/prometheus.py index dadafa80e..659e5b193 100644 --- a/litellm/integrations/prometheus.py +++ b/litellm/integrations/prometheus.py @@ -61,8 +61,8 @@ class PrometheusLogger(CustomLogger): ) # request latency metrics - self.litellm_request_latency_metric = Histogram( - "litellm_request_latency_metric", + self.litellm_request_total_latency_metric = Histogram( + "litellm_request_total_latency_metric", "Total latency (seconds) for a request to LiteLLM", labelnames=[ "model", @@ -70,8 +70,8 @@ class PrometheusLogger(CustomLogger): ], ) - self.litellm_deployment_latency_metric = Histogram( - "litellm_deployment_latency_metric", + self.litellm_llm_api_latency_metric = Histogram( + "litellm_llm_api_latency_metric", "Total latency (seconds) for a models LLM API call", labelnames=[ "model", @@ -359,11 +359,11 @@ class PrometheusLogger(CustomLogger): litellm_call_id = kwargs.get("litellm_call_id") - self.litellm_request_latency_metric.labels(model, litellm_call_id).observe( - total_time_seconds - ) + self.litellm_request_total_latency_metric.labels( + model, litellm_call_id + ).observe(total_time_seconds) - self.litellm_deployment_latency_metric.labels(model, litellm_call_id).observe( + self.litellm_llm_api_latency_metric.labels(model, litellm_call_id).observe( api_call_total_time_seconds ) From 57707b04b6993f9e375569fcc98caca8b9177177 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 22 Aug 2024 14:06:14 -0700 Subject: [PATCH 13/22] add prom docs for Request Latency Metrics --- docs/my-website/docs/proxy/prometheus.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docs/my-website/docs/proxy/prometheus.md b/docs/my-website/docs/proxy/prometheus.md index 4b913d2e8..10e6456c2 100644 --- a/docs/my-website/docs/proxy/prometheus.md +++ b/docs/my-website/docs/proxy/prometheus.md @@ -68,6 +68,15 @@ http://localhost:4000/metrics | `litellm_total_tokens` | input + output tokens per `"user", "key", "model", "team", "end-user"` | | `litellm_llm_api_failed_requests_metric` | Number of failed LLM API requests per `"user", "key", "model", "team", "end-user"` | +### Request Latency Metrics + +| Metric Name | Description | +|----------------------|--------------------------------------| +| `litellm_request_total_latency_metric` | Total latency (seconds) for a request to LiteLLM Proxy Server - tracked for labels `litellm_call_id`, `model` | +| `litellm_llm_api_latency_metric` | latency (seconds) for just the LLM API call - tracked for labels `litellm_call_id`, `model` | + + + ### LLM API / Provider Metrics | Metric Name | Description | From 62df7c755b6311cbfa3f33c824d0ca4634ddf1d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Thu, 22 Aug 2024 23:21:40 +0200 Subject: [PATCH 14/22] add dbally project --- docs/my-website/docs/projects/dbally.md | 3 +++ docs/my-website/sidebars.js | 1 + 2 files changed, 4 insertions(+) create mode 100644 docs/my-website/docs/projects/dbally.md diff --git a/docs/my-website/docs/projects/dbally.md b/docs/my-website/docs/projects/dbally.md new file mode 100644 index 000000000..688f1ab0f --- /dev/null +++ b/docs/my-website/docs/projects/dbally.md @@ -0,0 +1,3 @@ +Efficient, consistent and secure library for querying structured data with natural language. Query any database with over 100 LLMs ❤️ 🚅. + +🔗 [GitHub](https://github.com/deepsense-ai/db-ally) diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 54df1f3e3..368609d4b 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -261,6 +261,7 @@ const sidebars = { items: [ "projects/Docq.AI", "projects/OpenInterpreter", + "projects/dbally", "projects/FastREPL", "projects/PROMPTMETHEUS", "projects/Codium PR Agent", From 14a6ce367d2d23057d8131dbecba127b3c345b69 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 22 Aug 2024 15:40:58 -0700 Subject: [PATCH 15/22] add types for BedrockMessage --- litellm/types/guardrails.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/litellm/types/guardrails.py b/litellm/types/guardrails.py index 66c2a535a..13992beec 100644 --- a/litellm/types/guardrails.py +++ b/litellm/types/guardrails.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Dict, List, Optional, TypedDict +from typing import Dict, List, Literal, Optional, TypedDict from pydantic import BaseModel, ConfigDict from typing_extensions import Required, TypedDict @@ -76,8 +76,14 @@ class LitellmParams(TypedDict, total=False): mode: str api_key: str api_base: Optional[str] + + # Lakera specific params category_thresholds: Optional[LakeraCategoryThresholds] + # Bedrock specific params + guardrailIdentifier: Optional[str] + guardrailVersion: Optional[str] + class Guardrail(TypedDict): guardrail_name: str @@ -92,3 +98,16 @@ class GuardrailEventHooks(str, Enum): pre_call = "pre_call" post_call = "post_call" during_call = "during_call" + + +class BedrockTextContent(TypedDict): + text: str + + +class BedrockContentItem(TypedDict): + text: BedrockTextContent + + +class BedrockMessage(TypedDict): + source: Literal["INPUT", "OUTPUT"] + content: List[BedrockContentItem] From 7d55047ab9f99926d6147cc2b6c448c25e4c684d Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 22 Aug 2024 16:09:55 -0700 Subject: [PATCH 16/22] add bedrock guardrails support --- .../guardrail_hooks/bedrock_guardrails.py | 273 ++++++++++++++++++ litellm/proxy/guardrails/init_guardrails.py | 18 +- litellm/proxy/proxy_config.yaml | 12 +- litellm/types/guardrails.py | 6 +- 4 files changed, 296 insertions(+), 13 deletions(-) create mode 100644 litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py diff --git a/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py b/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py new file mode 100644 index 000000000..6c7ea4d90 --- /dev/null +++ b/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py @@ -0,0 +1,273 @@ +# +-------------------------------------------------------------+ +# +# Use Bedrock Guardrails for your LLM calls +# +# +-------------------------------------------------------------+ +# Thank you users! We ❤️ you! - Krrish & Ishaan + +import os +import sys + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import asyncio +import json +import sys +import traceback +import uuid +from datetime import datetime +from typing import Any, Dict, List, Literal, Optional, Union + +import aiohttp +import httpx +from fastapi import HTTPException + +import litellm +from litellm import get_secret +from litellm._logging import verbose_proxy_logger +from litellm.caching import DualCache +from litellm.integrations.custom_guardrail import CustomGuardrail +from litellm.litellm_core_utils.logging_utils import ( + convert_litellm_response_object_to_str, +) +from litellm.llms.base_aws_llm import BaseAWSLLM +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + _get_async_httpx_client, +) +from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata +from litellm.types.guardrails import ( + BedrockContentItem, + BedrockRequest, + BedrockTextContent, + GuardrailEventHooks, +) + +GUARDRAIL_NAME = "bedrock" + + +class BedrockGuardrail(CustomGuardrail, BaseAWSLLM): + def __init__( + self, + guardrailIdentifier: Optional[str] = None, + guardrailVersion: Optional[str] = None, + **kwargs, + ): + self.async_handler = _get_async_httpx_client() + self.guardrailIdentifier = guardrailIdentifier + self.guardrailVersion = guardrailVersion + + # store kwargs as optional_params + self.optional_params = kwargs + + super().__init__(**kwargs) + + def convert_to_bedrock_format( + self, + messages: Optional[List[Dict[str, str]]] = None, + ) -> BedrockRequest: + bedrock_request: BedrockRequest = BedrockRequest(source="INPUT") + if messages: + bedrock_request_content: List[BedrockContentItem] = [] + for message in messages: + content = message.get("content") + if isinstance(content, str): + bedrock_content_item = BedrockContentItem( + text=BedrockTextContent(text=content) + ) + bedrock_request_content.append(bedrock_content_item) + + bedrock_request["content"] = bedrock_request_content + + return bedrock_request + + #### CALL HOOKS - proxy only #### + def _load_credentials( + self, + ): + try: + from botocore.credentials import Credentials + except ImportError as e: + raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") + ## CREDENTIALS ## + # pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them + aws_secret_access_key = self.optional_params.pop("aws_secret_access_key", None) + aws_access_key_id = self.optional_params.pop("aws_access_key_id", None) + aws_session_token = self.optional_params.pop("aws_session_token", None) + aws_region_name = self.optional_params.pop("aws_region_name", None) + aws_role_name = self.optional_params.pop("aws_role_name", None) + aws_session_name = self.optional_params.pop("aws_session_name", None) + aws_profile_name = self.optional_params.pop("aws_profile_name", None) + aws_bedrock_runtime_endpoint = self.optional_params.pop( + "aws_bedrock_runtime_endpoint", None + ) # https://bedrock-runtime.{region_name}.amazonaws.com + aws_web_identity_token = self.optional_params.pop( + "aws_web_identity_token", None + ) + aws_sts_endpoint = self.optional_params.pop("aws_sts_endpoint", None) + + ### SET REGION NAME ### + if aws_region_name is None: + # check env # + litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) + + if litellm_aws_region_name is not None and isinstance( + litellm_aws_region_name, str + ): + aws_region_name = litellm_aws_region_name + + standard_aws_region_name = get_secret("AWS_REGION", None) + if standard_aws_region_name is not None and isinstance( + standard_aws_region_name, str + ): + aws_region_name = standard_aws_region_name + + if aws_region_name is None: + aws_region_name = "us-west-2" + + credentials: Credentials = self.get_credentials( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + aws_region_name=aws_region_name, + aws_session_name=aws_session_name, + aws_profile_name=aws_profile_name, + aws_role_name=aws_role_name, + aws_web_identity_token=aws_web_identity_token, + aws_sts_endpoint=aws_sts_endpoint, + ) + return credentials, aws_region_name + + def _prepare_request( + self, + credentials, + data: BedrockRequest, + optional_params: dict, + aws_region_name: str, + extra_headers: Optional[dict] = None, + ): + try: + import boto3 + from botocore.auth import SigV4Auth + from botocore.awsrequest import AWSRequest + from botocore.credentials import Credentials + except ImportError as e: + raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") + + sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) + api_base = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com/guardrail/{self.guardrailIdentifier}/version/{self.guardrailVersion}/apply" + + encoded_data = json.dumps(data).encode("utf-8") + headers = {"Content-Type": "application/json"} + if extra_headers is not None: + headers = {"Content-Type": "application/json", **extra_headers} + + request = AWSRequest( + method="POST", url=api_base, data=encoded_data, headers=headers + ) + sigv4.add_auth(request) + prepped_request = request.prepare() + + return prepped_request + + async def make_bedrock_api_request(self, kwargs: dict): + + credentials, aws_region_name = self._load_credentials() + request_data: BedrockRequest = self.convert_to_bedrock_format( + messages=kwargs.get("messages") + ) + prepared_request = self._prepare_request( + credentials=credentials, + data=request_data, + optional_params=self.optional_params, + aws_region_name=aws_region_name, + ) + verbose_proxy_logger.debug( + "Bedrock AI request body: %s, url %s, headers: %s", + request_data, + prepared_request.url, + prepared_request.headers, + ) + _json_data = json.dumps(request_data) # type: ignore + response = await self.async_handler.post( + url=prepared_request.url, + json=request_data, # type: ignore + headers=prepared_request.headers, + ) + verbose_proxy_logger.debug("Bedrock AI response: %s", response.text) + if response.status_code == 200: + # check if the response was flagged + _json_response = response.json() + if _json_response.get("action") == "GUARDRAIL_INTERVENED": + raise HTTPException( + status_code=400, + detail={ + "error": "Violated guardrail policy", + "bedrock_guardrail_response": _json_response, + }, + ) + else: + verbose_proxy_logger.error( + "Bedrock AI: error in response. Status code: %s, response: %s", + response.status_code, + response.text, + ) + + async def async_moderation_hook( ### 👈 KEY CHANGE ### + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + call_type: Literal["completion", "embeddings", "image_generation"], + ): + from litellm.proxy.common_utils.callback_utils import ( + add_guardrail_to_applied_guardrails_header, + ) + from litellm.types.guardrails import GuardrailEventHooks + + event_type: GuardrailEventHooks = GuardrailEventHooks.during_call + if self.should_run_guardrail(data=data, event_type=event_type) is not True: + return + + new_messages: Optional[List[dict]] = data.get("messages") + if new_messages is not None: + await self.make_bedrock_api_request(kwargs=data) + add_guardrail_to_applied_guardrails_header( + request_data=data, guardrail_name=self.guardrail_name + ) + else: + verbose_proxy_logger.warning( + "Bedrock AI: not running guardrail. No messages in data" + ) + pass + + # async def async_post_call_success_hook( + # self, + # data: dict, + # user_api_key_dict: UserAPIKeyAuth, + # response, + # ): + # from litellm.proxy.common_utils.callback_utils import ( + # add_guardrail_to_applied_guardrails_header, + # ) + # from litellm.types.guardrails import GuardrailEventHooks + + # """ + # Use this for the post call moderation with Guardrails + # """ + # event_type: GuardrailEventHooks = GuardrailEventHooks.post_call + # if self.should_run_guardrail(data=data, event_type=event_type) is not True: + # return + + # response_str: Optional[str] = convert_litellm_response_object_to_str(response) + # if response_str is not None: + # await self.make_bedrock_api_request( + # response_string=response_str, new_messages=data.get("messages", []) + # ) + + # add_guardrail_to_applied_guardrails_header( + # request_data=data, guardrail_name=self.guardrail_name + # ) + + # pass diff --git a/litellm/proxy/guardrails/init_guardrails.py b/litellm/proxy/guardrails/init_guardrails.py index ad99daf95..f0e2a9e2e 100644 --- a/litellm/proxy/guardrails/init_guardrails.py +++ b/litellm/proxy/guardrails/init_guardrails.py @@ -96,8 +96,10 @@ def init_guardrails_v2(all_guardrails: dict): litellm_params = LitellmParams( guardrail=litellm_params_data["guardrail"], mode=litellm_params_data["mode"], - api_key=litellm_params_data["api_key"], - api_base=litellm_params_data["api_base"], + api_key=litellm_params_data.get("api_key"), + api_base=litellm_params_data.get("api_base"), + guardrailIdentifier=litellm_params_data.get("guardrailIdentifier"), + guardrailVersion=litellm_params_data.get("guardrailVersion"), ) if ( @@ -134,6 +136,18 @@ def init_guardrails_v2(all_guardrails: dict): event_hook=litellm_params["mode"], ) litellm.callbacks.append(_aporia_callback) # type: ignore + if litellm_params["guardrail"] == "bedrock": + from litellm.proxy.guardrails.guardrail_hooks.bedrock_guardrails import ( + BedrockGuardrail, + ) + + _bedrock_callback = BedrockGuardrail( + guardrail_name=guardrail["guardrail_name"], + event_hook=litellm_params["mode"], + guardrailIdentifier=litellm_params["guardrailIdentifier"], + guardrailVersion=litellm_params["guardrailVersion"], + ) + litellm.callbacks.append(_bedrock_callback) # type: ignore elif litellm_params["guardrail"] == "lakera": from litellm.proxy.guardrails.guardrail_hooks.lakera_ai import ( lakeraAI_Moderation, diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 65c7f7052..d8e88cec7 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -6,13 +6,9 @@ model_list: api_base: https://exampleopenaiendpoint-production.up.railway.app/ guardrails: - - guardrail_name: "lakera-pre-guard" + - guardrail_name: "bedrock-pre-guard" litellm_params: - guardrail: lakera # supported values: "aporia", "bedrock", "lakera" + guardrail: bedrock # supported values: "aporia", "bedrock", "lakera" mode: "during_call" - api_key: os.environ/LAKERA_API_KEY - api_base: os.environ/LAKERA_API_BASE - category_thresholds: - prompt_injection: 0.1 - jailbreak: 0.1 - \ No newline at end of file + guardrailIdentifier: ff6ujrregl1q + guardrailVersion: "DRAFT" \ No newline at end of file diff --git a/litellm/types/guardrails.py b/litellm/types/guardrails.py index 13992beec..10f4be7e1 100644 --- a/litellm/types/guardrails.py +++ b/litellm/types/guardrails.py @@ -100,14 +100,14 @@ class GuardrailEventHooks(str, Enum): during_call = "during_call" -class BedrockTextContent(TypedDict): +class BedrockTextContent(TypedDict, total=False): text: str -class BedrockContentItem(TypedDict): +class BedrockContentItem(TypedDict, total=False): text: BedrockTextContent -class BedrockMessage(TypedDict): +class BedrockRequest(TypedDict, total=False): source: Literal["INPUT", "OUTPUT"] content: List[BedrockContentItem] From 499b6b33688dd43a9b1cf55c5ffa8370b5568dc1 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 22 Aug 2024 16:25:22 -0700 Subject: [PATCH 17/22] doc bedrock guardrails --- .../docs/proxy/guardrails/bedrock.md | 135 ++++++++++++++++++ docs/my-website/sidebars.js | 2 +- 2 files changed, 136 insertions(+), 1 deletion(-) create mode 100644 docs/my-website/docs/proxy/guardrails/bedrock.md diff --git a/docs/my-website/docs/proxy/guardrails/bedrock.md b/docs/my-website/docs/proxy/guardrails/bedrock.md new file mode 100644 index 000000000..ac8aa1c1b --- /dev/null +++ b/docs/my-website/docs/proxy/guardrails/bedrock.md @@ -0,0 +1,135 @@ +import Image from '@theme/IdealImage'; +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + +# Bedrock + +## Quick Start +### 1. Define Guardrails on your LiteLLM config.yaml + +Define your guardrails under the `guardrails` section +```yaml +model_list: + - model_name: gpt-3.5-turbo + litellm_params: + model: openai/gpt-3.5-turbo + api_key: os.environ/OPENAI_API_KEY + +guardrails: + - guardrail_name: "bedrock-pre-guard" + litellm_params: + guardrail: bedrock # supported values: "aporia", "bedrock", "lakera" + mode: "during_call" + guardrailIdentifier: ff6ujrregl1q # your guardrail ID on bedrock + guardrailVersion: "DRAFT" # your guardrail version on bedrock + +``` + +#### Supported values for `mode` + +- `pre_call` Run **before** LLM call, on **input** +- `post_call` Run **after** LLM call, on **input & output** +- `during_call` Run **during** LLM call, on **input** Same as `pre_call` but runs in parallel as LLM call. Response not returned until guardrail check completes + +### 2. Start LiteLLM Gateway + + +```shell +litellm --config config.yaml --detailed_debug +``` + +### 3. Test request + +**[Langchain, OpenAI SDK Usage Examples](../proxy/user_keys##request-format)** + + + + +Expect this to fail since since `ishaan@berri.ai` in the request is PII + +```shell +curl -i http://localhost:4000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-npnwjPQciVRok5yNZgKmFQ" \ + -d '{ + "model": "gpt-3.5-turbo", + "messages": [ + {"role": "user", "content": "hi my email is ishaan@berri.ai"} + ], + "guardrails": ["bedrock-guard"] + }' +``` + +Expected response on failure + +```shell +{ + "error": { + "message": { + "error": "Violated guardrail policy", + "bedrock_guardrail_response": { + "action": "GUARDRAIL_INTERVENED", + "assessments": [ + { + "topicPolicy": { + "topics": [ + { + "action": "BLOCKED", + "name": "Coffee", + "type": "DENY" + } + ] + } + } + ], + "blockedResponse": "Sorry, the model cannot answer this question. coffee guardrail applied ", + "output": [ + { + "text": "Sorry, the model cannot answer this question. coffee guardrail applied " + } + ], + "outputs": [ + { + "text": "Sorry, the model cannot answer this question. coffee guardrail applied " + } + ], + "usage": { + "contentPolicyUnits": 0, + "contextualGroundingPolicyUnits": 0, + "sensitiveInformationPolicyFreeUnits": 0, + "sensitiveInformationPolicyUnits": 0, + "topicPolicyUnits": 1, + "wordPolicyUnits": 0 + } + } + }, + "type": "None", + "param": "None", + "code": "400" + } +} + +``` + + + + + +```shell +curl -i http://localhost:4000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-npnwjPQciVRok5yNZgKmFQ" \ + -d '{ + "model": "gpt-3.5-turbo", + "messages": [ + {"role": "user", "content": "hi what is the weather"} + ], + "guardrails": ["bedrock-guard"] + }' +``` + + + + + + diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index ab94ed5b4..b907a1130 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -54,7 +54,7 @@ const sidebars = { { type: "category", label: "🛡️ [Beta] Guardrails", - items: ["proxy/guardrails/quick_start", "proxy/guardrails/aporia_api", "proxy/guardrails/lakera_ai"], + items: ["proxy/guardrails/quick_start", "proxy/guardrails/aporia_api", "proxy/guardrails/lakera_ai", "proxy/guardrails/bedrock"], }, { type: "category", From 9e3d573bcb6c14dc517ae9f930b952cd6b472698 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 22 Aug 2024 16:34:43 -0700 Subject: [PATCH 18/22] add async_post_call_success_hook --- .../guardrail_hooks/bedrock_guardrails.py | 78 +++++++++++-------- litellm/proxy/proxy_config.yaml | 7 +- 2 files changed, 50 insertions(+), 35 deletions(-) diff --git a/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py b/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py index 6c7ea4d90..d11f58a3e 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py +++ b/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py @@ -67,10 +67,12 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM): def convert_to_bedrock_format( self, messages: Optional[List[Dict[str, str]]] = None, + response: Optional[Union[Any, litellm.ModelResponse]] = None, ) -> BedrockRequest: bedrock_request: BedrockRequest = BedrockRequest(source="INPUT") + bedrock_request_content: List[BedrockContentItem] = [] + if messages: - bedrock_request_content: List[BedrockContentItem] = [] for message in messages: content = message.get("content") if isinstance(content, str): @@ -80,7 +82,19 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM): bedrock_request_content.append(bedrock_content_item) bedrock_request["content"] = bedrock_request_content - + if response: + bedrock_request["source"] = "OUTPUT" + if isinstance(response, litellm.ModelResponse): + for choice in response.choices: + if isinstance(choice, litellm.Choices): + if choice.message.content and isinstance( + choice.message.content, str + ): + bedrock_content_item = BedrockContentItem( + text=BedrockTextContent(text=choice.message.content) + ) + bedrock_request_content.append(bedrock_content_item) + bedrock_request["content"] = bedrock_request_content return bedrock_request #### CALL HOOKS - proxy only #### @@ -172,11 +186,13 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM): return prepped_request - async def make_bedrock_api_request(self, kwargs: dict): + async def make_bedrock_api_request( + self, kwargs: dict, response: Optional[Union[Any, litellm.ModelResponse]] = None + ): credentials, aws_region_name = self._load_credentials() request_data: BedrockRequest = self.convert_to_bedrock_format( - messages=kwargs.get("messages") + messages=kwargs.get("messages"), response=response ) prepared_request = self._prepare_request( credentials=credentials, @@ -242,32 +258,32 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM): ) pass - # async def async_post_call_success_hook( - # self, - # data: dict, - # user_api_key_dict: UserAPIKeyAuth, - # response, - # ): - # from litellm.proxy.common_utils.callback_utils import ( - # add_guardrail_to_applied_guardrails_header, - # ) - # from litellm.types.guardrails import GuardrailEventHooks + async def async_post_call_success_hook( + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + response, + ): + from litellm.proxy.common_utils.callback_utils import ( + add_guardrail_to_applied_guardrails_header, + ) + from litellm.types.guardrails import GuardrailEventHooks - # """ - # Use this for the post call moderation with Guardrails - # """ - # event_type: GuardrailEventHooks = GuardrailEventHooks.post_call - # if self.should_run_guardrail(data=data, event_type=event_type) is not True: - # return + if ( + self.should_run_guardrail( + data=data, event_type=GuardrailEventHooks.post_call + ) + is not True + ): + return - # response_str: Optional[str] = convert_litellm_response_object_to_str(response) - # if response_str is not None: - # await self.make_bedrock_api_request( - # response_string=response_str, new_messages=data.get("messages", []) - # ) - - # add_guardrail_to_applied_guardrails_header( - # request_data=data, guardrail_name=self.guardrail_name - # ) - - # pass + new_messages: Optional[List[dict]] = data.get("messages") + if new_messages is not None: + await self.make_bedrock_api_request(kwargs=data, response=response) + add_guardrail_to_applied_guardrails_header( + request_data=data, guardrail_name=self.guardrail_name + ) + else: + verbose_proxy_logger.warning( + "Bedrock AI: not running guardrail. No messages in data" + ) diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index d8e88cec7..d0ed9a699 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -1,14 +1,13 @@ model_list: - model_name: gpt-4 litellm_params: - model: openai/fake - api_key: fake-key - api_base: https://exampleopenaiendpoint-production.up.railway.app/ + model: openai/gpt-4 + api_key: os.environ/OPENAI_API_KEY guardrails: - guardrail_name: "bedrock-pre-guard" litellm_params: guardrail: bedrock # supported values: "aporia", "bedrock", "lakera" - mode: "during_call" + mode: "post_call" guardrailIdentifier: ff6ujrregl1q guardrailVersion: "DRAFT" \ No newline at end of file From d7b525f391b7c7bcbad43e1fb90afb92478226de Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 22 Aug 2024 16:37:46 -0700 Subject: [PATCH 19/22] feat(auth_checks.py): allow team to call all models, when explicitly set via /* --- litellm/proxy/_new_secret_config.yaml | 9 --------- litellm/proxy/auth/auth_checks.py | 8 ++++++-- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 66530c7db..96a0242a8 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -2,12 +2,3 @@ model_list: - model_name: "*" litellm_params: model: "*" - -litellm_settings: - success_callback: ["s3"] - cache: true - s3_callback_params: - s3_bucket_name: mytestbucketlitellm # AWS Bucket Name for S3 - s3_region_name: us-west-2 # AWS Region Name for S3 - s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/ to pass environment variables. This is AWS Access Key ID for S3 - s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3 \ No newline at end of file diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index cf5065c2e..0f1452651 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -66,7 +66,7 @@ def common_checks( raise Exception( f"Team={team_object.team_id} is blocked. Update via `/team/unblock` if your admin." ) - # 2. If user can call model + # 2. If team can call model if ( _model is not None and team_object is not None @@ -74,7 +74,11 @@ def common_checks( and _model not in team_object.models ): # this means the team has access to all models on the proxy - if "all-proxy-models" in team_object.models: + if ( + "all-proxy-models" in team_object.models + or "*" in team_object.models + or "openai/*" in team_object.models + ): # this means the team has access to all models on the proxy pass # check if the team model is an access_group From 735fc804edd44d6b34f900a535578a5199398907 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 22 Aug 2024 16:49:52 -0700 Subject: [PATCH 20/22] fix(proxy_server.py): expose flag to disable retries when max parallel request limit is hit --- docs/my-website/docs/proxy/configs.md | 1 + litellm/proxy/proxy_server.py | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/docs/my-website/docs/proxy/configs.md b/docs/my-website/docs/proxy/configs.md index 19c1f7902..d08d6324d 100644 --- a/docs/my-website/docs/proxy/configs.md +++ b/docs/my-website/docs/proxy/configs.md @@ -727,6 +727,7 @@ general_settings: "completion_model": "string", "disable_spend_logs": "boolean", # turn off writing each transaction to the db "disable_master_key_return": "boolean", # turn off returning master key on UI (checked on '/user/info' endpoint) + "disable_retry_on_max_parallel_request_limit_error": "boolean", # turn off retries when max parallel request limit is reached "disable_reset_budget": "boolean", # turn off reset budget scheduled task "disable_adding_master_key_hash_to_db": "boolean", # turn off storing master key hash in db, for spend tracking "enable_jwt_auth": "boolean", # allow proxy admin to auth in via jwt tokens with 'litellm_proxy_admin' in claims diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 6bd528def..c793ffbe3 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -2672,6 +2672,13 @@ def giveup(e): and isinstance(e.message, str) and "Max parallel request limit reached" in e.message ) + + if ( + general_settings.get("disable_retry_on_max_parallel_request_limit_error") + is True + ): + return True # giveup if queuing max parallel request limits is disabled + if result: verbose_proxy_logger.info(json.dumps({"event": "giveup", "exception": str(e)})) return result From e445b78490b59ba30a62391c6c9b45d3879a2b4b Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 22 Aug 2024 17:12:52 -0700 Subject: [PATCH 21/22] docs(configs.md): add global_max_parallel_requests to docs --- docs/my-website/docs/proxy/configs.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/my-website/docs/proxy/configs.md b/docs/my-website/docs/proxy/configs.md index d08d6324d..a50b3f646 100644 --- a/docs/my-website/docs/proxy/configs.md +++ b/docs/my-website/docs/proxy/configs.md @@ -752,7 +752,8 @@ general_settings: }, "otel": true, "custom_auth": "string", - "max_parallel_requests": 0, + "max_parallel_requests": 0, # the max parallel requests allowed per deployment + "global_max_parallel_requests": 0, # the max parallel requests allowed on the proxy all up "infer_model_from_keys": true, "background_health_checks": true, "health_check_interval": 300, From 1f0cc725316a8410f0e20e10297442ce84ec2022 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 22 Aug 2024 17:24:42 -0700 Subject: [PATCH 22/22] test bedrock guardrails --- .circleci/config.yml | 3 +++ .../example_config_yaml/otel_test_config.yaml | 8 ++++++- litellm/proxy/proxy_config.yaml | 2 +- tests/otel_tests/test_guardrails.py | 23 +++++++++++++++++++ 4 files changed, 34 insertions(+), 2 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 24d826f4f..f8393be9d 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -320,6 +320,9 @@ jobs: -e APORIA_API_BASE_2=$APORIA_API_BASE_2 \ -e APORIA_API_KEY_2=$APORIA_API_KEY_2 \ -e APORIA_API_BASE_1=$APORIA_API_BASE_1 \ + -e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \ + -e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \ + -e AWS_REGION_NAME=$AWS_REGION_NAME \ -e APORIA_API_KEY_1=$APORIA_API_KEY_1 \ --name my-app \ -v $(pwd)/litellm/proxy/example_config_yaml/otel_test_config.yaml:/app/config.yaml \ diff --git a/litellm/proxy/example_config_yaml/otel_test_config.yaml b/litellm/proxy/example_config_yaml/otel_test_config.yaml index 496ae1710..8ca4f37fd 100644 --- a/litellm/proxy/example_config_yaml/otel_test_config.yaml +++ b/litellm/proxy/example_config_yaml/otel_test_config.yaml @@ -21,4 +21,10 @@ guardrails: guardrail: aporia # supported values: "aporia", "bedrock", "lakera" mode: "post_call" api_key: os.environ/APORIA_API_KEY_2 - api_base: os.environ/APORIA_API_BASE_2 \ No newline at end of file + api_base: os.environ/APORIA_API_BASE_2 + - guardrail_name: "bedrock-pre-guard" + litellm_params: + guardrail: bedrock # supported values: "aporia", "bedrock", "lakera" + mode: "pre_call" + guardrailIdentifier: ff6ujrregl1q + guardrailVersion: "DRAFT" \ No newline at end of file diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index d0ed9a699..6b831876f 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -1,5 +1,5 @@ model_list: - - model_name: gpt-4 + - model_name: fake-openai-endpoint litellm_params: model: openai/gpt-4 api_key: os.environ/OPENAI_API_KEY diff --git a/tests/otel_tests/test_guardrails.py b/tests/otel_tests/test_guardrails.py index 7e9ff613a..34f14186e 100644 --- a/tests/otel_tests/test_guardrails.py +++ b/tests/otel_tests/test_guardrails.py @@ -144,6 +144,7 @@ async def test_no_llm_guard_triggered(): assert "x-litellm-applied-guardrails" not in headers + @pytest.mark.asyncio async def test_guardrails_with_api_key_controls(): """ @@ -194,3 +195,25 @@ async def test_guardrails_with_api_key_controls(): except Exception as e: print(e) assert "Aporia detected and blocked PII" in str(e) + + +@pytest.mark.asyncio +async def test_bedrock_guardrail_triggered(): + """ + - Tests a request where our bedrock guardrail should be triggered + - Assert that the guardrails applied are returned in the response headers + """ + async with aiohttp.ClientSession() as session: + try: + response, headers = await chat_completion( + session, + "sk-1234", + model="fake-openai-endpoint", + messages=[{"role": "user", "content": f"Hello do you like coffee?"}], + guardrails=["bedrock-pre-guard"], + ) + pytest.fail("Should have thrown an exception") + except Exception as e: + print(e) + assert "GUARDRAIL_INTERVENED" in str(e) + assert "Violated guardrail policy" in str(e)