forked from phoenix/litellm-mirror
Merge branch 'main' into litellm_allow_using_azure_ad_token_auth
This commit is contained in:
commit
228252b92d
33 changed files with 802 additions and 84 deletions
|
@ -321,6 +321,9 @@ jobs:
|
||||||
-e APORIA_API_BASE_2=$APORIA_API_BASE_2 \
|
-e APORIA_API_BASE_2=$APORIA_API_BASE_2 \
|
||||||
-e APORIA_API_KEY_2=$APORIA_API_KEY_2 \
|
-e APORIA_API_KEY_2=$APORIA_API_KEY_2 \
|
||||||
-e APORIA_API_BASE_1=$APORIA_API_BASE_1 \
|
-e APORIA_API_BASE_1=$APORIA_API_BASE_1 \
|
||||||
|
-e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
|
||||||
|
-e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \
|
||||||
|
-e AWS_REGION_NAME=$AWS_REGION_NAME \
|
||||||
-e APORIA_API_KEY_1=$APORIA_API_KEY_1 \
|
-e APORIA_API_KEY_1=$APORIA_API_KEY_1 \
|
||||||
--name my-app \
|
--name my-app \
|
||||||
-v $(pwd)/litellm/proxy/example_config_yaml/otel_test_config.yaml:/app/config.yaml \
|
-v $(pwd)/litellm/proxy/example_config_yaml/otel_test_config.yaml:/app/config.yaml \
|
||||||
|
|
3
docs/my-website/docs/projects/dbally.md
Normal file
3
docs/my-website/docs/projects/dbally.md
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
Efficient, consistent and secure library for querying structured data with natural language. Query any database with over 100 LLMs ❤️ 🚅.
|
||||||
|
|
||||||
|
🔗 [GitHub](https://github.com/deepsense-ai/db-ally)
|
|
@ -307,8 +307,9 @@ LiteLLM supports **ALL** azure ai models. Here's a few examples:
|
||||||
|
|
||||||
| Model Name | Function Call |
|
| Model Name | Function Call |
|
||||||
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||||
| Cohere command-r-plus | `completion(model="azure/command-r-plus", messages)` |
|
| Cohere command-r-plus | `completion(model="azure_ai/command-r-plus", messages)` |
|
||||||
| Cohere command-r | `completion(model="azure/command-r", messages)` |
|
| Cohere command-r | `completion(model="azure_ai/command-r", messages)` |
|
||||||
| mistral-large-latest | `completion(model="azure/mistral-large-latest", messages)` |
|
| mistral-large-latest | `completion(model="azure_ai/mistral-large-latest", messages)` |
|
||||||
|
| AI21-Jamba-Instruct | `completion(model="azure_ai/ai21-jamba-instruct", messages)` |
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -727,6 +727,7 @@ general_settings:
|
||||||
"completion_model": "string",
|
"completion_model": "string",
|
||||||
"disable_spend_logs": "boolean", # turn off writing each transaction to the db
|
"disable_spend_logs": "boolean", # turn off writing each transaction to the db
|
||||||
"disable_master_key_return": "boolean", # turn off returning master key on UI (checked on '/user/info' endpoint)
|
"disable_master_key_return": "boolean", # turn off returning master key on UI (checked on '/user/info' endpoint)
|
||||||
|
"disable_retry_on_max_parallel_request_limit_error": "boolean", # turn off retries when max parallel request limit is reached
|
||||||
"disable_reset_budget": "boolean", # turn off reset budget scheduled task
|
"disable_reset_budget": "boolean", # turn off reset budget scheduled task
|
||||||
"disable_adding_master_key_hash_to_db": "boolean", # turn off storing master key hash in db, for spend tracking
|
"disable_adding_master_key_hash_to_db": "boolean", # turn off storing master key hash in db, for spend tracking
|
||||||
"enable_jwt_auth": "boolean", # allow proxy admin to auth in via jwt tokens with 'litellm_proxy_admin' in claims
|
"enable_jwt_auth": "boolean", # allow proxy admin to auth in via jwt tokens with 'litellm_proxy_admin' in claims
|
||||||
|
@ -751,7 +752,8 @@ general_settings:
|
||||||
},
|
},
|
||||||
"otel": true,
|
"otel": true,
|
||||||
"custom_auth": "string",
|
"custom_auth": "string",
|
||||||
"max_parallel_requests": 0,
|
"max_parallel_requests": 0, # the max parallel requests allowed per deployment
|
||||||
|
"global_max_parallel_requests": 0, # the max parallel requests allowed on the proxy all up
|
||||||
"infer_model_from_keys": true,
|
"infer_model_from_keys": true,
|
||||||
"background_health_checks": true,
|
"background_health_checks": true,
|
||||||
"health_check_interval": 300,
|
"health_check_interval": 300,
|
||||||
|
|
135
docs/my-website/docs/proxy/guardrails/bedrock.md
Normal file
135
docs/my-website/docs/proxy/guardrails/bedrock.md
Normal file
|
@ -0,0 +1,135 @@
|
||||||
|
import Image from '@theme/IdealImage';
|
||||||
|
import Tabs from '@theme/Tabs';
|
||||||
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
|
# Bedrock
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
### 1. Define Guardrails on your LiteLLM config.yaml
|
||||||
|
|
||||||
|
Define your guardrails under the `guardrails` section
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: gpt-3.5-turbo
|
||||||
|
litellm_params:
|
||||||
|
model: openai/gpt-3.5-turbo
|
||||||
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
|
|
||||||
|
guardrails:
|
||||||
|
- guardrail_name: "bedrock-pre-guard"
|
||||||
|
litellm_params:
|
||||||
|
guardrail: bedrock # supported values: "aporia", "bedrock", "lakera"
|
||||||
|
mode: "during_call"
|
||||||
|
guardrailIdentifier: ff6ujrregl1q # your guardrail ID on bedrock
|
||||||
|
guardrailVersion: "DRAFT" # your guardrail version on bedrock
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Supported values for `mode`
|
||||||
|
|
||||||
|
- `pre_call` Run **before** LLM call, on **input**
|
||||||
|
- `post_call` Run **after** LLM call, on **input & output**
|
||||||
|
- `during_call` Run **during** LLM call, on **input** Same as `pre_call` but runs in parallel as LLM call. Response not returned until guardrail check completes
|
||||||
|
|
||||||
|
### 2. Start LiteLLM Gateway
|
||||||
|
|
||||||
|
|
||||||
|
```shell
|
||||||
|
litellm --config config.yaml --detailed_debug
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Test request
|
||||||
|
|
||||||
|
**[Langchain, OpenAI SDK Usage Examples](../proxy/user_keys##request-format)**
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem label="Unsuccessful call" value = "not-allowed">
|
||||||
|
|
||||||
|
Expect this to fail since since `ishaan@berri.ai` in the request is PII
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl -i http://localhost:4000/v1/chat/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-H "Authorization: Bearer sk-npnwjPQciVRok5yNZgKmFQ" \
|
||||||
|
-d '{
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "hi my email is ishaan@berri.ai"}
|
||||||
|
],
|
||||||
|
"guardrails": ["bedrock-guard"]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected response on failure
|
||||||
|
|
||||||
|
```shell
|
||||||
|
{
|
||||||
|
"error": {
|
||||||
|
"message": {
|
||||||
|
"error": "Violated guardrail policy",
|
||||||
|
"bedrock_guardrail_response": {
|
||||||
|
"action": "GUARDRAIL_INTERVENED",
|
||||||
|
"assessments": [
|
||||||
|
{
|
||||||
|
"topicPolicy": {
|
||||||
|
"topics": [
|
||||||
|
{
|
||||||
|
"action": "BLOCKED",
|
||||||
|
"name": "Coffee",
|
||||||
|
"type": "DENY"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"blockedResponse": "Sorry, the model cannot answer this question. coffee guardrail applied ",
|
||||||
|
"output": [
|
||||||
|
{
|
||||||
|
"text": "Sorry, the model cannot answer this question. coffee guardrail applied "
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"text": "Sorry, the model cannot answer this question. coffee guardrail applied "
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"contentPolicyUnits": 0,
|
||||||
|
"contextualGroundingPolicyUnits": 0,
|
||||||
|
"sensitiveInformationPolicyFreeUnits": 0,
|
||||||
|
"sensitiveInformationPolicyUnits": 0,
|
||||||
|
"topicPolicyUnits": 1,
|
||||||
|
"wordPolicyUnits": 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": "None",
|
||||||
|
"param": "None",
|
||||||
|
"code": "400"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem label="Successful Call " value = "allowed">
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl -i http://localhost:4000/v1/chat/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-H "Authorization: Bearer sk-npnwjPQciVRok5yNZgKmFQ" \
|
||||||
|
-d '{
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "hi what is the weather"}
|
||||||
|
],
|
||||||
|
"guardrails": ["bedrock-guard"]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
|
||||||
|
</Tabs>
|
||||||
|
|
|
@ -68,6 +68,15 @@ http://localhost:4000/metrics
|
||||||
| `litellm_total_tokens` | input + output tokens per `"user", "key", "model", "team", "end-user"` |
|
| `litellm_total_tokens` | input + output tokens per `"user", "key", "model", "team", "end-user"` |
|
||||||
| `litellm_llm_api_failed_requests_metric` | Number of failed LLM API requests per `"user", "key", "model", "team", "end-user"` |
|
| `litellm_llm_api_failed_requests_metric` | Number of failed LLM API requests per `"user", "key", "model", "team", "end-user"` |
|
||||||
|
|
||||||
|
### Request Latency Metrics
|
||||||
|
|
||||||
|
| Metric Name | Description |
|
||||||
|
|----------------------|--------------------------------------|
|
||||||
|
| `litellm_request_total_latency_metric` | Total latency (seconds) for a request to LiteLLM Proxy Server - tracked for labels `litellm_call_id`, `model` |
|
||||||
|
| `litellm_llm_api_latency_metric` | latency (seconds) for just the LLM API call - tracked for labels `litellm_call_id`, `model` |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### LLM API / Provider Metrics
|
### LLM API / Provider Metrics
|
||||||
|
|
||||||
| Metric Name | Description |
|
| Metric Name | Description |
|
||||||
|
|
|
@ -54,7 +54,7 @@ const sidebars = {
|
||||||
{
|
{
|
||||||
type: "category",
|
type: "category",
|
||||||
label: "🛡️ [Beta] Guardrails",
|
label: "🛡️ [Beta] Guardrails",
|
||||||
items: ["proxy/guardrails/quick_start", "proxy/guardrails/aporia_api", "proxy/guardrails/lakera_ai"],
|
items: ["proxy/guardrails/quick_start", "proxy/guardrails/aporia_api", "proxy/guardrails/lakera_ai", "proxy/guardrails/bedrock"],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
type: "category",
|
type: "category",
|
||||||
|
@ -292,6 +292,7 @@ const sidebars = {
|
||||||
items: [
|
items: [
|
||||||
"projects/Docq.AI",
|
"projects/Docq.AI",
|
||||||
"projects/OpenInterpreter",
|
"projects/OpenInterpreter",
|
||||||
|
"projects/dbally",
|
||||||
"projects/FastREPL",
|
"projects/FastREPL",
|
||||||
"projects/PROMPTMETHEUS",
|
"projects/PROMPTMETHEUS",
|
||||||
"projects/Codium PR Agent",
|
"projects/Codium PR Agent",
|
||||||
|
|
8
litellm-js/spend-logs/package-lock.json
generated
8
litellm-js/spend-logs/package-lock.json
generated
|
@ -6,7 +6,7 @@
|
||||||
"": {
|
"": {
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@hono/node-server": "^1.10.1",
|
"@hono/node-server": "^1.10.1",
|
||||||
"hono": "^4.2.7"
|
"hono": "^4.5.8"
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
"@types/node": "^20.11.17",
|
"@types/node": "^20.11.17",
|
||||||
|
@ -463,9 +463,9 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"node_modules/hono": {
|
"node_modules/hono": {
|
||||||
"version": "4.2.7",
|
"version": "4.5.8",
|
||||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.2.7.tgz",
|
"resolved": "https://registry.npmjs.org/hono/-/hono-4.5.8.tgz",
|
||||||
"integrity": "sha512-k1xHi86tJnRIVvqhFMBDGFKJ8r5O+bEsT4P59ZK59r0F300Xd910/r237inVfuT/VmE86RQQffX4OYNda6dLXw==",
|
"integrity": "sha512-pqpSlcdqGkpTTRpLYU1PnCz52gVr0zVR9H5GzMyJWuKQLLEBQxh96q45QizJ2PPX8NATtz2mu31/PKW/Jt+90Q==",
|
||||||
"engines": {
|
"engines": {
|
||||||
"node": ">=16.0.0"
|
"node": ">=16.0.0"
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@hono/node-server": "^1.10.1",
|
"@hono/node-server": "^1.10.1",
|
||||||
"hono": "^4.2.7"
|
"hono": "^4.5.8"
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
"@types/node": "^20.11.17",
|
"@types/node": "^20.11.17",
|
||||||
|
|
|
@ -7,13 +7,17 @@
|
||||||
#
|
#
|
||||||
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
|
||||||
# s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation
|
# s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation
|
||||||
import os
|
import os
|
||||||
import inspect
|
|
||||||
import redis, litellm # type: ignore
|
|
||||||
import redis.asyncio as async_redis # type: ignore
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import redis # type: ignore
|
||||||
|
import redis.asyncio as async_redis # type: ignore
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
|
||||||
def _get_redis_kwargs():
|
def _get_redis_kwargs():
|
||||||
arg_spec = inspect.getfullargspec(redis.Redis)
|
arg_spec = inspect.getfullargspec(redis.Redis)
|
||||||
|
@ -51,6 +55,19 @@ def _get_redis_url_kwargs(client=None):
|
||||||
return available_args
|
return available_args
|
||||||
|
|
||||||
|
|
||||||
|
def _get_redis_cluster_kwargs(client=None):
|
||||||
|
if client is None:
|
||||||
|
client = redis.Redis.from_url
|
||||||
|
arg_spec = inspect.getfullargspec(redis.RedisCluster)
|
||||||
|
|
||||||
|
# Only allow primitive arguments
|
||||||
|
exclude_args = {"self", "connection_pool", "retry", "host", "port", "startup_nodes"}
|
||||||
|
|
||||||
|
available_args = [x for x in arg_spec.args if x not in exclude_args]
|
||||||
|
|
||||||
|
return available_args
|
||||||
|
|
||||||
|
|
||||||
def _get_redis_env_kwarg_mapping():
|
def _get_redis_env_kwarg_mapping():
|
||||||
PREFIX = "REDIS_"
|
PREFIX = "REDIS_"
|
||||||
|
|
||||||
|
@ -124,6 +141,22 @@ def get_redis_client(**env_overrides):
|
||||||
url_kwargs[arg] = redis_kwargs[arg]
|
url_kwargs[arg] = redis_kwargs[arg]
|
||||||
|
|
||||||
return redis.Redis.from_url(**url_kwargs)
|
return redis.Redis.from_url(**url_kwargs)
|
||||||
|
|
||||||
|
if "startup_nodes" in redis_kwargs:
|
||||||
|
from redis.cluster import ClusterNode
|
||||||
|
|
||||||
|
args = _get_redis_cluster_kwargs()
|
||||||
|
cluster_kwargs = {}
|
||||||
|
for arg in redis_kwargs:
|
||||||
|
if arg in args:
|
||||||
|
cluster_kwargs[arg] = redis_kwargs[arg]
|
||||||
|
|
||||||
|
new_startup_nodes: List[ClusterNode] = []
|
||||||
|
|
||||||
|
for item in redis_kwargs["startup_nodes"]:
|
||||||
|
new_startup_nodes.append(ClusterNode(**item))
|
||||||
|
redis_kwargs.pop("startup_nodes")
|
||||||
|
return redis.RedisCluster(startup_nodes=new_startup_nodes, **cluster_kwargs)
|
||||||
return redis.Redis(**redis_kwargs)
|
return redis.Redis(**redis_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@ -143,6 +176,24 @@ def get_redis_async_client(**env_overrides):
|
||||||
)
|
)
|
||||||
return async_redis.Redis.from_url(**url_kwargs)
|
return async_redis.Redis.from_url(**url_kwargs)
|
||||||
|
|
||||||
|
if "startup_nodes" in redis_kwargs:
|
||||||
|
from redis.cluster import ClusterNode
|
||||||
|
|
||||||
|
args = _get_redis_cluster_kwargs()
|
||||||
|
cluster_kwargs = {}
|
||||||
|
for arg in redis_kwargs:
|
||||||
|
if arg in args:
|
||||||
|
cluster_kwargs[arg] = redis_kwargs[arg]
|
||||||
|
|
||||||
|
new_startup_nodes: List[ClusterNode] = []
|
||||||
|
|
||||||
|
for item in redis_kwargs["startup_nodes"]:
|
||||||
|
new_startup_nodes.append(ClusterNode(**item))
|
||||||
|
redis_kwargs.pop("startup_nodes")
|
||||||
|
return async_redis.RedisCluster(
|
||||||
|
startup_nodes=new_startup_nodes, **cluster_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
return async_redis.Redis(
|
return async_redis.Redis(
|
||||||
socket_timeout=5,
|
socket_timeout=5,
|
||||||
**redis_kwargs,
|
**redis_kwargs,
|
||||||
|
@ -160,4 +211,5 @@ def get_redis_connection_pool(**env_overrides):
|
||||||
connection_class = async_redis.SSLConnection
|
connection_class = async_redis.SSLConnection
|
||||||
redis_kwargs.pop("ssl", None)
|
redis_kwargs.pop("ssl", None)
|
||||||
redis_kwargs["connection_class"] = connection_class
|
redis_kwargs["connection_class"] = connection_class
|
||||||
|
redis_kwargs.pop("startup_nodes", None)
|
||||||
return async_redis.BlockingConnectionPool(timeout=5, **redis_kwargs)
|
return async_redis.BlockingConnectionPool(timeout=5, **redis_kwargs)
|
||||||
|
|
|
@ -203,6 +203,7 @@ class RedisCache(BaseCache):
|
||||||
password=None,
|
password=None,
|
||||||
redis_flush_size=100,
|
redis_flush_size=100,
|
||||||
namespace: Optional[str] = None,
|
namespace: Optional[str] = None,
|
||||||
|
startup_nodes: Optional[List] = None, # for redis-cluster
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
import redis
|
import redis
|
||||||
|
@ -218,7 +219,8 @@ class RedisCache(BaseCache):
|
||||||
redis_kwargs["port"] = port
|
redis_kwargs["port"] = port
|
||||||
if password is not None:
|
if password is not None:
|
||||||
redis_kwargs["password"] = password
|
redis_kwargs["password"] = password
|
||||||
|
if startup_nodes is not None:
|
||||||
|
redis_kwargs["startup_nodes"] = startup_nodes
|
||||||
### HEALTH MONITORING OBJECT ###
|
### HEALTH MONITORING OBJECT ###
|
||||||
if kwargs.get("service_logger_obj", None) is not None and isinstance(
|
if kwargs.get("service_logger_obj", None) is not None and isinstance(
|
||||||
kwargs["service_logger_obj"], ServiceLogging
|
kwargs["service_logger_obj"], ServiceLogging
|
||||||
|
@ -246,7 +248,7 @@ class RedisCache(BaseCache):
|
||||||
### ASYNC HEALTH PING ###
|
### ASYNC HEALTH PING ###
|
||||||
try:
|
try:
|
||||||
# asyncio.get_running_loop().create_task(self.ping())
|
# asyncio.get_running_loop().create_task(self.ping())
|
||||||
result = asyncio.get_running_loop().create_task(self.ping())
|
_ = asyncio.get_running_loop().create_task(self.ping())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if "no running event loop" in str(e):
|
if "no running event loop" in str(e):
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
|
@ -2123,6 +2125,7 @@ class Cache:
|
||||||
redis_semantic_cache_use_async=False,
|
redis_semantic_cache_use_async=False,
|
||||||
redis_semantic_cache_embedding_model="text-embedding-ada-002",
|
redis_semantic_cache_embedding_model="text-embedding-ada-002",
|
||||||
redis_flush_size=None,
|
redis_flush_size=None,
|
||||||
|
redis_startup_nodes: Optional[List] = None,
|
||||||
disk_cache_dir=None,
|
disk_cache_dir=None,
|
||||||
qdrant_api_base: Optional[str] = None,
|
qdrant_api_base: Optional[str] = None,
|
||||||
qdrant_api_key: Optional[str] = None,
|
qdrant_api_key: Optional[str] = None,
|
||||||
|
@ -2155,7 +2158,12 @@ class Cache:
|
||||||
"""
|
"""
|
||||||
if type == "redis":
|
if type == "redis":
|
||||||
self.cache: BaseCache = RedisCache(
|
self.cache: BaseCache = RedisCache(
|
||||||
host, port, password, redis_flush_size, **kwargs
|
host,
|
||||||
|
port,
|
||||||
|
password,
|
||||||
|
redis_flush_size,
|
||||||
|
startup_nodes=redis_startup_nodes,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
elif type == "redis-semantic":
|
elif type == "redis-semantic":
|
||||||
self.cache = RedisSemanticCache(
|
self.cache = RedisSemanticCache(
|
||||||
|
|
|
@ -60,6 +60,25 @@ class PrometheusLogger(CustomLogger):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# request latency metrics
|
||||||
|
self.litellm_request_total_latency_metric = Histogram(
|
||||||
|
"litellm_request_total_latency_metric",
|
||||||
|
"Total latency (seconds) for a request to LiteLLM",
|
||||||
|
labelnames=[
|
||||||
|
"model",
|
||||||
|
"litellm_call_id",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.litellm_llm_api_latency_metric = Histogram(
|
||||||
|
"litellm_llm_api_latency_metric",
|
||||||
|
"Total latency (seconds) for a models LLM API call",
|
||||||
|
labelnames=[
|
||||||
|
"model",
|
||||||
|
"litellm_call_id",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# Counter for spend
|
# Counter for spend
|
||||||
self.litellm_spend_metric = Counter(
|
self.litellm_spend_metric = Counter(
|
||||||
"litellm_spend_metric",
|
"litellm_spend_metric",
|
||||||
|
@ -103,8 +122,6 @@ class PrometheusLogger(CustomLogger):
|
||||||
"Remaining budget for api key",
|
"Remaining budget for api key",
|
||||||
labelnames=["hashed_api_key", "api_key_alias"],
|
labelnames=["hashed_api_key", "api_key_alias"],
|
||||||
)
|
)
|
||||||
# Litellm-Enterprise Metrics
|
|
||||||
if premium_user is True:
|
|
||||||
|
|
||||||
########################################
|
########################################
|
||||||
# LiteLLM Virtual API KEY metrics
|
# LiteLLM Virtual API KEY metrics
|
||||||
|
@ -123,6 +140,9 @@ class PrometheusLogger(CustomLogger):
|
||||||
labelnames=["hashed_api_key", "api_key_alias", "model"],
|
labelnames=["hashed_api_key", "api_key_alias", "model"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Litellm-Enterprise Metrics
|
||||||
|
if premium_user is True:
|
||||||
|
|
||||||
########################################
|
########################################
|
||||||
# LLM API Deployment Metrics / analytics
|
# LLM API Deployment Metrics / analytics
|
||||||
########################################
|
########################################
|
||||||
|
@ -328,6 +348,25 @@ class PrometheusLogger(CustomLogger):
|
||||||
user_api_key, user_api_key_alias, model_group
|
user_api_key, user_api_key_alias, model_group
|
||||||
).set(remaining_tokens)
|
).set(remaining_tokens)
|
||||||
|
|
||||||
|
# latency metrics
|
||||||
|
total_time: timedelta = kwargs.get("end_time") - kwargs.get("start_time")
|
||||||
|
total_time_seconds = total_time.total_seconds()
|
||||||
|
api_call_total_time: timedelta = kwargs.get("end_time") - kwargs.get(
|
||||||
|
"api_call_start_time"
|
||||||
|
)
|
||||||
|
|
||||||
|
api_call_total_time_seconds = api_call_total_time.total_seconds()
|
||||||
|
|
||||||
|
litellm_call_id = kwargs.get("litellm_call_id")
|
||||||
|
|
||||||
|
self.litellm_request_total_latency_metric.labels(
|
||||||
|
model, litellm_call_id
|
||||||
|
).observe(total_time_seconds)
|
||||||
|
|
||||||
|
self.litellm_llm_api_latency_metric.labels(model, litellm_call_id).observe(
|
||||||
|
api_call_total_time_seconds
|
||||||
|
)
|
||||||
|
|
||||||
# set x-ratelimit headers
|
# set x-ratelimit headers
|
||||||
if premium_user is True:
|
if premium_user is True:
|
||||||
self.set_llm_deployment_success_metrics(
|
self.set_llm_deployment_success_metrics(
|
||||||
|
|
|
@ -354,6 +354,8 @@ class Logging:
|
||||||
str(e)
|
str(e)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.model_call_details["api_call_start_time"] = datetime.datetime.now()
|
||||||
# Input Integration Logging -> If you want to log the fact that an attempt to call the model was made
|
# Input Integration Logging -> If you want to log the fact that an attempt to call the model was made
|
||||||
callbacks = litellm.input_callback + self.dynamic_input_callbacks
|
callbacks = litellm.input_callback + self.dynamic_input_callbacks
|
||||||
for callback in callbacks:
|
for callback in callbacks:
|
||||||
|
|
|
@ -84,6 +84,7 @@ BEDROCK_CONVERSE_MODELS = [
|
||||||
"meta.llama3-1-8b-instruct-v1:0",
|
"meta.llama3-1-8b-instruct-v1:0",
|
||||||
"meta.llama3-1-70b-instruct-v1:0",
|
"meta.llama3-1-70b-instruct-v1:0",
|
||||||
"meta.llama3-1-405b-instruct-v1:0",
|
"meta.llama3-1-405b-instruct-v1:0",
|
||||||
|
"meta.llama3-70b-instruct-v1:0",
|
||||||
"mistral.mistral-large-2407-v1:0",
|
"mistral.mistral-large-2407-v1:0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -1480,7 +1481,7 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
acompletion: bool,
|
acompletion: bool,
|
||||||
timeout: Optional[Union[float, httpx.Timeout]],
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
litellm_params=None,
|
litellm_params: dict,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
extra_headers: Optional[dict] = None,
|
extra_headers: Optional[dict] = None,
|
||||||
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||||
|
@ -1596,6 +1597,14 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
supported_tool_call_params = ["tools", "tool_choice"]
|
supported_tool_call_params = ["tools", "tool_choice"]
|
||||||
supported_guardrail_params = ["guardrailConfig"]
|
supported_guardrail_params = ["guardrailConfig"]
|
||||||
## TRANSFORMATION ##
|
## TRANSFORMATION ##
|
||||||
|
|
||||||
|
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
|
||||||
|
messages=messages,
|
||||||
|
model=model,
|
||||||
|
llm_provider="bedrock_converse",
|
||||||
|
user_continue_message=litellm_params.pop("user_continue_message", None),
|
||||||
|
)
|
||||||
|
|
||||||
# send all model-specific params in 'additional_request_params'
|
# send all model-specific params in 'additional_request_params'
|
||||||
for k, v in inference_params.items():
|
for k, v in inference_params.items():
|
||||||
if (
|
if (
|
||||||
|
@ -1608,11 +1617,6 @@ class BedrockConverseLLM(BaseAWSLLM):
|
||||||
for key in additional_request_keys:
|
for key in additional_request_keys:
|
||||||
inference_params.pop(key, None)
|
inference_params.pop(key, None)
|
||||||
|
|
||||||
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
|
|
||||||
messages=messages,
|
|
||||||
model=model,
|
|
||||||
llm_provider="bedrock_converse",
|
|
||||||
)
|
|
||||||
bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
|
bedrock_tools: List[ToolBlock] = _bedrock_tools_pt(
|
||||||
inference_params.pop("tools", [])
|
inference_params.pop("tools", [])
|
||||||
)
|
)
|
||||||
|
|
|
@ -124,12 +124,14 @@ class CohereConfig:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def validate_environment(api_key):
|
def validate_environment(api_key, headers: dict):
|
||||||
headers = {
|
headers.update(
|
||||||
|
{
|
||||||
"Request-Source": "unspecified:litellm",
|
"Request-Source": "unspecified:litellm",
|
||||||
"accept": "application/json",
|
"accept": "application/json",
|
||||||
"content-type": "application/json",
|
"content-type": "application/json",
|
||||||
}
|
}
|
||||||
|
)
|
||||||
if api_key:
|
if api_key:
|
||||||
headers["Authorization"] = f"Bearer {api_key}"
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
return headers
|
return headers
|
||||||
|
@ -144,11 +146,12 @@ def completion(
|
||||||
encoding,
|
encoding,
|
||||||
api_key,
|
api_key,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
|
headers: dict,
|
||||||
optional_params=None,
|
optional_params=None,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
):
|
):
|
||||||
headers = validate_environment(api_key)
|
headers = validate_environment(api_key, headers=headers)
|
||||||
completion_url = api_base
|
completion_url = api_base
|
||||||
model = model
|
model = model
|
||||||
prompt = " ".join(message["content"] for message in messages)
|
prompt = " ".join(message["content"] for message in messages)
|
||||||
|
@ -338,13 +341,14 @@ def embedding(
|
||||||
model_response: litellm.EmbeddingResponse,
|
model_response: litellm.EmbeddingResponse,
|
||||||
logging_obj: LiteLLMLoggingObj,
|
logging_obj: LiteLLMLoggingObj,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
headers: dict,
|
||||||
encoding: Any,
|
encoding: Any,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
aembedding: Optional[bool] = None,
|
aembedding: Optional[bool] = None,
|
||||||
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
|
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
|
||||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
):
|
):
|
||||||
headers = validate_environment(api_key)
|
headers = validate_environment(api_key, headers=headers)
|
||||||
embed_url = "https://api.cohere.ai/v1/embed"
|
embed_url = "https://api.cohere.ai/v1/embed"
|
||||||
model = model
|
model = model
|
||||||
data = {"model": model, "texts": input, **optional_params}
|
data = {"model": model, "texts": input, **optional_params}
|
||||||
|
|
|
@ -116,12 +116,14 @@ class CohereChatConfig:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def validate_environment(api_key):
|
def validate_environment(api_key, headers: dict):
|
||||||
headers = {
|
headers.update(
|
||||||
|
{
|
||||||
"Request-Source": "unspecified:litellm",
|
"Request-Source": "unspecified:litellm",
|
||||||
"accept": "application/json",
|
"accept": "application/json",
|
||||||
"content-type": "application/json",
|
"content-type": "application/json",
|
||||||
}
|
}
|
||||||
|
)
|
||||||
if api_key:
|
if api_key:
|
||||||
headers["Authorization"] = f"Bearer {api_key}"
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
return headers
|
return headers
|
||||||
|
@ -203,13 +205,14 @@ def completion(
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
print_verbose: Callable,
|
print_verbose: Callable,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
headers: dict,
|
||||||
encoding,
|
encoding,
|
||||||
api_key,
|
api_key,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
):
|
):
|
||||||
headers = validate_environment(api_key)
|
headers = validate_environment(api_key, headers=headers)
|
||||||
completion_url = api_base
|
completion_url = api_base
|
||||||
model = model
|
model = model
|
||||||
most_recent_message, chat_history = cohere_messages_pt_v2(
|
most_recent_message, chat_history = cohere_messages_pt_v2(
|
||||||
|
|
|
@ -38,6 +38,18 @@ def prompt_injection_detection_default_pt():
|
||||||
|
|
||||||
BAD_MESSAGE_ERROR_STR = "Invalid Message "
|
BAD_MESSAGE_ERROR_STR = "Invalid Message "
|
||||||
|
|
||||||
|
# used to interweave user messages, to ensure user/assistant alternating
|
||||||
|
DEFAULT_USER_CONTINUE_MESSAGE = {
|
||||||
|
"role": "user",
|
||||||
|
"content": "Please continue.",
|
||||||
|
} # similar to autogen. Only used if `litellm.modify_params=True`.
|
||||||
|
|
||||||
|
# used to interweave assistant messages, to ensure user/assistant alternating
|
||||||
|
DEFAULT_ASSISTANT_CONTINUE_MESSAGE = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Please continue.",
|
||||||
|
} # similar to autogen. Only used if `litellm.modify_params=True`.
|
||||||
|
|
||||||
|
|
||||||
def map_system_message_pt(messages: list) -> list:
|
def map_system_message_pt(messages: list) -> list:
|
||||||
"""
|
"""
|
||||||
|
@ -2254,6 +2266,7 @@ def _bedrock_converse_messages_pt(
|
||||||
messages: List,
|
messages: List,
|
||||||
model: str,
|
model: str,
|
||||||
llm_provider: str,
|
llm_provider: str,
|
||||||
|
user_continue_message: Optional[dict] = None,
|
||||||
) -> List[BedrockMessageBlock]:
|
) -> List[BedrockMessageBlock]:
|
||||||
"""
|
"""
|
||||||
Converts given messages from OpenAI format to Bedrock format
|
Converts given messages from OpenAI format to Bedrock format
|
||||||
|
@ -2264,6 +2277,21 @@ def _bedrock_converse_messages_pt(
|
||||||
|
|
||||||
contents: List[BedrockMessageBlock] = []
|
contents: List[BedrockMessageBlock] = []
|
||||||
msg_i = 0
|
msg_i = 0
|
||||||
|
|
||||||
|
# if initial message is assistant message
|
||||||
|
if messages[0].get("role") is not None and messages[0]["role"] == "assistant":
|
||||||
|
if user_continue_message is not None:
|
||||||
|
messages.insert(0, user_continue_message)
|
||||||
|
elif litellm.modify_params:
|
||||||
|
messages.insert(0, DEFAULT_USER_CONTINUE_MESSAGE)
|
||||||
|
|
||||||
|
# if final message is assistant message
|
||||||
|
if messages[-1].get("role") is not None and messages[-1]["role"] == "assistant":
|
||||||
|
if user_continue_message is not None:
|
||||||
|
messages.append(user_continue_message)
|
||||||
|
elif litellm.modify_params:
|
||||||
|
messages.append(DEFAULT_USER_CONTINUE_MESSAGE)
|
||||||
|
|
||||||
while msg_i < len(messages):
|
while msg_i < len(messages):
|
||||||
user_content: List[BedrockContentBlock] = []
|
user_content: List[BedrockContentBlock] = []
|
||||||
init_msg_i = msg_i
|
init_msg_i = msg_i
|
||||||
|
@ -2344,6 +2372,7 @@ def _bedrock_converse_messages_pt(
|
||||||
model=model,
|
model=model,
|
||||||
llm_provider=llm_provider,
|
llm_provider=llm_provider,
|
||||||
)
|
)
|
||||||
|
|
||||||
return contents
|
return contents
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -944,6 +944,8 @@ def completion(
|
||||||
cooldown_time=cooldown_time,
|
cooldown_time=cooldown_time,
|
||||||
text_completion=kwargs.get("text_completion"),
|
text_completion=kwargs.get("text_completion"),
|
||||||
azure_ad_token_provider=kwargs.get("azure_ad_token_provider"),
|
azure_ad_token_provider=kwargs.get("azure_ad_token_provider"),
|
||||||
|
user_continue_message=kwargs.get("user_continue_message"),
|
||||||
|
|
||||||
)
|
)
|
||||||
logging.update_environment_variables(
|
logging.update_environment_variables(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -1635,6 +1637,13 @@ def completion(
|
||||||
or "https://api.cohere.ai/v1/generate"
|
or "https://api.cohere.ai/v1/generate"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
headers = headers or litellm.headers or {}
|
||||||
|
if headers is None:
|
||||||
|
headers = {}
|
||||||
|
|
||||||
|
if extra_headers is not None:
|
||||||
|
headers.update(extra_headers)
|
||||||
|
|
||||||
model_response = cohere.completion(
|
model_response = cohere.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -1645,6 +1654,7 @@ def completion(
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
|
headers=headers,
|
||||||
api_key=cohere_key,
|
api_key=cohere_key,
|
||||||
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
||||||
)
|
)
|
||||||
|
@ -1675,6 +1685,13 @@ def completion(
|
||||||
or "https://api.cohere.ai/v1/chat"
|
or "https://api.cohere.ai/v1/chat"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
headers = headers or litellm.headers or {}
|
||||||
|
if headers is None:
|
||||||
|
headers = {}
|
||||||
|
|
||||||
|
if extra_headers is not None:
|
||||||
|
headers.update(extra_headers)
|
||||||
|
|
||||||
model_response = cohere_chat.completion(
|
model_response = cohere_chat.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -1683,6 +1700,7 @@ def completion(
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params,
|
||||||
|
headers=headers,
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
api_key=cohere_key,
|
api_key=cohere_key,
|
||||||
|
@ -2289,7 +2307,7 @@ def completion(
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params, # type: ignore
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
|
@ -3159,6 +3177,7 @@ def embedding(
|
||||||
encoding_format = kwargs.get("encoding_format", None)
|
encoding_format = kwargs.get("encoding_format", None)
|
||||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||||
aembedding = kwargs.get("aembedding", None)
|
aembedding = kwargs.get("aembedding", None)
|
||||||
|
extra_headers = kwargs.get("extra_headers", None)
|
||||||
### CUSTOM MODEL COST ###
|
### CUSTOM MODEL COST ###
|
||||||
input_cost_per_token = kwargs.get("input_cost_per_token", None)
|
input_cost_per_token = kwargs.get("input_cost_per_token", None)
|
||||||
output_cost_per_token = kwargs.get("output_cost_per_token", None)
|
output_cost_per_token = kwargs.get("output_cost_per_token", None)
|
||||||
|
@ -3234,6 +3253,7 @@ def embedding(
|
||||||
"tenant_id",
|
"tenant_id",
|
||||||
"client_id",
|
"client_id",
|
||||||
"client_secret",
|
"client_secret",
|
||||||
|
"extra_headers",
|
||||||
]
|
]
|
||||||
default_params = openai_params + litellm_params
|
default_params = openai_params + litellm_params
|
||||||
non_default_params = {
|
non_default_params = {
|
||||||
|
@ -3297,7 +3317,7 @@ def embedding(
|
||||||
"cooldown_time": cooldown_time,
|
"cooldown_time": cooldown_time,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
if azure == True or custom_llm_provider == "azure":
|
if azure is True or custom_llm_provider == "azure":
|
||||||
# azure configs
|
# azure configs
|
||||||
api_type = get_secret("AZURE_API_TYPE") or "azure"
|
api_type = get_secret("AZURE_API_TYPE") or "azure"
|
||||||
|
|
||||||
|
@ -3403,12 +3423,18 @@ def embedding(
|
||||||
or get_secret("CO_API_KEY")
|
or get_secret("CO_API_KEY")
|
||||||
or litellm.api_key
|
or litellm.api_key
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if extra_headers is not None and isinstance(extra_headers, dict):
|
||||||
|
headers = extra_headers
|
||||||
|
else:
|
||||||
|
headers = {}
|
||||||
response = cohere.embedding(
|
response = cohere.embedding(
|
||||||
model=model,
|
model=model,
|
||||||
input=input,
|
input=input,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
api_key=cohere_key, # type: ignore
|
api_key=cohere_key, # type: ignore
|
||||||
|
headers=headers,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
model_response=EmbeddingResponse(),
|
model_response=EmbeddingResponse(),
|
||||||
aembedding=aembedding,
|
aembedding=aembedding,
|
||||||
|
|
|
@ -2,12 +2,3 @@ model_list:
|
||||||
- model_name: "*"
|
- model_name: "*"
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: "*"
|
model: "*"
|
||||||
|
|
||||||
litellm_settings:
|
|
||||||
success_callback: ["s3"]
|
|
||||||
cache: true
|
|
||||||
s3_callback_params:
|
|
||||||
s3_bucket_name: mytestbucketlitellm # AWS Bucket Name for S3
|
|
||||||
s3_region_name: us-west-2 # AWS Region Name for S3
|
|
||||||
s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/<variable name> to pass environment variables. This is AWS Access Key ID for S3
|
|
||||||
s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3
|
|
||||||
|
|
|
@ -66,7 +66,7 @@ def common_checks(
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Team={team_object.team_id} is blocked. Update via `/team/unblock` if your admin."
|
f"Team={team_object.team_id} is blocked. Update via `/team/unblock` if your admin."
|
||||||
)
|
)
|
||||||
# 2. If user can call model
|
# 2. If team can call model
|
||||||
if (
|
if (
|
||||||
_model is not None
|
_model is not None
|
||||||
and team_object is not None
|
and team_object is not None
|
||||||
|
@ -74,7 +74,11 @@ def common_checks(
|
||||||
and _model not in team_object.models
|
and _model not in team_object.models
|
||||||
):
|
):
|
||||||
# this means the team has access to all models on the proxy
|
# this means the team has access to all models on the proxy
|
||||||
if "all-proxy-models" in team_object.models:
|
if (
|
||||||
|
"all-proxy-models" in team_object.models
|
||||||
|
or "*" in team_object.models
|
||||||
|
or "openai/*" in team_object.models
|
||||||
|
):
|
||||||
# this means the team has access to all models on the proxy
|
# this means the team has access to all models on the proxy
|
||||||
pass
|
pass
|
||||||
# check if the team model is an access_group
|
# check if the team model is an access_group
|
||||||
|
|
|
@ -22,3 +22,9 @@ guardrails:
|
||||||
mode: "post_call"
|
mode: "post_call"
|
||||||
api_key: os.environ/APORIA_API_KEY_2
|
api_key: os.environ/APORIA_API_KEY_2
|
||||||
api_base: os.environ/APORIA_API_BASE_2
|
api_base: os.environ/APORIA_API_BASE_2
|
||||||
|
- guardrail_name: "bedrock-pre-guard"
|
||||||
|
litellm_params:
|
||||||
|
guardrail: bedrock # supported values: "aporia", "bedrock", "lakera"
|
||||||
|
mode: "pre_call"
|
||||||
|
guardrailIdentifier: ff6ujrregl1q
|
||||||
|
guardrailVersion: "DRAFT"
|
289
litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py
Normal file
289
litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py
Normal file
|
@ -0,0 +1,289 @@
|
||||||
|
# +-------------------------------------------------------------+
|
||||||
|
#
|
||||||
|
# Use Bedrock Guardrails for your LLM calls
|
||||||
|
#
|
||||||
|
# +-------------------------------------------------------------+
|
||||||
|
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import httpx
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import get_secret
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||||
|
from litellm.litellm_core_utils.logging_utils import (
|
||||||
|
convert_litellm_response_object_to_str,
|
||||||
|
)
|
||||||
|
from litellm.llms.base_aws_llm import BaseAWSLLM
|
||||||
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
AsyncHTTPHandler,
|
||||||
|
_get_async_httpx_client,
|
||||||
|
)
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
from litellm.proxy.guardrails.guardrail_helpers import should_proceed_based_on_metadata
|
||||||
|
from litellm.types.guardrails import (
|
||||||
|
BedrockContentItem,
|
||||||
|
BedrockRequest,
|
||||||
|
BedrockTextContent,
|
||||||
|
GuardrailEventHooks,
|
||||||
|
)
|
||||||
|
|
||||||
|
GUARDRAIL_NAME = "bedrock"
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
guardrailIdentifier: Optional[str] = None,
|
||||||
|
guardrailVersion: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.async_handler = _get_async_httpx_client()
|
||||||
|
self.guardrailIdentifier = guardrailIdentifier
|
||||||
|
self.guardrailVersion = guardrailVersion
|
||||||
|
|
||||||
|
# store kwargs as optional_params
|
||||||
|
self.optional_params = kwargs
|
||||||
|
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def convert_to_bedrock_format(
|
||||||
|
self,
|
||||||
|
messages: Optional[List[Dict[str, str]]] = None,
|
||||||
|
response: Optional[Union[Any, litellm.ModelResponse]] = None,
|
||||||
|
) -> BedrockRequest:
|
||||||
|
bedrock_request: BedrockRequest = BedrockRequest(source="INPUT")
|
||||||
|
bedrock_request_content: List[BedrockContentItem] = []
|
||||||
|
|
||||||
|
if messages:
|
||||||
|
for message in messages:
|
||||||
|
content = message.get("content")
|
||||||
|
if isinstance(content, str):
|
||||||
|
bedrock_content_item = BedrockContentItem(
|
||||||
|
text=BedrockTextContent(text=content)
|
||||||
|
)
|
||||||
|
bedrock_request_content.append(bedrock_content_item)
|
||||||
|
|
||||||
|
bedrock_request["content"] = bedrock_request_content
|
||||||
|
if response:
|
||||||
|
bedrock_request["source"] = "OUTPUT"
|
||||||
|
if isinstance(response, litellm.ModelResponse):
|
||||||
|
for choice in response.choices:
|
||||||
|
if isinstance(choice, litellm.Choices):
|
||||||
|
if choice.message.content and isinstance(
|
||||||
|
choice.message.content, str
|
||||||
|
):
|
||||||
|
bedrock_content_item = BedrockContentItem(
|
||||||
|
text=BedrockTextContent(text=choice.message.content)
|
||||||
|
)
|
||||||
|
bedrock_request_content.append(bedrock_content_item)
|
||||||
|
bedrock_request["content"] = bedrock_request_content
|
||||||
|
return bedrock_request
|
||||||
|
|
||||||
|
#### CALL HOOKS - proxy only ####
|
||||||
|
def _load_credentials(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
from botocore.credentials import Credentials
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||||
|
## CREDENTIALS ##
|
||||||
|
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
|
||||||
|
aws_secret_access_key = self.optional_params.pop("aws_secret_access_key", None)
|
||||||
|
aws_access_key_id = self.optional_params.pop("aws_access_key_id", None)
|
||||||
|
aws_session_token = self.optional_params.pop("aws_session_token", None)
|
||||||
|
aws_region_name = self.optional_params.pop("aws_region_name", None)
|
||||||
|
aws_role_name = self.optional_params.pop("aws_role_name", None)
|
||||||
|
aws_session_name = self.optional_params.pop("aws_session_name", None)
|
||||||
|
aws_profile_name = self.optional_params.pop("aws_profile_name", None)
|
||||||
|
aws_bedrock_runtime_endpoint = self.optional_params.pop(
|
||||||
|
"aws_bedrock_runtime_endpoint", None
|
||||||
|
) # https://bedrock-runtime.{region_name}.amazonaws.com
|
||||||
|
aws_web_identity_token = self.optional_params.pop(
|
||||||
|
"aws_web_identity_token", None
|
||||||
|
)
|
||||||
|
aws_sts_endpoint = self.optional_params.pop("aws_sts_endpoint", None)
|
||||||
|
|
||||||
|
### SET REGION NAME ###
|
||||||
|
if aws_region_name is None:
|
||||||
|
# check env #
|
||||||
|
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||||
|
|
||||||
|
if litellm_aws_region_name is not None and isinstance(
|
||||||
|
litellm_aws_region_name, str
|
||||||
|
):
|
||||||
|
aws_region_name = litellm_aws_region_name
|
||||||
|
|
||||||
|
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||||
|
if standard_aws_region_name is not None and isinstance(
|
||||||
|
standard_aws_region_name, str
|
||||||
|
):
|
||||||
|
aws_region_name = standard_aws_region_name
|
||||||
|
|
||||||
|
if aws_region_name is None:
|
||||||
|
aws_region_name = "us-west-2"
|
||||||
|
|
||||||
|
credentials: Credentials = self.get_credentials(
|
||||||
|
aws_access_key_id=aws_access_key_id,
|
||||||
|
aws_secret_access_key=aws_secret_access_key,
|
||||||
|
aws_session_token=aws_session_token,
|
||||||
|
aws_region_name=aws_region_name,
|
||||||
|
aws_session_name=aws_session_name,
|
||||||
|
aws_profile_name=aws_profile_name,
|
||||||
|
aws_role_name=aws_role_name,
|
||||||
|
aws_web_identity_token=aws_web_identity_token,
|
||||||
|
aws_sts_endpoint=aws_sts_endpoint,
|
||||||
|
)
|
||||||
|
return credentials, aws_region_name
|
||||||
|
|
||||||
|
def _prepare_request(
|
||||||
|
self,
|
||||||
|
credentials,
|
||||||
|
data: BedrockRequest,
|
||||||
|
optional_params: dict,
|
||||||
|
aws_region_name: str,
|
||||||
|
extra_headers: Optional[dict] = None,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
import boto3
|
||||||
|
from botocore.auth import SigV4Auth
|
||||||
|
from botocore.awsrequest import AWSRequest
|
||||||
|
from botocore.credentials import Credentials
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||||
|
|
||||||
|
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||||
|
api_base = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com/guardrail/{self.guardrailIdentifier}/version/{self.guardrailVersion}/apply"
|
||||||
|
|
||||||
|
encoded_data = json.dumps(data).encode("utf-8")
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
if extra_headers is not None:
|
||||||
|
headers = {"Content-Type": "application/json", **extra_headers}
|
||||||
|
|
||||||
|
request = AWSRequest(
|
||||||
|
method="POST", url=api_base, data=encoded_data, headers=headers
|
||||||
|
)
|
||||||
|
sigv4.add_auth(request)
|
||||||
|
prepped_request = request.prepare()
|
||||||
|
|
||||||
|
return prepped_request
|
||||||
|
|
||||||
|
async def make_bedrock_api_request(
|
||||||
|
self, kwargs: dict, response: Optional[Union[Any, litellm.ModelResponse]] = None
|
||||||
|
):
|
||||||
|
|
||||||
|
credentials, aws_region_name = self._load_credentials()
|
||||||
|
request_data: BedrockRequest = self.convert_to_bedrock_format(
|
||||||
|
messages=kwargs.get("messages"), response=response
|
||||||
|
)
|
||||||
|
prepared_request = self._prepare_request(
|
||||||
|
credentials=credentials,
|
||||||
|
data=request_data,
|
||||||
|
optional_params=self.optional_params,
|
||||||
|
aws_region_name=aws_region_name,
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
"Bedrock AI request body: %s, url %s, headers: %s",
|
||||||
|
request_data,
|
||||||
|
prepared_request.url,
|
||||||
|
prepared_request.headers,
|
||||||
|
)
|
||||||
|
_json_data = json.dumps(request_data) # type: ignore
|
||||||
|
response = await self.async_handler.post(
|
||||||
|
url=prepared_request.url,
|
||||||
|
json=request_data, # type: ignore
|
||||||
|
headers=prepared_request.headers,
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.debug("Bedrock AI response: %s", response.text)
|
||||||
|
if response.status_code == 200:
|
||||||
|
# check if the response was flagged
|
||||||
|
_json_response = response.json()
|
||||||
|
if _json_response.get("action") == "GUARDRAIL_INTERVENED":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"error": "Violated guardrail policy",
|
||||||
|
"bedrock_guardrail_response": _json_response,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
verbose_proxy_logger.error(
|
||||||
|
"Bedrock AI: error in response. Status code: %s, response: %s",
|
||||||
|
response.status_code,
|
||||||
|
response.text,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_moderation_hook( ### 👈 KEY CHANGE ###
|
||||||
|
self,
|
||||||
|
data: dict,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
call_type: Literal["completion", "embeddings", "image_generation"],
|
||||||
|
):
|
||||||
|
from litellm.proxy.common_utils.callback_utils import (
|
||||||
|
add_guardrail_to_applied_guardrails_header,
|
||||||
|
)
|
||||||
|
from litellm.types.guardrails import GuardrailEventHooks
|
||||||
|
|
||||||
|
event_type: GuardrailEventHooks = GuardrailEventHooks.during_call
|
||||||
|
if self.should_run_guardrail(data=data, event_type=event_type) is not True:
|
||||||
|
return
|
||||||
|
|
||||||
|
new_messages: Optional[List[dict]] = data.get("messages")
|
||||||
|
if new_messages is not None:
|
||||||
|
await self.make_bedrock_api_request(kwargs=data)
|
||||||
|
add_guardrail_to_applied_guardrails_header(
|
||||||
|
request_data=data, guardrail_name=self.guardrail_name
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
verbose_proxy_logger.warning(
|
||||||
|
"Bedrock AI: not running guardrail. No messages in data"
|
||||||
|
)
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def async_post_call_success_hook(
|
||||||
|
self,
|
||||||
|
data: dict,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
response,
|
||||||
|
):
|
||||||
|
from litellm.proxy.common_utils.callback_utils import (
|
||||||
|
add_guardrail_to_applied_guardrails_header,
|
||||||
|
)
|
||||||
|
from litellm.types.guardrails import GuardrailEventHooks
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.should_run_guardrail(
|
||||||
|
data=data, event_type=GuardrailEventHooks.post_call
|
||||||
|
)
|
||||||
|
is not True
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
new_messages: Optional[List[dict]] = data.get("messages")
|
||||||
|
if new_messages is not None:
|
||||||
|
await self.make_bedrock_api_request(kwargs=data, response=response)
|
||||||
|
add_guardrail_to_applied_guardrails_header(
|
||||||
|
request_data=data, guardrail_name=self.guardrail_name
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
verbose_proxy_logger.warning(
|
||||||
|
"Bedrock AI: not running guardrail. No messages in data"
|
||||||
|
)
|
|
@ -96,8 +96,10 @@ def init_guardrails_v2(all_guardrails: dict):
|
||||||
litellm_params = LitellmParams(
|
litellm_params = LitellmParams(
|
||||||
guardrail=litellm_params_data["guardrail"],
|
guardrail=litellm_params_data["guardrail"],
|
||||||
mode=litellm_params_data["mode"],
|
mode=litellm_params_data["mode"],
|
||||||
api_key=litellm_params_data["api_key"],
|
api_key=litellm_params_data.get("api_key"),
|
||||||
api_base=litellm_params_data["api_base"],
|
api_base=litellm_params_data.get("api_base"),
|
||||||
|
guardrailIdentifier=litellm_params_data.get("guardrailIdentifier"),
|
||||||
|
guardrailVersion=litellm_params_data.get("guardrailVersion"),
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
@ -134,6 +136,18 @@ def init_guardrails_v2(all_guardrails: dict):
|
||||||
event_hook=litellm_params["mode"],
|
event_hook=litellm_params["mode"],
|
||||||
)
|
)
|
||||||
litellm.callbacks.append(_aporia_callback) # type: ignore
|
litellm.callbacks.append(_aporia_callback) # type: ignore
|
||||||
|
if litellm_params["guardrail"] == "bedrock":
|
||||||
|
from litellm.proxy.guardrails.guardrail_hooks.bedrock_guardrails import (
|
||||||
|
BedrockGuardrail,
|
||||||
|
)
|
||||||
|
|
||||||
|
_bedrock_callback = BedrockGuardrail(
|
||||||
|
guardrail_name=guardrail["guardrail_name"],
|
||||||
|
event_hook=litellm_params["mode"],
|
||||||
|
guardrailIdentifier=litellm_params["guardrailIdentifier"],
|
||||||
|
guardrailVersion=litellm_params["guardrailVersion"],
|
||||||
|
)
|
||||||
|
litellm.callbacks.append(_bedrock_callback) # type: ignore
|
||||||
elif litellm_params["guardrail"] == "lakera":
|
elif litellm_params["guardrail"] == "lakera":
|
||||||
from litellm.proxy.guardrails.guardrail_hooks.lakera_ai import (
|
from litellm.proxy.guardrails.guardrail_hooks.lakera_ai import (
|
||||||
lakeraAI_Moderation,
|
lakeraAI_Moderation,
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: gpt-4
|
- model_name: fake-openai-endpoint
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: azure/chatgpt-v-2
|
model: azure/chatgpt-v-2
|
||||||
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
||||||
|
@ -9,13 +9,9 @@ model_list:
|
||||||
client_secret: os.environ/AZURE_CLIENT_SECRET
|
client_secret: os.environ/AZURE_CLIENT_SECRET
|
||||||
|
|
||||||
guardrails:
|
guardrails:
|
||||||
- guardrail_name: "lakera-pre-guard"
|
- guardrail_name: "bedrock-pre-guard"
|
||||||
litellm_params:
|
litellm_params:
|
||||||
guardrail: lakera # supported values: "aporia", "bedrock", "lakera"
|
guardrail: bedrock # supported values: "aporia", "bedrock", "lakera"
|
||||||
mode: "during_call"
|
mode: "post_call"
|
||||||
api_key: os.environ/LAKERA_API_KEY
|
guardrailIdentifier: ff6ujrregl1q
|
||||||
api_base: os.environ/LAKERA_API_BASE
|
guardrailVersion: "DRAFT"
|
||||||
category_thresholds:
|
|
||||||
prompt_injection: 0.1
|
|
||||||
jailbreak: 0.1
|
|
||||||
|
|
|
@ -1588,7 +1588,7 @@ class ProxyConfig:
|
||||||
verbose_proxy_logger.debug( # noqa
|
verbose_proxy_logger.debug( # noqa
|
||||||
f"{blue_color_code}Set Cache on LiteLLM Proxy: {vars(litellm.cache.cache)}{reset_color_code}"
|
f"{blue_color_code}Set Cache on LiteLLM Proxy: {vars(litellm.cache.cache)}{reset_color_code}"
|
||||||
)
|
)
|
||||||
elif key == "cache" and value == False:
|
elif key == "cache" and value is False:
|
||||||
pass
|
pass
|
||||||
elif key == "guardrails":
|
elif key == "guardrails":
|
||||||
if premium_user is not True:
|
if premium_user is not True:
|
||||||
|
@ -2672,6 +2672,13 @@ def giveup(e):
|
||||||
and isinstance(e.message, str)
|
and isinstance(e.message, str)
|
||||||
and "Max parallel request limit reached" in e.message
|
and "Max parallel request limit reached" in e.message
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
general_settings.get("disable_retry_on_max_parallel_request_limit_error")
|
||||||
|
is True
|
||||||
|
):
|
||||||
|
return True # giveup if queuing max parallel request limits is disabled
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
verbose_proxy_logger.info(json.dumps({"event": "giveup", "exception": str(e)}))
|
verbose_proxy_logger.info(json.dumps({"event": "giveup", "exception": str(e)}))
|
||||||
return result
|
return result
|
||||||
|
|
|
@ -277,7 +277,8 @@ class Router:
|
||||||
"local" # default to an in-memory cache
|
"local" # default to an in-memory cache
|
||||||
)
|
)
|
||||||
redis_cache = None
|
redis_cache = None
|
||||||
cache_config = {}
|
cache_config: Dict[str, Any] = {}
|
||||||
|
|
||||||
self.client_ttl = client_ttl
|
self.client_ttl = client_ttl
|
||||||
if redis_url is not None or (
|
if redis_url is not None or (
|
||||||
redis_host is not None
|
redis_host is not None
|
||||||
|
|
|
@ -738,8 +738,9 @@ def test_bedrock_system_prompt(system, model):
|
||||||
"temperature": 0.3,
|
"temperature": 0.3,
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "system", "content": system},
|
{"role": "system", "content": system},
|
||||||
{"role": "user", "content": "hey, how's it going?"},
|
{"role": "assistant", "content": "hey, how's it going?"},
|
||||||
],
|
],
|
||||||
|
"user_continue_message": {"role": "user", "content": "Be a good bot!"},
|
||||||
}
|
}
|
||||||
response: ModelResponse = completion(
|
response: ModelResponse = completion(
|
||||||
model="bedrock/{}".format(model),
|
model="bedrock/{}".format(model),
|
||||||
|
|
|
@ -804,6 +804,38 @@ def test_redis_cache_completion_stream():
|
||||||
# test_redis_cache_completion_stream()
|
# test_redis_cache_completion_stream()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="Local test. Requires running redis cluster locally.")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_redis_cache_cluster_init_unit_test():
|
||||||
|
try:
|
||||||
|
from redis.asyncio import RedisCluster as AsyncRedisCluster
|
||||||
|
from redis.cluster import RedisCluster
|
||||||
|
|
||||||
|
from litellm.caching import RedisCache
|
||||||
|
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
# List of startup nodes
|
||||||
|
startup_nodes = [
|
||||||
|
{"host": "127.0.0.1", "port": "7001"},
|
||||||
|
]
|
||||||
|
|
||||||
|
resp = RedisCache(startup_nodes=startup_nodes)
|
||||||
|
|
||||||
|
assert isinstance(resp.redis_client, RedisCluster)
|
||||||
|
assert isinstance(resp.init_async_client(), AsyncRedisCluster)
|
||||||
|
|
||||||
|
resp = litellm.Cache(type="redis", redis_startup_nodes=startup_nodes)
|
||||||
|
|
||||||
|
assert isinstance(resp.cache, RedisCache)
|
||||||
|
assert isinstance(resp.cache.redis_client, RedisCluster)
|
||||||
|
assert isinstance(resp.cache.init_async_client(), AsyncRedisCluster)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"{str(e)}\n\n{traceback.format_exc()}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_redis_cache_acompletion_stream():
|
async def test_redis_cache_acompletion_stream():
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -3653,6 +3653,7 @@ def test_completion_cohere():
|
||||||
response = completion(
|
response = completion(
|
||||||
model="command-r",
|
model="command-r",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
extra_headers={"Helicone-Property-Locale": "ko"},
|
||||||
)
|
)
|
||||||
print(response)
|
print(response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, List, Optional, TypedDict
|
from typing import Dict, List, Literal, Optional, TypedDict
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
from typing_extensions import Required, TypedDict
|
from typing_extensions import Required, TypedDict
|
||||||
|
@ -76,8 +76,14 @@ class LitellmParams(TypedDict, total=False):
|
||||||
mode: str
|
mode: str
|
||||||
api_key: str
|
api_key: str
|
||||||
api_base: Optional[str]
|
api_base: Optional[str]
|
||||||
|
|
||||||
|
# Lakera specific params
|
||||||
category_thresholds: Optional[LakeraCategoryThresholds]
|
category_thresholds: Optional[LakeraCategoryThresholds]
|
||||||
|
|
||||||
|
# Bedrock specific params
|
||||||
|
guardrailIdentifier: Optional[str]
|
||||||
|
guardrailVersion: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
class Guardrail(TypedDict):
|
class Guardrail(TypedDict):
|
||||||
guardrail_name: str
|
guardrail_name: str
|
||||||
|
@ -92,3 +98,16 @@ class GuardrailEventHooks(str, Enum):
|
||||||
pre_call = "pre_call"
|
pre_call = "pre_call"
|
||||||
post_call = "post_call"
|
post_call = "post_call"
|
||||||
during_call = "during_call"
|
during_call = "during_call"
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockTextContent(TypedDict, total=False):
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockContentItem(TypedDict, total=False):
|
||||||
|
text: BedrockTextContent
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockRequest(TypedDict, total=False):
|
||||||
|
source: Literal["INPUT", "OUTPUT"]
|
||||||
|
content: List[BedrockContentItem]
|
||||||
|
|
|
@ -1120,6 +1120,7 @@ all_litellm_params = [
|
||||||
"tenant_id",
|
"tenant_id",
|
||||||
"client_id",
|
"client_id",
|
||||||
"client_secret",
|
"client_secret",
|
||||||
|
"user_continue_message",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2324,6 +2324,7 @@ def get_litellm_params(
|
||||||
cooldown_time=None,
|
cooldown_time=None,
|
||||||
text_completion=None,
|
text_completion=None,
|
||||||
azure_ad_token_provider=None,
|
azure_ad_token_provider=None,
|
||||||
|
user_continue_message=None,
|
||||||
):
|
):
|
||||||
litellm_params = {
|
litellm_params = {
|
||||||
"acompletion": acompletion,
|
"acompletion": acompletion,
|
||||||
|
@ -2349,6 +2350,7 @@ def get_litellm_params(
|
||||||
"cooldown_time": cooldown_time,
|
"cooldown_time": cooldown_time,
|
||||||
"text_completion": text_completion,
|
"text_completion": text_completion,
|
||||||
"azure_ad_token_provider": azure_ad_token_provider,
|
"azure_ad_token_provider": azure_ad_token_provider,
|
||||||
|
"user_continue_message": user_continue_message,
|
||||||
}
|
}
|
||||||
|
|
||||||
return litellm_params
|
return litellm_params
|
||||||
|
@ -4221,6 +4223,7 @@ def get_supported_openai_params(
|
||||||
"presence_penalty",
|
"presence_penalty",
|
||||||
"stop",
|
"stop",
|
||||||
"n",
|
"n",
|
||||||
|
"extra_headers",
|
||||||
]
|
]
|
||||||
elif custom_llm_provider == "cohere_chat":
|
elif custom_llm_provider == "cohere_chat":
|
||||||
return [
|
return [
|
||||||
|
@ -4235,6 +4238,7 @@ def get_supported_openai_params(
|
||||||
"tools",
|
"tools",
|
||||||
"tool_choice",
|
"tool_choice",
|
||||||
"seed",
|
"seed",
|
||||||
|
"extra_headers",
|
||||||
]
|
]
|
||||||
elif custom_llm_provider == "maritalk":
|
elif custom_llm_provider == "maritalk":
|
||||||
return [
|
return [
|
||||||
|
@ -7123,6 +7127,14 @@ def exception_type(
|
||||||
llm_provider="bedrock",
|
llm_provider="bedrock",
|
||||||
response=original_exception.response,
|
response=original_exception.response,
|
||||||
)
|
)
|
||||||
|
elif "A conversation must start with a user message." in error_str:
|
||||||
|
exception_mapping_worked = True
|
||||||
|
raise BadRequestError(
|
||||||
|
message=f"BedrockException - {error_str}\n. Pass in default user message via `completion(..,user_continue_message=)` or enable `litellm.modify_params=True`.\nFor Proxy: do via `litellm_settings::modify_params: True` or user_continue_message under `litellm_params`",
|
||||||
|
model=model,
|
||||||
|
llm_provider="bedrock",
|
||||||
|
response=original_exception.response,
|
||||||
|
)
|
||||||
elif (
|
elif (
|
||||||
"Unable to locate credentials" in error_str
|
"Unable to locate credentials" in error_str
|
||||||
or "The security token included in the request is invalid"
|
or "The security token included in the request is invalid"
|
||||||
|
|
|
@ -144,6 +144,7 @@ async def test_no_llm_guard_triggered():
|
||||||
|
|
||||||
assert "x-litellm-applied-guardrails" not in headers
|
assert "x-litellm-applied-guardrails" not in headers
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_guardrails_with_api_key_controls():
|
async def test_guardrails_with_api_key_controls():
|
||||||
"""
|
"""
|
||||||
|
@ -194,3 +195,25 @@ async def test_guardrails_with_api_key_controls():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
assert "Aporia detected and blocked PII" in str(e)
|
assert "Aporia detected and blocked PII" in str(e)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bedrock_guardrail_triggered():
|
||||||
|
"""
|
||||||
|
- Tests a request where our bedrock guardrail should be triggered
|
||||||
|
- Assert that the guardrails applied are returned in the response headers
|
||||||
|
"""
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
try:
|
||||||
|
response, headers = await chat_completion(
|
||||||
|
session,
|
||||||
|
"sk-1234",
|
||||||
|
model="fake-openai-endpoint",
|
||||||
|
messages=[{"role": "user", "content": f"Hello do you like coffee?"}],
|
||||||
|
guardrails=["bedrock-pre-guard"],
|
||||||
|
)
|
||||||
|
pytest.fail("Should have thrown an exception")
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
assert "GUARDRAIL_INTERVENED" in str(e)
|
||||||
|
assert "Violated guardrail policy" in str(e)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue