diff --git a/.circleci/config.yml b/.circleci/config.yml
index 3f457b30b..a5abab254 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -321,6 +321,9 @@ jobs:
-e APORIA_API_BASE_2=$APORIA_API_BASE_2 \
-e APORIA_API_KEY_2=$APORIA_API_KEY_2 \
-e APORIA_API_BASE_1=$APORIA_API_BASE_1 \
+ -e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
+ -e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \
+ -e AWS_REGION_NAME=$AWS_REGION_NAME \
-e APORIA_API_KEY_1=$APORIA_API_KEY_1 \
--name my-app \
-v $(pwd)/litellm/proxy/example_config_yaml/otel_test_config.yaml:/app/config.yaml \
diff --git a/docs/my-website/docs/projects/dbally.md b/docs/my-website/docs/projects/dbally.md
new file mode 100644
index 000000000..688f1ab0f
--- /dev/null
+++ b/docs/my-website/docs/projects/dbally.md
@@ -0,0 +1,3 @@
+Efficient, consistent and secure library for querying structured data with natural language. Query any database with over 100 LLMs ❤️ 🚅.
+
+🔗 [GitHub](https://github.com/deepsense-ai/db-ally)
diff --git a/docs/my-website/docs/providers/azure_ai.md b/docs/my-website/docs/providers/azure_ai.md
index 26c965a0c..23993b52a 100644
--- a/docs/my-website/docs/providers/azure_ai.md
+++ b/docs/my-website/docs/providers/azure_ai.md
@@ -307,8 +307,9 @@ LiteLLM supports **ALL** azure ai models. Here's a few examples:
| Model Name | Function Call |
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
-| Cohere command-r-plus | `completion(model="azure/command-r-plus", messages)` |
-| Cohere command-r | `completion(model="azure/command-r", messages)` |
-| mistral-large-latest | `completion(model="azure/mistral-large-latest", messages)` |
+| Cohere command-r-plus | `completion(model="azure_ai/command-r-plus", messages)` |
+| Cohere command-r | `completion(model="azure_ai/command-r", messages)` |
+| mistral-large-latest | `completion(model="azure_ai/mistral-large-latest", messages)` |
+| AI21-Jamba-Instruct | `completion(model="azure_ai/ai21-jamba-instruct", messages)` |
diff --git a/docs/my-website/docs/proxy/configs.md b/docs/my-website/docs/proxy/configs.md
index 19c1f7902..a50b3f646 100644
--- a/docs/my-website/docs/proxy/configs.md
+++ b/docs/my-website/docs/proxy/configs.md
@@ -727,6 +727,7 @@ general_settings:
"completion_model": "string",
"disable_spend_logs": "boolean", # turn off writing each transaction to the db
"disable_master_key_return": "boolean", # turn off returning master key on UI (checked on '/user/info' endpoint)
+ "disable_retry_on_max_parallel_request_limit_error": "boolean", # turn off retries when max parallel request limit is reached
"disable_reset_budget": "boolean", # turn off reset budget scheduled task
"disable_adding_master_key_hash_to_db": "boolean", # turn off storing master key hash in db, for spend tracking
"enable_jwt_auth": "boolean", # allow proxy admin to auth in via jwt tokens with 'litellm_proxy_admin' in claims
@@ -751,7 +752,8 @@ general_settings:
},
"otel": true,
"custom_auth": "string",
- "max_parallel_requests": 0,
+ "max_parallel_requests": 0, # the max parallel requests allowed per deployment
+ "global_max_parallel_requests": 0, # the max parallel requests allowed on the proxy all up
"infer_model_from_keys": true,
"background_health_checks": true,
"health_check_interval": 300,
diff --git a/docs/my-website/docs/proxy/guardrails/bedrock.md b/docs/my-website/docs/proxy/guardrails/bedrock.md
new file mode 100644
index 000000000..ac8aa1c1b
--- /dev/null
+++ b/docs/my-website/docs/proxy/guardrails/bedrock.md
@@ -0,0 +1,135 @@
+import Image from '@theme/IdealImage';
+import Tabs from '@theme/Tabs';
+import TabItem from '@theme/TabItem';
+
+# Bedrock
+
+## Quick Start
+### 1. Define Guardrails on your LiteLLM config.yaml
+
+Define your guardrails under the `guardrails` section
+```yaml
+model_list:
+ - model_name: gpt-3.5-turbo
+ litellm_params:
+ model: openai/gpt-3.5-turbo
+ api_key: os.environ/OPENAI_API_KEY
+
+guardrails:
+ - guardrail_name: "bedrock-pre-guard"
+ litellm_params:
+ guardrail: bedrock # supported values: "aporia", "bedrock", "lakera"
+ mode: "during_call"
+ guardrailIdentifier: ff6ujrregl1q # your guardrail ID on bedrock
+ guardrailVersion: "DRAFT" # your guardrail version on bedrock
+
+```
+
+#### Supported values for `mode`
+
+- `pre_call` Run **before** LLM call, on **input**
+- `post_call` Run **after** LLM call, on **input & output**
+- `during_call` Run **during** LLM call, on **input** Same as `pre_call` but runs in parallel as LLM call. Response not returned until guardrail check completes
+
+### 2. Start LiteLLM Gateway
+
+
+```shell
+litellm --config config.yaml --detailed_debug
+```
+
+### 3. Test request
+
+**[Langchain, OpenAI SDK Usage Examples](../proxy/user_keys##request-format)**
+
+
+
+
+Expect this to fail since since `ishaan@berri.ai` in the request is PII
+
+```shell
+curl -i http://localhost:4000/v1/chat/completions \
+ -H "Content-Type: application/json" \
+ -H "Authorization: Bearer sk-npnwjPQciVRok5yNZgKmFQ" \
+ -d '{
+ "model": "gpt-3.5-turbo",
+ "messages": [
+ {"role": "user", "content": "hi my email is ishaan@berri.ai"}
+ ],
+ "guardrails": ["bedrock-guard"]
+ }'
+```
+
+Expected response on failure
+
+```shell
+{
+ "error": {
+ "message": {
+ "error": "Violated guardrail policy",
+ "bedrock_guardrail_response": {
+ "action": "GUARDRAIL_INTERVENED",
+ "assessments": [
+ {
+ "topicPolicy": {
+ "topics": [
+ {
+ "action": "BLOCKED",
+ "name": "Coffee",
+ "type": "DENY"
+ }
+ ]
+ }
+ }
+ ],
+ "blockedResponse": "Sorry, the model cannot answer this question. coffee guardrail applied ",
+ "output": [
+ {
+ "text": "Sorry, the model cannot answer this question. coffee guardrail applied "
+ }
+ ],
+ "outputs": [
+ {
+ "text": "Sorry, the model cannot answer this question. coffee guardrail applied "
+ }
+ ],
+ "usage": {
+ "contentPolicyUnits": 0,
+ "contextualGroundingPolicyUnits": 0,
+ "sensitiveInformationPolicyFreeUnits": 0,
+ "sensitiveInformationPolicyUnits": 0,
+ "topicPolicyUnits": 1,
+ "wordPolicyUnits": 0
+ }
+ }
+ },
+ "type": "None",
+ "param": "None",
+ "code": "400"
+ }
+}
+
+```
+
+
+
+
+
+```shell
+curl -i http://localhost:4000/v1/chat/completions \
+ -H "Content-Type: application/json" \
+ -H "Authorization: Bearer sk-npnwjPQciVRok5yNZgKmFQ" \
+ -d '{
+ "model": "gpt-3.5-turbo",
+ "messages": [
+ {"role": "user", "content": "hi what is the weather"}
+ ],
+ "guardrails": ["bedrock-guard"]
+ }'
+```
+
+
+
+
+
+
diff --git a/docs/my-website/docs/proxy/prometheus.md b/docs/my-website/docs/proxy/prometheus.md
index 4b913d2e8..10e6456c2 100644
--- a/docs/my-website/docs/proxy/prometheus.md
+++ b/docs/my-website/docs/proxy/prometheus.md
@@ -68,6 +68,15 @@ http://localhost:4000/metrics
| `litellm_total_tokens` | input + output tokens per `"user", "key", "model", "team", "end-user"` |
| `litellm_llm_api_failed_requests_metric` | Number of failed LLM API requests per `"user", "key", "model", "team", "end-user"` |
+### Request Latency Metrics
+
+| Metric Name | Description |
+|----------------------|--------------------------------------|
+| `litellm_request_total_latency_metric` | Total latency (seconds) for a request to LiteLLM Proxy Server - tracked for labels `litellm_call_id`, `model` |
+| `litellm_llm_api_latency_metric` | latency (seconds) for just the LLM API call - tracked for labels `litellm_call_id`, `model` |
+
+
+
### LLM API / Provider Metrics
| Metric Name | Description |
diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js
index ab94ed5b4..4bb4125e0 100644
--- a/docs/my-website/sidebars.js
+++ b/docs/my-website/sidebars.js
@@ -54,7 +54,7 @@ const sidebars = {
{
type: "category",
label: "🛡️ [Beta] Guardrails",
- items: ["proxy/guardrails/quick_start", "proxy/guardrails/aporia_api", "proxy/guardrails/lakera_ai"],
+ items: ["proxy/guardrails/quick_start", "proxy/guardrails/aporia_api", "proxy/guardrails/lakera_ai", "proxy/guardrails/bedrock"],
},
{
type: "category",
@@ -292,6 +292,7 @@ const sidebars = {
items: [
"projects/Docq.AI",
"projects/OpenInterpreter",
+ "projects/dbally",
"projects/FastREPL",
"projects/PROMPTMETHEUS",
"projects/Codium PR Agent",
diff --git a/litellm-js/spend-logs/package-lock.json b/litellm-js/spend-logs/package-lock.json
index cb4b599d3..5d8b85ad5 100644
--- a/litellm-js/spend-logs/package-lock.json
+++ b/litellm-js/spend-logs/package-lock.json
@@ -6,7 +6,7 @@
"": {
"dependencies": {
"@hono/node-server": "^1.10.1",
- "hono": "^4.2.7"
+ "hono": "^4.5.8"
},
"devDependencies": {
"@types/node": "^20.11.17",
@@ -463,9 +463,9 @@
}
},
"node_modules/hono": {
- "version": "4.2.7",
- "resolved": "https://registry.npmjs.org/hono/-/hono-4.2.7.tgz",
- "integrity": "sha512-k1xHi86tJnRIVvqhFMBDGFKJ8r5O+bEsT4P59ZK59r0F300Xd910/r237inVfuT/VmE86RQQffX4OYNda6dLXw==",
+ "version": "4.5.8",
+ "resolved": "https://registry.npmjs.org/hono/-/hono-4.5.8.tgz",
+ "integrity": "sha512-pqpSlcdqGkpTTRpLYU1PnCz52gVr0zVR9H5GzMyJWuKQLLEBQxh96q45QizJ2PPX8NATtz2mu31/PKW/Jt+90Q==",
"engines": {
"node": ">=16.0.0"
}
diff --git a/litellm-js/spend-logs/package.json b/litellm-js/spend-logs/package.json
index d9543220b..359935c25 100644
--- a/litellm-js/spend-logs/package.json
+++ b/litellm-js/spend-logs/package.json
@@ -4,7 +4,7 @@
},
"dependencies": {
"@hono/node-server": "^1.10.1",
- "hono": "^4.2.7"
+ "hono": "^4.5.8"
},
"devDependencies": {
"@types/node": "^20.11.17",
diff --git a/litellm/_redis.py b/litellm/_redis.py
index d72016dcd..23f82ed1a 100644
--- a/litellm/_redis.py
+++ b/litellm/_redis.py
@@ -7,13 +7,17 @@
#
# Thank you users! We ❤️ you! - Krrish & Ishaan
+import inspect
+
# s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation
import os
-import inspect
-import redis, litellm # type: ignore
-import redis.asyncio as async_redis # type: ignore
from typing import List, Optional
+import redis # type: ignore
+import redis.asyncio as async_redis # type: ignore
+
+import litellm
+
def _get_redis_kwargs():
arg_spec = inspect.getfullargspec(redis.Redis)
@@ -51,6 +55,19 @@ def _get_redis_url_kwargs(client=None):
return available_args
+def _get_redis_cluster_kwargs(client=None):
+ if client is None:
+ client = redis.Redis.from_url
+ arg_spec = inspect.getfullargspec(redis.RedisCluster)
+
+ # Only allow primitive arguments
+ exclude_args = {"self", "connection_pool", "retry", "host", "port", "startup_nodes"}
+
+ available_args = [x for x in arg_spec.args if x not in exclude_args]
+
+ return available_args
+
+
def _get_redis_env_kwarg_mapping():
PREFIX = "REDIS_"
@@ -124,6 +141,22 @@ def get_redis_client(**env_overrides):
url_kwargs[arg] = redis_kwargs[arg]
return redis.Redis.from_url(**url_kwargs)
+
+ if "startup_nodes" in redis_kwargs:
+ from redis.cluster import ClusterNode
+
+ args = _get_redis_cluster_kwargs()
+ cluster_kwargs = {}
+ for arg in redis_kwargs:
+ if arg in args:
+ cluster_kwargs[arg] = redis_kwargs[arg]
+
+ new_startup_nodes: List[ClusterNode] = []
+
+ for item in redis_kwargs["startup_nodes"]:
+ new_startup_nodes.append(ClusterNode(**item))
+ redis_kwargs.pop("startup_nodes")
+ return redis.RedisCluster(startup_nodes=new_startup_nodes, **cluster_kwargs)
return redis.Redis(**redis_kwargs)
@@ -143,6 +176,24 @@ def get_redis_async_client(**env_overrides):
)
return async_redis.Redis.from_url(**url_kwargs)
+ if "startup_nodes" in redis_kwargs:
+ from redis.cluster import ClusterNode
+
+ args = _get_redis_cluster_kwargs()
+ cluster_kwargs = {}
+ for arg in redis_kwargs:
+ if arg in args:
+ cluster_kwargs[arg] = redis_kwargs[arg]
+
+ new_startup_nodes: List[ClusterNode] = []
+
+ for item in redis_kwargs["startup_nodes"]:
+ new_startup_nodes.append(ClusterNode(**item))
+ redis_kwargs.pop("startup_nodes")
+ return async_redis.RedisCluster(
+ startup_nodes=new_startup_nodes, **cluster_kwargs
+ )
+
return async_redis.Redis(
socket_timeout=5,
**redis_kwargs,
@@ -160,4 +211,5 @@ def get_redis_connection_pool(**env_overrides):
connection_class = async_redis.SSLConnection
redis_kwargs.pop("ssl", None)
redis_kwargs["connection_class"] = connection_class
+ redis_kwargs.pop("startup_nodes", None)
return async_redis.BlockingConnectionPool(timeout=5, **redis_kwargs)
diff --git a/litellm/caching.py b/litellm/caching.py
index 1c7216029..1b19fdf3e 100644
--- a/litellm/caching.py
+++ b/litellm/caching.py
@@ -203,6 +203,7 @@ class RedisCache(BaseCache):
password=None,
redis_flush_size=100,
namespace: Optional[str] = None,
+ startup_nodes: Optional[List] = None, # for redis-cluster
**kwargs,
):
import redis
@@ -218,7 +219,8 @@ class RedisCache(BaseCache):
redis_kwargs["port"] = port
if password is not None:
redis_kwargs["password"] = password
-
+ if startup_nodes is not None:
+ redis_kwargs["startup_nodes"] = startup_nodes
### HEALTH MONITORING OBJECT ###
if kwargs.get("service_logger_obj", None) is not None and isinstance(
kwargs["service_logger_obj"], ServiceLogging
@@ -246,7 +248,7 @@ class RedisCache(BaseCache):
### ASYNC HEALTH PING ###
try:
# asyncio.get_running_loop().create_task(self.ping())
- result = asyncio.get_running_loop().create_task(self.ping())
+ _ = asyncio.get_running_loop().create_task(self.ping())
except Exception as e:
if "no running event loop" in str(e):
verbose_logger.debug(
@@ -2123,6 +2125,7 @@ class Cache:
redis_semantic_cache_use_async=False,
redis_semantic_cache_embedding_model="text-embedding-ada-002",
redis_flush_size=None,
+ redis_startup_nodes: Optional[List] = None,
disk_cache_dir=None,
qdrant_api_base: Optional[str] = None,
qdrant_api_key: Optional[str] = None,
@@ -2155,7 +2158,12 @@ class Cache:
"""
if type == "redis":
self.cache: BaseCache = RedisCache(
- host, port, password, redis_flush_size, **kwargs
+ host,
+ port,
+ password,
+ redis_flush_size,
+ startup_nodes=redis_startup_nodes,
+ **kwargs,
)
elif type == "redis-semantic":
self.cache = RedisSemanticCache(
diff --git a/litellm/integrations/prometheus.py b/litellm/integrations/prometheus.py
index 321c1cc1f..659e5b193 100644
--- a/litellm/integrations/prometheus.py
+++ b/litellm/integrations/prometheus.py
@@ -60,6 +60,25 @@ class PrometheusLogger(CustomLogger):
],
)
+ # request latency metrics
+ self.litellm_request_total_latency_metric = Histogram(
+ "litellm_request_total_latency_metric",
+ "Total latency (seconds) for a request to LiteLLM",
+ labelnames=[
+ "model",
+ "litellm_call_id",
+ ],
+ )
+
+ self.litellm_llm_api_latency_metric = Histogram(
+ "litellm_llm_api_latency_metric",
+ "Total latency (seconds) for a models LLM API call",
+ labelnames=[
+ "model",
+ "litellm_call_id",
+ ],
+ )
+
# Counter for spend
self.litellm_spend_metric = Counter(
"litellm_spend_metric",
@@ -103,26 +122,27 @@ class PrometheusLogger(CustomLogger):
"Remaining budget for api key",
labelnames=["hashed_api_key", "api_key_alias"],
)
+
+ ########################################
+ # LiteLLM Virtual API KEY metrics
+ ########################################
+ # Remaining MODEL RPM limit for API Key
+ self.litellm_remaining_api_key_requests_for_model = Gauge(
+ "litellm_remaining_api_key_requests_for_model",
+ "Remaining Requests API Key can make for model (model based rpm limit on key)",
+ labelnames=["hashed_api_key", "api_key_alias", "model"],
+ )
+
+ # Remaining MODEL TPM limit for API Key
+ self.litellm_remaining_api_key_tokens_for_model = Gauge(
+ "litellm_remaining_api_key_tokens_for_model",
+ "Remaining Tokens API Key can make for model (model based tpm limit on key)",
+ labelnames=["hashed_api_key", "api_key_alias", "model"],
+ )
+
# Litellm-Enterprise Metrics
if premium_user is True:
- ########################################
- # LiteLLM Virtual API KEY metrics
- ########################################
- # Remaining MODEL RPM limit for API Key
- self.litellm_remaining_api_key_requests_for_model = Gauge(
- "litellm_remaining_api_key_requests_for_model",
- "Remaining Requests API Key can make for model (model based rpm limit on key)",
- labelnames=["hashed_api_key", "api_key_alias", "model"],
- )
-
- # Remaining MODEL TPM limit for API Key
- self.litellm_remaining_api_key_tokens_for_model = Gauge(
- "litellm_remaining_api_key_tokens_for_model",
- "Remaining Tokens API Key can make for model (model based tpm limit on key)",
- labelnames=["hashed_api_key", "api_key_alias", "model"],
- )
-
########################################
# LLM API Deployment Metrics / analytics
########################################
@@ -328,6 +348,25 @@ class PrometheusLogger(CustomLogger):
user_api_key, user_api_key_alias, model_group
).set(remaining_tokens)
+ # latency metrics
+ total_time: timedelta = kwargs.get("end_time") - kwargs.get("start_time")
+ total_time_seconds = total_time.total_seconds()
+ api_call_total_time: timedelta = kwargs.get("end_time") - kwargs.get(
+ "api_call_start_time"
+ )
+
+ api_call_total_time_seconds = api_call_total_time.total_seconds()
+
+ litellm_call_id = kwargs.get("litellm_call_id")
+
+ self.litellm_request_total_latency_metric.labels(
+ model, litellm_call_id
+ ).observe(total_time_seconds)
+
+ self.litellm_llm_api_latency_metric.labels(model, litellm_call_id).observe(
+ api_call_total_time_seconds
+ )
+
# set x-ratelimit headers
if premium_user is True:
self.set_llm_deployment_success_metrics(
diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py
index d59f98558..dbf2a7d3e 100644
--- a/litellm/litellm_core_utils/litellm_logging.py
+++ b/litellm/litellm_core_utils/litellm_logging.py
@@ -354,6 +354,8 @@ class Logging:
str(e)
)
)
+
+ self.model_call_details["api_call_start_time"] = datetime.datetime.now()
# Input Integration Logging -> If you want to log the fact that an attempt to call the model was made
callbacks = litellm.input_callback + self.dynamic_input_callbacks
for callback in callbacks:
diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock_httpx.py
index e45559752..23e7fdc3e 100644
--- a/litellm/llms/bedrock_httpx.py
+++ b/litellm/llms/bedrock_httpx.py
@@ -84,6 +84,7 @@ BEDROCK_CONVERSE_MODELS = [
"meta.llama3-1-8b-instruct-v1:0",
"meta.llama3-1-70b-instruct-v1:0",
"meta.llama3-1-405b-instruct-v1:0",
+ "meta.llama3-70b-instruct-v1:0",
"mistral.mistral-large-2407-v1:0",
]
@@ -1480,7 +1481,7 @@ class BedrockConverseLLM(BaseAWSLLM):
optional_params: dict,
acompletion: bool,
timeout: Optional[Union[float, httpx.Timeout]],
- litellm_params=None,
+ litellm_params: dict,
logger_fn=None,
extra_headers: Optional[dict] = None,
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
@@ -1596,6 +1597,14 @@ class BedrockConverseLLM(BaseAWSLLM):
supported_tool_call_params = ["tools", "tool_choice"]
supported_guardrail_params = ["guardrailConfig"]
## TRANSFORMATION ##
+
+ bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
+ messages=messages,
+ model=model,
+ llm_provider="bedrock_converse",
+ user_continue_message=litellm_params.pop("user_continue_message", None),
+ )
+
# send all model-specific params in 'additional_request_params'
for k, v in inference_params.items():
if (
@@ -1608,11 +1617,6 @@ class BedrockConverseLLM(BaseAWSLLM):
for key in additional_request_keys:
inference_params.pop(key, None)
- bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
- messages=messages,
- model=model,
- llm_provider="bedrock_converse",
- )
bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
inference_params.pop("tools", [])
)
diff --git a/litellm/llms/cohere.py b/litellm/llms/cohere.py
index 3873027b2..8bd1051e8 100644
--- a/litellm/llms/cohere.py
+++ b/litellm/llms/cohere.py
@@ -124,12 +124,14 @@ class CohereConfig:
}
-def validate_environment(api_key):
- headers = {
- "Request-Source": "unspecified:litellm",
- "accept": "application/json",
- "content-type": "application/json",
- }
+def validate_environment(api_key, headers: dict):
+ headers.update(
+ {
+ "Request-Source": "unspecified:litellm",
+ "accept": "application/json",
+ "content-type": "application/json",
+ }
+ )
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
return headers
@@ -144,11 +146,12 @@ def completion(
encoding,
api_key,
logging_obj,
+ headers: dict,
optional_params=None,
litellm_params=None,
logger_fn=None,
):
- headers = validate_environment(api_key)
+ headers = validate_environment(api_key, headers=headers)
completion_url = api_base
model = model
prompt = " ".join(message["content"] for message in messages)
@@ -338,13 +341,14 @@ def embedding(
model_response: litellm.EmbeddingResponse,
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
+ headers: dict,
encoding: Any,
api_key: Optional[str] = None,
aembedding: Optional[bool] = None,
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
):
- headers = validate_environment(api_key)
+ headers = validate_environment(api_key, headers=headers)
embed_url = "https://api.cohere.ai/v1/embed"
model = model
data = {"model": model, "texts": input, **optional_params}
diff --git a/litellm/llms/cohere_chat.py b/litellm/llms/cohere_chat.py
index a0a9a9874..f13e74614 100644
--- a/litellm/llms/cohere_chat.py
+++ b/litellm/llms/cohere_chat.py
@@ -116,12 +116,14 @@ class CohereChatConfig:
}
-def validate_environment(api_key):
- headers = {
- "Request-Source": "unspecified:litellm",
- "accept": "application/json",
- "content-type": "application/json",
- }
+def validate_environment(api_key, headers: dict):
+ headers.update(
+ {
+ "Request-Source": "unspecified:litellm",
+ "accept": "application/json",
+ "content-type": "application/json",
+ }
+ )
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
return headers
@@ -203,13 +205,14 @@ def completion(
model_response: ModelResponse,
print_verbose: Callable,
optional_params: dict,
+ headers: dict,
encoding,
api_key,
logging_obj,
litellm_params=None,
logger_fn=None,
):
- headers = validate_environment(api_key)
+ headers = validate_environment(api_key, headers=headers)
completion_url = api_base
model = model
most_recent_message, chat_history = cohere_messages_pt_v2(
diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py
index c9e691c00..2b9a7fc24 100644
--- a/litellm/llms/prompt_templates/factory.py
+++ b/litellm/llms/prompt_templates/factory.py
@@ -38,6 +38,18 @@ def prompt_injection_detection_default_pt():
BAD_MESSAGE_ERROR_STR = "Invalid Message "
+# used to interweave user messages, to ensure user/assistant alternating
+DEFAULT_USER_CONTINUE_MESSAGE = {
+ "role": "user",
+ "content": "Please continue.",
+} # similar to autogen. Only used if `litellm.modify_params=True`.
+
+# used to interweave assistant messages, to ensure user/assistant alternating
+DEFAULT_ASSISTANT_CONTINUE_MESSAGE = {
+ "role": "assistant",
+ "content": "Please continue.",
+} # similar to autogen. Only used if `litellm.modify_params=True`.
+
def map_system_message_pt(messages: list) -> list:
"""
@@ -2254,6 +2266,7 @@ def _bedrock_converse_messages_pt(
messages: List,
model: str,
llm_provider: str,
+ user_continue_message: Optional[dict] = None,
) -> List[BedrockMessageBlock]:
"""
Converts given messages from OpenAI format to Bedrock format
@@ -2264,6 +2277,21 @@ def _bedrock_converse_messages_pt(
contents: List[BedrockMessageBlock] = []
msg_i = 0
+
+ # if initial message is assistant message
+ if messages[0].get("role") is not None and messages[0]["role"] == "assistant":
+ if user_continue_message is not None:
+ messages.insert(0, user_continue_message)
+ elif litellm.modify_params:
+ messages.insert(0, DEFAULT_USER_CONTINUE_MESSAGE)
+
+ # if final message is assistant message
+ if messages[-1].get("role") is not None and messages[-1]["role"] == "assistant":
+ if user_continue_message is not None:
+ messages.append(user_continue_message)
+ elif litellm.modify_params:
+ messages.append(DEFAULT_USER_CONTINUE_MESSAGE)
+
while msg_i < len(messages):
user_content: List[BedrockContentBlock] = []
init_msg_i = msg_i
@@ -2344,6 +2372,7 @@ def _bedrock_converse_messages_pt(
model=model,
llm_provider=llm_provider,
)
+
return contents
diff --git a/litellm/main.py b/litellm/main.py
index 45e164a89..16a4f89ed 100644
--- a/litellm/main.py
+++ b/litellm/main.py
@@ -944,6 +944,8 @@ def completion(
cooldown_time=cooldown_time,
text_completion=kwargs.get("text_completion"),
azure_ad_token_provider=kwargs.get("azure_ad_token_provider"),
+ user_continue_message=kwargs.get("user_continue_message"),
+
)
logging.update_environment_variables(
model=model,
@@ -1635,6 +1637,13 @@ def completion(
or "https://api.cohere.ai/v1/generate"
)
+ headers = headers or litellm.headers or {}
+ if headers is None:
+ headers = {}
+
+ if extra_headers is not None:
+ headers.update(extra_headers)
+
model_response = cohere.completion(
model=model,
messages=messages,
@@ -1645,6 +1654,7 @@ def completion(
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
+ headers=headers,
api_key=cohere_key,
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
)
@@ -1675,6 +1685,13 @@ def completion(
or "https://api.cohere.ai/v1/chat"
)
+ headers = headers or litellm.headers or {}
+ if headers is None:
+ headers = {}
+
+ if extra_headers is not None:
+ headers.update(extra_headers)
+
model_response = cohere_chat.completion(
model=model,
messages=messages,
@@ -1683,6 +1700,7 @@ def completion(
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
+ headers=headers,
logger_fn=logger_fn,
encoding=encoding,
api_key=cohere_key,
@@ -2289,7 +2307,7 @@ def completion(
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
- litellm_params=litellm_params,
+ litellm_params=litellm_params, # type: ignore
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
@@ -3159,6 +3177,7 @@ def embedding(
encoding_format = kwargs.get("encoding_format", None)
proxy_server_request = kwargs.get("proxy_server_request", None)
aembedding = kwargs.get("aembedding", None)
+ extra_headers = kwargs.get("extra_headers", None)
### CUSTOM MODEL COST ###
input_cost_per_token = kwargs.get("input_cost_per_token", None)
output_cost_per_token = kwargs.get("output_cost_per_token", None)
@@ -3234,6 +3253,7 @@ def embedding(
"tenant_id",
"client_id",
"client_secret",
+ "extra_headers",
]
default_params = openai_params + litellm_params
non_default_params = {
@@ -3297,7 +3317,7 @@ def embedding(
"cooldown_time": cooldown_time,
},
)
- if azure == True or custom_llm_provider == "azure":
+ if azure is True or custom_llm_provider == "azure":
# azure configs
api_type = get_secret("AZURE_API_TYPE") or "azure"
@@ -3403,12 +3423,18 @@ def embedding(
or get_secret("CO_API_KEY")
or litellm.api_key
)
+
+ if extra_headers is not None and isinstance(extra_headers, dict):
+ headers = extra_headers
+ else:
+ headers = {}
response = cohere.embedding(
model=model,
input=input,
optional_params=optional_params,
encoding=encoding,
api_key=cohere_key, # type: ignore
+ headers=headers,
logging_obj=logging,
model_response=EmbeddingResponse(),
aembedding=aembedding,
diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml
index 2c888a4f3..96a0242a8 100644
--- a/litellm/proxy/_new_secret_config.yaml
+++ b/litellm/proxy/_new_secret_config.yaml
@@ -2,12 +2,3 @@ model_list:
- model_name: "*"
litellm_params:
model: "*"
-
-litellm_settings:
- success_callback: ["s3"]
- cache: true
- s3_callback_params:
- s3_bucket_name: mytestbucketlitellm # AWS Bucket Name for S3
- s3_region_name: us-west-2 # AWS Region Name for S3
- s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/ to pass environment variables. This is AWS Access Key ID for S3
- s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3
diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py
index cf5065c2e..0f1452651 100644
--- a/litellm/proxy/auth/auth_checks.py
+++ b/litellm/proxy/auth/auth_checks.py
@@ -66,7 +66,7 @@ def common_checks(
raise Exception(
f"Team={team_object.team_id} is blocked. Update via `/team/unblock` if your admin."
)
- # 2. If user can call model
+ # 2. If team can call model
if (
_model is not None
and team_object is not None
@@ -74,7 +74,11 @@ def common_checks(
and _model not in team_object.models
):
# this means the team has access to all models on the proxy
- if "all-proxy-models" in team_object.models:
+ if (
+ "all-proxy-models" in team_object.models
+ or "*" in team_object.models
+ or "openai/*" in team_object.models
+ ):
# this means the team has access to all models on the proxy
pass
# check if the team model is an access_group
diff --git a/litellm/proxy/example_config_yaml/otel_test_config.yaml b/litellm/proxy/example_config_yaml/otel_test_config.yaml
index 496ae1710..8ca4f37fd 100644
--- a/litellm/proxy/example_config_yaml/otel_test_config.yaml
+++ b/litellm/proxy/example_config_yaml/otel_test_config.yaml
@@ -21,4 +21,10 @@ guardrails:
guardrail: aporia # supported values: "aporia", "bedrock", "lakera"
mode: "post_call"
api_key: os.environ/APORIA_API_KEY_2
- api_base: os.environ/APORIA_API_BASE_2
\ No newline at end of file
+ api_base: os.environ/APORIA_API_BASE_2
+ - guardrail_name: "bedrock-pre-guard"
+ litellm_params:
+ guardrail: bedrock # supported values: "aporia", "bedrock", "lakera"
+ mode: "pre_call"
+ guardrailIdentifier: ff6ujrregl1q
+ guardrailVersion: "DRAFT"
\ No newline at end of file
diff --git a/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py b/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py
new file mode 100644
index 000000000..d11f58a3e
--- /dev/null
+++ b/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py
@@ -0,0 +1,289 @@
+# +-------------------------------------------------------------+
+#
+# Use Bedrock Guardrails for your LLM calls
+#
+# +-------------------------------------------------------------+
+# Thank you users! We ❤️ you! - Krrish & Ishaan
+
+import os
+import sys
+
+sys.path.insert(
+ 0, os.path.abspath("../..")
+) # Adds the parent directory to the system path
+import asyncio
+import json
+import sys
+import traceback
+import uuid
+from datetime import datetime
+from typing import Any, Dict, List, Literal, Optional, Union
+
+import aiohttp
+import httpx
+from fastapi import HTTPException
+
+import litellm
+from litellm import get_secret
+from litellm._logging import verbose_proxy_logger
+from litellm.caching import DualCache
+from litellm.integrations.custom_guardrail import CustomGuardrail
+from litellm.litellm_core_utils.logging_utils import (
+ convert_litellm_response_object_to_str,
+)
+from litellm.llms.base_aws_llm import BaseAWSLLM
+from litellm.llms.custom_httpx.http_handler import (
+ AsyncHTTPHandler,
+ _get_async_httpx_client,
+)
+from litellm.proxy._types import UserAPIKeyAuth
+from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
+from litellm.types.guardrails import (
+ BedrockContentItem,
+ BedrockRequest,
+ BedrockTextContent,
+ GuardrailEventHooks,
+)
+
+GUARDRAIL_NAME = "bedrock"
+
+
+class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
+ def __init__(
+ self,
+ guardrailIdentifier: Optional[str] = None,
+ guardrailVersion: Optional[str] = None,
+ **kwargs,
+ ):
+ self.async_handler = _get_async_httpx_client()
+ self.guardrailIdentifier = guardrailIdentifier
+ self.guardrailVersion = guardrailVersion
+
+ # store kwargs as optional_params
+ self.optional_params = kwargs
+
+ super().__init__(**kwargs)
+
+ def convert_to_bedrock_format(
+ self,
+ messages: Optional[List[Dict[str, str]]] = None,
+ response: Optional[Union[Any, litellm.ModelResponse]] = None,
+ ) -> BedrockRequest:
+ bedrock_request: BedrockRequest = BedrockRequest(source="INPUT")
+ bedrock_request_content: List[BedrockContentItem] = []
+
+ if messages:
+ for message in messages:
+ content = message.get("content")
+ if isinstance(content, str):
+ bedrock_content_item = BedrockContentItem(
+ text=BedrockTextContent(text=content)
+ )
+ bedrock_request_content.append(bedrock_content_item)
+
+ bedrock_request["content"] = bedrock_request_content
+ if response:
+ bedrock_request["source"] = "OUTPUT"
+ if isinstance(response, litellm.ModelResponse):
+ for choice in response.choices:
+ if isinstance(choice, litellm.Choices):
+ if choice.message.content and isinstance(
+ choice.message.content, str
+ ):
+ bedrock_content_item = BedrockContentItem(
+ text=BedrockTextContent(text=choice.message.content)
+ )
+ bedrock_request_content.append(bedrock_content_item)
+ bedrock_request["content"] = bedrock_request_content
+ return bedrock_request
+
+ #### CALL HOOKS - proxy only ####
+ def _load_credentials(
+ self,
+ ):
+ try:
+ from botocore.credentials import Credentials
+ except ImportError as e:
+ raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
+ ## CREDENTIALS ##
+ # pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
+ aws_secret_access_key = self.optional_params.pop("aws_secret_access_key", None)
+ aws_access_key_id = self.optional_params.pop("aws_access_key_id", None)
+ aws_session_token = self.optional_params.pop("aws_session_token", None)
+ aws_region_name = self.optional_params.pop("aws_region_name", None)
+ aws_role_name = self.optional_params.pop("aws_role_name", None)
+ aws_session_name = self.optional_params.pop("aws_session_name", None)
+ aws_profile_name = self.optional_params.pop("aws_profile_name", None)
+ aws_bedrock_runtime_endpoint = self.optional_params.pop(
+ "aws_bedrock_runtime_endpoint", None
+ ) # https://bedrock-runtime.{region_name}.amazonaws.com
+ aws_web_identity_token = self.optional_params.pop(
+ "aws_web_identity_token", None
+ )
+ aws_sts_endpoint = self.optional_params.pop("aws_sts_endpoint", None)
+
+ ### SET REGION NAME ###
+ if aws_region_name is None:
+ # check env #
+ litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
+
+ if litellm_aws_region_name is not None and isinstance(
+ litellm_aws_region_name, str
+ ):
+ aws_region_name = litellm_aws_region_name
+
+ standard_aws_region_name = get_secret("AWS_REGION", None)
+ if standard_aws_region_name is not None and isinstance(
+ standard_aws_region_name, str
+ ):
+ aws_region_name = standard_aws_region_name
+
+ if aws_region_name is None:
+ aws_region_name = "us-west-2"
+
+ credentials: Credentials = self.get_credentials(
+ aws_access_key_id=aws_access_key_id,
+ aws_secret_access_key=aws_secret_access_key,
+ aws_session_token=aws_session_token,
+ aws_region_name=aws_region_name,
+ aws_session_name=aws_session_name,
+ aws_profile_name=aws_profile_name,
+ aws_role_name=aws_role_name,
+ aws_web_identity_token=aws_web_identity_token,
+ aws_sts_endpoint=aws_sts_endpoint,
+ )
+ return credentials, aws_region_name
+
+ def _prepare_request(
+ self,
+ credentials,
+ data: BedrockRequest,
+ optional_params: dict,
+ aws_region_name: str,
+ extra_headers: Optional[dict] = None,
+ ):
+ try:
+ import boto3
+ from botocore.auth import SigV4Auth
+ from botocore.awsrequest import AWSRequest
+ from botocore.credentials import Credentials
+ except ImportError as e:
+ raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
+
+ sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
+ api_base = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com/guardrail/{self.guardrailIdentifier}/version/{self.guardrailVersion}/apply"
+
+ encoded_data = json.dumps(data).encode("utf-8")
+ headers = {"Content-Type": "application/json"}
+ if extra_headers is not None:
+ headers = {"Content-Type": "application/json", **extra_headers}
+
+ request = AWSRequest(
+ method="POST", url=api_base, data=encoded_data, headers=headers
+ )
+ sigv4.add_auth(request)
+ prepped_request = request.prepare()
+
+ return prepped_request
+
+ async def make_bedrock_api_request(
+ self, kwargs: dict, response: Optional[Union[Any, litellm.ModelResponse]] = None
+ ):
+
+ credentials, aws_region_name = self._load_credentials()
+ request_data: BedrockRequest = self.convert_to_bedrock_format(
+ messages=kwargs.get("messages"), response=response
+ )
+ prepared_request = self._prepare_request(
+ credentials=credentials,
+ data=request_data,
+ optional_params=self.optional_params,
+ aws_region_name=aws_region_name,
+ )
+ verbose_proxy_logger.debug(
+ "Bedrock AI request body: %s, url %s, headers: %s",
+ request_data,
+ prepared_request.url,
+ prepared_request.headers,
+ )
+ _json_data = json.dumps(request_data) # type: ignore
+ response = await self.async_handler.post(
+ url=prepared_request.url,
+ json=request_data, # type: ignore
+ headers=prepared_request.headers,
+ )
+ verbose_proxy_logger.debug("Bedrock AI response: %s", response.text)
+ if response.status_code == 200:
+ # check if the response was flagged
+ _json_response = response.json()
+ if _json_response.get("action") == "GUARDRAIL_INTERVENED":
+ raise HTTPException(
+ status_code=400,
+ detail={
+ "error": "Violated guardrail policy",
+ "bedrock_guardrail_response": _json_response,
+ },
+ )
+ else:
+ verbose_proxy_logger.error(
+ "Bedrock AI: error in response. Status code: %s, response: %s",
+ response.status_code,
+ response.text,
+ )
+
+ async def async_moderation_hook( ### 👈 KEY CHANGE ###
+ self,
+ data: dict,
+ user_api_key_dict: UserAPIKeyAuth,
+ call_type: Literal["completion", "embeddings", "image_generation"],
+ ):
+ from litellm.proxy.common_utils.callback_utils import (
+ add_guardrail_to_applied_guardrails_header,
+ )
+ from litellm.types.guardrails import GuardrailEventHooks
+
+ event_type: GuardrailEventHooks = GuardrailEventHooks.during_call
+ if self.should_run_guardrail(data=data, event_type=event_type) is not True:
+ return
+
+ new_messages: Optional[List[dict]] = data.get("messages")
+ if new_messages is not None:
+ await self.make_bedrock_api_request(kwargs=data)
+ add_guardrail_to_applied_guardrails_header(
+ request_data=data, guardrail_name=self.guardrail_name
+ )
+ else:
+ verbose_proxy_logger.warning(
+ "Bedrock AI: not running guardrail. No messages in data"
+ )
+ pass
+
+ async def async_post_call_success_hook(
+ self,
+ data: dict,
+ user_api_key_dict: UserAPIKeyAuth,
+ response,
+ ):
+ from litellm.proxy.common_utils.callback_utils import (
+ add_guardrail_to_applied_guardrails_header,
+ )
+ from litellm.types.guardrails import GuardrailEventHooks
+
+ if (
+ self.should_run_guardrail(
+ data=data, event_type=GuardrailEventHooks.post_call
+ )
+ is not True
+ ):
+ return
+
+ new_messages: Optional[List[dict]] = data.get("messages")
+ if new_messages is not None:
+ await self.make_bedrock_api_request(kwargs=data, response=response)
+ add_guardrail_to_applied_guardrails_header(
+ request_data=data, guardrail_name=self.guardrail_name
+ )
+ else:
+ verbose_proxy_logger.warning(
+ "Bedrock AI: not running guardrail. No messages in data"
+ )
diff --git a/litellm/proxy/guardrails/init_guardrails.py b/litellm/proxy/guardrails/init_guardrails.py
index ad99daf95..f0e2a9e2e 100644
--- a/litellm/proxy/guardrails/init_guardrails.py
+++ b/litellm/proxy/guardrails/init_guardrails.py
@@ -96,8 +96,10 @@ def init_guardrails_v2(all_guardrails: dict):
litellm_params = LitellmParams(
guardrail=litellm_params_data["guardrail"],
mode=litellm_params_data["mode"],
- api_key=litellm_params_data["api_key"],
- api_base=litellm_params_data["api_base"],
+ api_key=litellm_params_data.get("api_key"),
+ api_base=litellm_params_data.get("api_base"),
+ guardrailIdentifier=litellm_params_data.get("guardrailIdentifier"),
+ guardrailVersion=litellm_params_data.get("guardrailVersion"),
)
if (
@@ -134,6 +136,18 @@ def init_guardrails_v2(all_guardrails: dict):
event_hook=litellm_params["mode"],
)
litellm.callbacks.append(_aporia_callback) # type: ignore
+ if litellm_params["guardrail"] == "bedrock":
+ from litellm.proxy.guardrails.guardrail_hooks.bedrock_guardrails import (
+ BedrockGuardrail,
+ )
+
+ _bedrock_callback = BedrockGuardrail(
+ guardrail_name=guardrail["guardrail_name"],
+ event_hook=litellm_params["mode"],
+ guardrailIdentifier=litellm_params["guardrailIdentifier"],
+ guardrailVersion=litellm_params["guardrailVersion"],
+ )
+ litellm.callbacks.append(_bedrock_callback) # type: ignore
elif litellm_params["guardrail"] == "lakera":
from litellm.proxy.guardrails.guardrail_hooks.lakera_ai import (
lakeraAI_Moderation,
diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml
index c8599d56e..320216a79 100644
--- a/litellm/proxy/proxy_config.yaml
+++ b/litellm/proxy/proxy_config.yaml
@@ -1,5 +1,5 @@
model_list:
- - model_name: gpt-4
+ - model_name: fake-openai-endpoint
litellm_params:
model: azure/chatgpt-v-2
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
@@ -9,13 +9,9 @@ model_list:
client_secret: os.environ/AZURE_CLIENT_SECRET
guardrails:
- - guardrail_name: "lakera-pre-guard"
+ - guardrail_name: "bedrock-pre-guard"
litellm_params:
- guardrail: lakera # supported values: "aporia", "bedrock", "lakera"
- mode: "during_call"
- api_key: os.environ/LAKERA_API_KEY
- api_base: os.environ/LAKERA_API_BASE
- category_thresholds:
- prompt_injection: 0.1
- jailbreak: 0.1
-
\ No newline at end of file
+ guardrail: bedrock # supported values: "aporia", "bedrock", "lakera"
+ mode: "post_call"
+ guardrailIdentifier: ff6ujrregl1q
+ guardrailVersion: "DRAFT"
\ No newline at end of file
diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py
index 0a9abc09a..c793ffbe3 100644
--- a/litellm/proxy/proxy_server.py
+++ b/litellm/proxy/proxy_server.py
@@ -1588,7 +1588,7 @@ class ProxyConfig:
verbose_proxy_logger.debug( # noqa
f"{blue_color_code}Set Cache on LiteLLM Proxy: {vars(litellm.cache.cache)}{reset_color_code}"
)
- elif key == "cache" and value == False:
+ elif key == "cache" and value is False:
pass
elif key == "guardrails":
if premium_user is not True:
@@ -2672,6 +2672,13 @@ def giveup(e):
and isinstance(e.message, str)
and "Max parallel request limit reached" in e.message
)
+
+ if (
+ general_settings.get("disable_retry_on_max_parallel_request_limit_error")
+ is True
+ ):
+ return True # giveup if queuing max parallel request limits is disabled
+
if result:
verbose_proxy_logger.info(json.dumps({"event": "giveup", "exception": str(e)}))
return result
diff --git a/litellm/router.py b/litellm/router.py
index e261c1743..7a938f5c4 100644
--- a/litellm/router.py
+++ b/litellm/router.py
@@ -277,7 +277,8 @@ class Router:
"local" # default to an in-memory cache
)
redis_cache = None
- cache_config = {}
+ cache_config: Dict[str, Any] = {}
+
self.client_ttl = client_ttl
if redis_url is not None or (
redis_host is not None
diff --git a/litellm/tests/test_bedrock_completion.py b/litellm/tests/test_bedrock_completion.py
index 4892601b1..90592b499 100644
--- a/litellm/tests/test_bedrock_completion.py
+++ b/litellm/tests/test_bedrock_completion.py
@@ -738,8 +738,9 @@ def test_bedrock_system_prompt(system, model):
"temperature": 0.3,
"messages": [
{"role": "system", "content": system},
- {"role": "user", "content": "hey, how's it going?"},
+ {"role": "assistant", "content": "hey, how's it going?"},
],
+ "user_continue_message": {"role": "user", "content": "Be a good bot!"},
}
response: ModelResponse = completion(
model="bedrock/{}".format(model),
diff --git a/litellm/tests/test_caching.py b/litellm/tests/test_caching.py
index 64196e5c5..e474dff2e 100644
--- a/litellm/tests/test_caching.py
+++ b/litellm/tests/test_caching.py
@@ -804,6 +804,38 @@ def test_redis_cache_completion_stream():
# test_redis_cache_completion_stream()
+@pytest.mark.skip(reason="Local test. Requires running redis cluster locally.")
+@pytest.mark.asyncio
+async def test_redis_cache_cluster_init_unit_test():
+ try:
+ from redis.asyncio import RedisCluster as AsyncRedisCluster
+ from redis.cluster import RedisCluster
+
+ from litellm.caching import RedisCache
+
+ litellm.set_verbose = True
+
+ # List of startup nodes
+ startup_nodes = [
+ {"host": "127.0.0.1", "port": "7001"},
+ ]
+
+ resp = RedisCache(startup_nodes=startup_nodes)
+
+ assert isinstance(resp.redis_client, RedisCluster)
+ assert isinstance(resp.init_async_client(), AsyncRedisCluster)
+
+ resp = litellm.Cache(type="redis", redis_startup_nodes=startup_nodes)
+
+ assert isinstance(resp.cache, RedisCache)
+ assert isinstance(resp.cache.redis_client, RedisCluster)
+ assert isinstance(resp.cache.init_async_client(), AsyncRedisCluster)
+
+ except Exception as e:
+ print(f"{str(e)}\n\n{traceback.format_exc()}")
+ raise e
+
+
@pytest.mark.asyncio
async def test_redis_cache_acompletion_stream():
try:
diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py
index 0941484d9..c0c3c70f9 100644
--- a/litellm/tests/test_completion.py
+++ b/litellm/tests/test_completion.py
@@ -3653,6 +3653,7 @@ def test_completion_cohere():
response = completion(
model="command-r",
messages=messages,
+ extra_headers={"Helicone-Property-Locale": "ko"},
)
print(response)
except Exception as e:
diff --git a/litellm/types/guardrails.py b/litellm/types/guardrails.py
index 66c2a535a..10f4be7e1 100644
--- a/litellm/types/guardrails.py
+++ b/litellm/types/guardrails.py
@@ -1,5 +1,5 @@
from enum import Enum
-from typing import Dict, List, Optional, TypedDict
+from typing import Dict, List, Literal, Optional, TypedDict
from pydantic import BaseModel, ConfigDict
from typing_extensions import Required, TypedDict
@@ -76,8 +76,14 @@ class LitellmParams(TypedDict, total=False):
mode: str
api_key: str
api_base: Optional[str]
+
+ # Lakera specific params
category_thresholds: Optional[LakeraCategoryThresholds]
+ # Bedrock specific params
+ guardrailIdentifier: Optional[str]
+ guardrailVersion: Optional[str]
+
class Guardrail(TypedDict):
guardrail_name: str
@@ -92,3 +98,16 @@ class GuardrailEventHooks(str, Enum):
pre_call = "pre_call"
post_call = "post_call"
during_call = "during_call"
+
+
+class BedrockTextContent(TypedDict, total=False):
+ text: str
+
+
+class BedrockContentItem(TypedDict, total=False):
+ text: BedrockTextContent
+
+
+class BedrockRequest(TypedDict, total=False):
+ source: Literal["INPUT", "OUTPUT"]
+ content: List[BedrockContentItem]
diff --git a/litellm/types/utils.py b/litellm/types/utils.py
index 21eae868a..14d5cd1b8 100644
--- a/litellm/types/utils.py
+++ b/litellm/types/utils.py
@@ -1120,6 +1120,7 @@ all_litellm_params = [
"tenant_id",
"client_id",
"client_secret",
+ "user_continue_message",
]
diff --git a/litellm/utils.py b/litellm/utils.py
index ea6b99524..d5aefa80e 100644
--- a/litellm/utils.py
+++ b/litellm/utils.py
@@ -2324,6 +2324,7 @@ def get_litellm_params(
cooldown_time=None,
text_completion=None,
azure_ad_token_provider=None,
+ user_continue_message=None,
):
litellm_params = {
"acompletion": acompletion,
@@ -2349,6 +2350,7 @@ def get_litellm_params(
"cooldown_time": cooldown_time,
"text_completion": text_completion,
"azure_ad_token_provider": azure_ad_token_provider,
+ "user_continue_message": user_continue_message,
}
return litellm_params
@@ -4221,6 +4223,7 @@ def get_supported_openai_params(
"presence_penalty",
"stop",
"n",
+ "extra_headers",
]
elif custom_llm_provider == "cohere_chat":
return [
@@ -4235,6 +4238,7 @@ def get_supported_openai_params(
"tools",
"tool_choice",
"seed",
+ "extra_headers",
]
elif custom_llm_provider == "maritalk":
return [
@@ -7123,6 +7127,14 @@ def exception_type(
llm_provider="bedrock",
response=original_exception.response,
)
+ elif "A conversation must start with a user message." in error_str:
+ exception_mapping_worked = True
+ raise BadRequestError(
+ message=f"BedrockException - {error_str}\n. Pass in default user message via `completion(..,user_continue_message=)` or enable `litellm.modify_params=True`.\nFor Proxy: do via `litellm_settings::modify_params: True` or user_continue_message under `litellm_params`",
+ model=model,
+ llm_provider="bedrock",
+ response=original_exception.response,
+ )
elif (
"Unable to locate credentials" in error_str
or "The security token included in the request is invalid"
diff --git a/tests/otel_tests/test_guardrails.py b/tests/otel_tests/test_guardrails.py
index 7e9ff613a..34f14186e 100644
--- a/tests/otel_tests/test_guardrails.py
+++ b/tests/otel_tests/test_guardrails.py
@@ -144,6 +144,7 @@ async def test_no_llm_guard_triggered():
assert "x-litellm-applied-guardrails" not in headers
+
@pytest.mark.asyncio
async def test_guardrails_with_api_key_controls():
"""
@@ -194,3 +195,25 @@ async def test_guardrails_with_api_key_controls():
except Exception as e:
print(e)
assert "Aporia detected and blocked PII" in str(e)
+
+
+@pytest.mark.asyncio
+async def test_bedrock_guardrail_triggered():
+ """
+ - Tests a request where our bedrock guardrail should be triggered
+ - Assert that the guardrails applied are returned in the response headers
+ """
+ async with aiohttp.ClientSession() as session:
+ try:
+ response, headers = await chat_completion(
+ session,
+ "sk-1234",
+ model="fake-openai-endpoint",
+ messages=[{"role": "user", "content": f"Hello do you like coffee?"}],
+ guardrails=["bedrock-pre-guard"],
+ )
+ pytest.fail("Should have thrown an exception")
+ except Exception as e:
+ print(e)
+ assert "GUARDRAIL_INTERVENED" in str(e)
+ assert "Violated guardrail policy" in str(e)