mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
Merge branch 'main' into litellm_disable_storing_master_key_hash_in_db
This commit is contained in:
commit
72169fd5c4
22 changed files with 537 additions and 169 deletions
|
@ -161,8 +161,7 @@ random_number = random.randint(
|
||||||
print("testing semantic caching")
|
print("testing semantic caching")
|
||||||
litellm.cache = Cache(
|
litellm.cache = Cache(
|
||||||
type="qdrant-semantic",
|
type="qdrant-semantic",
|
||||||
qdrant_host_type="cloud", # can be either 'cloud' or 'local'
|
qdrant_api_base=os.environ["QDRANT_API_BASE"],
|
||||||
qdrant_url=os.environ["QDRANT_URL"],
|
|
||||||
qdrant_api_key=os.environ["QDRANT_API_KEY"],
|
qdrant_api_key=os.environ["QDRANT_API_KEY"],
|
||||||
qdrant_collection_name="your_collection_name", # any name of your collection
|
qdrant_collection_name="your_collection_name", # any name of your collection
|
||||||
similarity_threshold=0.7, # similarity threshold for cache hits, 0 == no similarity, 1 = exact matches, 0.5 == 50% similarity
|
similarity_threshold=0.7, # similarity threshold for cache hits, 0 == no similarity, 1 = exact matches, 0.5 == 50% similarity
|
||||||
|
@ -491,12 +490,11 @@ def __init__(
|
||||||
disk_cache_dir=None,
|
disk_cache_dir=None,
|
||||||
|
|
||||||
# qdrant cache params
|
# qdrant cache params
|
||||||
qdrant_url: Optional[str] = None,
|
qdrant_api_base: Optional[str] = None,
|
||||||
qdrant_api_key: Optional[str] = None,
|
qdrant_api_key: Optional[str] = None,
|
||||||
qdrant_collection_name: Optional[str] = None,
|
qdrant_collection_name: Optional[str] = None,
|
||||||
qdrant_quantization_config: Optional[str] = None,
|
qdrant_quantization_config: Optional[str] = None,
|
||||||
qdrant_semantic_cache_embedding_model="text-embedding-ada-002",
|
qdrant_semantic_cache_embedding_model="text-embedding-ada-002",
|
||||||
qdrant_host_type: Optional[Literal["local","cloud"]] = "local",
|
|
||||||
|
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
|
|
|
@ -7,6 +7,7 @@ Cache LLM Responses
|
||||||
LiteLLM supports:
|
LiteLLM supports:
|
||||||
- In Memory Cache
|
- In Memory Cache
|
||||||
- Redis Cache
|
- Redis Cache
|
||||||
|
- Qdrant Semantic Cache
|
||||||
- Redis Semantic Cache
|
- Redis Semantic Cache
|
||||||
- s3 Bucket Cache
|
- s3 Bucket Cache
|
||||||
|
|
||||||
|
@ -103,6 +104,66 @@ $ litellm --config /path/to/config.yaml
|
||||||
```
|
```
|
||||||
</TabItem>
|
</TabItem>
|
||||||
|
|
||||||
|
|
||||||
|
<TabItem value="qdrant-semantic" label="Qdrant Semantic cache">
|
||||||
|
|
||||||
|
Caching can be enabled by adding the `cache` key in the `config.yaml`
|
||||||
|
|
||||||
|
#### Step 1: Add `cache` to the config.yaml
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: fake-openai-endpoint
|
||||||
|
litellm_params:
|
||||||
|
model: openai/fake
|
||||||
|
api_key: fake-key
|
||||||
|
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||||
|
- model_name: openai-embedding
|
||||||
|
litellm_params:
|
||||||
|
model: openai/text-embedding-3-small
|
||||||
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
|
|
||||||
|
litellm_settings:
|
||||||
|
set_verbose: True
|
||||||
|
cache: True # set cache responses to True, litellm defaults to using a redis cache
|
||||||
|
cache_params:
|
||||||
|
type: qdrant-semantic
|
||||||
|
qdrant_semantic_cache_embedding_model: openai-embedding # the model should be defined on the model_list
|
||||||
|
qdrant_collection_name: test_collection
|
||||||
|
qdrant_quantization_config: binary
|
||||||
|
similarity_threshold: 0.8 # similarity threshold for semantic cache
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Step 2: Add Qdrant Credentials to your .env
|
||||||
|
|
||||||
|
```shell
|
||||||
|
QDRANT_API_KEY = "16rJUMBRx*************"
|
||||||
|
QDRANT_API_BASE = "https://5392d382-45*********.cloud.qdrant.io"
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Step 3: Run proxy with config
|
||||||
|
```shell
|
||||||
|
$ litellm --config /path/to/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
#### Step 4. Test it
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl -i http://localhost:4000/v1/chat/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-H "Authorization: Bearer sk-1234" \
|
||||||
|
-d '{
|
||||||
|
"model": "fake-openai-endpoint",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hello"}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
**Expect to see `x-litellm-semantic-similarity` in the response headers when semantic caching is one**
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
<TabItem value="s3" label="s3 cache">
|
<TabItem value="s3" label="s3 cache">
|
||||||
|
|
||||||
#### Step 1: Add `cache` to the config.yaml
|
#### Step 1: Add `cache` to the config.yaml
|
||||||
|
@ -182,6 +243,9 @@ REDIS_<redis-kwarg-name> = ""
|
||||||
$ litellm --config /path/to/config.yaml
|
$ litellm --config /path/to/config.yaml
|
||||||
```
|
```
|
||||||
</TabItem>
|
</TabItem>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -74,6 +74,7 @@ const sidebars = {
|
||||||
"proxy/alerting",
|
"proxy/alerting",
|
||||||
"proxy/ui",
|
"proxy/ui",
|
||||||
"proxy/prometheus",
|
"proxy/prometheus",
|
||||||
|
"proxy/caching",
|
||||||
"proxy/pass_through",
|
"proxy/pass_through",
|
||||||
"proxy/email",
|
"proxy/email",
|
||||||
"proxy/multiple_admins",
|
"proxy/multiple_admins",
|
||||||
|
@ -88,7 +89,6 @@ const sidebars = {
|
||||||
"proxy/health",
|
"proxy/health",
|
||||||
"proxy/debugging",
|
"proxy/debugging",
|
||||||
"proxy/pii_masking",
|
"proxy/pii_masking",
|
||||||
"proxy/caching",
|
|
||||||
"proxy/call_hooks",
|
"proxy/call_hooks",
|
||||||
"proxy/rules",
|
"proxy/rules",
|
||||||
"proxy/cli",
|
"proxy/cli",
|
||||||
|
|
|
@ -1219,20 +1219,23 @@ class RedisSemanticCache(BaseCache):
|
||||||
async def _index_info(self):
|
async def _index_info(self):
|
||||||
return await self.index.ainfo()
|
return await self.index.ainfo()
|
||||||
|
|
||||||
|
|
||||||
class QdrantSemanticCache(BaseCache):
|
class QdrantSemanticCache(BaseCache):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
qdrant_url=None,
|
qdrant_api_base=None,
|
||||||
qdrant_api_key=None,
|
qdrant_api_key=None,
|
||||||
collection_name=None,
|
collection_name=None,
|
||||||
similarity_threshold=None,
|
similarity_threshold=None,
|
||||||
quantization_config=None,
|
quantization_config=None,
|
||||||
embedding_model="text-embedding-ada-002",
|
embedding_model="text-embedding-ada-002",
|
||||||
host_type = None
|
host_type=None,
|
||||||
):
|
):
|
||||||
|
import os
|
||||||
|
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
_get_async_httpx_client,
|
||||||
_get_httpx_client,
|
_get_httpx_client,
|
||||||
_get_async_httpx_client
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if collection_name is None:
|
if collection_name is None:
|
||||||
|
@ -1247,45 +1250,32 @@ class QdrantSemanticCache(BaseCache):
|
||||||
raise Exception("similarity_threshold must be provided, passed None")
|
raise Exception("similarity_threshold must be provided, passed None")
|
||||||
self.similarity_threshold = similarity_threshold
|
self.similarity_threshold = similarity_threshold
|
||||||
self.embedding_model = embedding_model
|
self.embedding_model = embedding_model
|
||||||
|
headers = {}
|
||||||
|
|
||||||
if host_type=="cloud":
|
# check if defined as os.environ/ variable
|
||||||
import os
|
if qdrant_api_base:
|
||||||
if qdrant_url is None:
|
if isinstance(qdrant_api_base, str) and qdrant_api_base.startswith(
|
||||||
qdrant_url = os.getenv('QDRANT_URL')
|
"os.environ/"
|
||||||
if qdrant_api_key is None:
|
):
|
||||||
qdrant_api_key = os.getenv('QDRANT_API_KEY')
|
qdrant_api_base = litellm.get_secret(qdrant_api_base)
|
||||||
if qdrant_url is not None and qdrant_api_key is not None:
|
if qdrant_api_key:
|
||||||
headers = {
|
if isinstance(qdrant_api_key, str) and qdrant_api_key.startswith(
|
||||||
"api-key": qdrant_api_key,
|
"os.environ/"
|
||||||
"Content-Type": "application/json"
|
):
|
||||||
}
|
qdrant_api_key = litellm.get_secret(qdrant_api_key)
|
||||||
else:
|
|
||||||
raise Exception("Qdrant url and api_key must be provided for qdrant cloud hosting")
|
|
||||||
elif host_type=="local":
|
|
||||||
import os
|
|
||||||
if qdrant_url is None:
|
|
||||||
qdrant_url = os.getenv('QDRANT_URL')
|
|
||||||
if qdrant_url is None:
|
|
||||||
raise Exception("Qdrant url must be provided for qdrant local hosting")
|
|
||||||
if qdrant_api_key is None:
|
|
||||||
qdrant_api_key = os.getenv('QDRANT_API_KEY')
|
|
||||||
if qdrant_api_key is None:
|
|
||||||
print_verbose('Running locally without API Key.')
|
|
||||||
headers= {
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
print_verbose("Running locally with API Key")
|
|
||||||
headers = {
|
|
||||||
"api-key": qdrant_api_key,
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
raise Exception("Host type can be either 'local' or 'cloud'")
|
|
||||||
|
|
||||||
self.qdrant_url = qdrant_url
|
qdrant_api_base = (
|
||||||
|
qdrant_api_base or os.getenv("QDRANT_URL") or os.getenv("QDRANT_API_BASE")
|
||||||
|
)
|
||||||
|
qdrant_api_key = qdrant_api_key or os.getenv("QDRANT_API_KEY")
|
||||||
|
headers = {"api-key": qdrant_api_key, "Content-Type": "application/json"}
|
||||||
|
|
||||||
|
if qdrant_api_key is None or qdrant_api_base is None:
|
||||||
|
raise ValueError("Qdrant url and api_key must be")
|
||||||
|
|
||||||
|
self.qdrant_api_base = qdrant_api_base
|
||||||
self.qdrant_api_key = qdrant_api_key
|
self.qdrant_api_key = qdrant_api_key
|
||||||
print_verbose(f"qdrant semantic-cache qdrant_url: {self.qdrant_url}")
|
print_verbose(f"qdrant semantic-cache qdrant_api_base: {self.qdrant_api_base}")
|
||||||
|
|
||||||
self.headers = headers
|
self.headers = headers
|
||||||
|
|
||||||
|
@ -1293,62 +1283,64 @@ class QdrantSemanticCache(BaseCache):
|
||||||
self.async_client = _get_async_httpx_client()
|
self.async_client = _get_async_httpx_client()
|
||||||
|
|
||||||
if quantization_config is None:
|
if quantization_config is None:
|
||||||
print('Quantization config is not provided. Default binary quantization will be used.')
|
print_verbose(
|
||||||
|
"Quantization config is not provided. Default binary quantization will be used."
|
||||||
collection_exists = self.sync_client.get(
|
|
||||||
url= f"{self.qdrant_url}/collections/{self.collection_name}/exists",
|
|
||||||
headers=self.headers
|
|
||||||
)
|
)
|
||||||
if collection_exists.json()['result']['exists']:
|
collection_exists = self.sync_client.get(
|
||||||
|
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/exists",
|
||||||
|
headers=self.headers,
|
||||||
|
)
|
||||||
|
if collection_exists.status_code != 200:
|
||||||
|
raise ValueError(
|
||||||
|
f"Error from qdrant checking if /collections exist {collection_exists.text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if collection_exists.json()["result"]["exists"]:
|
||||||
collection_details = self.sync_client.get(
|
collection_details = self.sync_client.get(
|
||||||
url=f"{self.qdrant_url}/collections/{self.collection_name}",
|
url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
|
||||||
headers=self.headers
|
headers=self.headers,
|
||||||
)
|
)
|
||||||
self.collection_info = collection_details.json()
|
self.collection_info = collection_details.json()
|
||||||
print_verbose(f'Collection already exists.\nCollection details:{self.collection_info}')
|
print_verbose(
|
||||||
|
f"Collection already exists.\nCollection details:{self.collection_info}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if quantization_config is None or quantization_config == 'binary':
|
if quantization_config is None or quantization_config == "binary":
|
||||||
quantization_params = {
|
quantization_params = {
|
||||||
"binary": {
|
"binary": {
|
||||||
"always_ram": False,
|
"always_ram": False,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
elif quantization_config == 'scalar':
|
elif quantization_config == "scalar":
|
||||||
quantization_params = {
|
quantization_params = {
|
||||||
"scalar": {
|
"scalar": {"type": "int8", "quantile": 0.99, "always_ram": False}
|
||||||
"type": "int8",
|
|
||||||
"quantile": 0.99,
|
|
||||||
"always_ram": False
|
|
||||||
}
|
}
|
||||||
}
|
elif quantization_config == "product":
|
||||||
elif quantization_config == 'product':
|
|
||||||
quantization_params = {
|
quantization_params = {
|
||||||
"product": {
|
"product": {"compression": "x16", "always_ram": False}
|
||||||
"compression": "x16",
|
|
||||||
"always_ram": False
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
raise Exception("Quantization config must be one of 'scalar', 'binary' or 'product'")
|
raise Exception(
|
||||||
|
"Quantization config must be one of 'scalar', 'binary' or 'product'"
|
||||||
|
)
|
||||||
|
|
||||||
new_collection_status = self.sync_client.put(
|
new_collection_status = self.sync_client.put(
|
||||||
url=f"{self.qdrant_url}/collections/{self.collection_name}",
|
url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
|
||||||
json={
|
json={
|
||||||
"vectors": {
|
"vectors": {"size": 1536, "distance": "Cosine"},
|
||||||
"size": 1536,
|
"quantization_config": quantization_params,
|
||||||
"distance": "Cosine"
|
|
||||||
},
|
},
|
||||||
"quantization_config": quantization_params
|
headers=self.headers,
|
||||||
},
|
|
||||||
headers=self.headers
|
|
||||||
)
|
)
|
||||||
if new_collection_status.json()["result"]:
|
if new_collection_status.json()["result"]:
|
||||||
collection_details = self.sync_client.get(
|
collection_details = self.sync_client.get(
|
||||||
url=f"{self.qdrant_url}/collections/{self.collection_name}",
|
url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
|
||||||
headers=self.headers
|
headers=self.headers,
|
||||||
)
|
)
|
||||||
self.collection_info = collection_details.json()
|
self.collection_info = collection_details.json()
|
||||||
print_verbose(f'New collection created.\nCollection details:{self.collection_info}')
|
print_verbose(
|
||||||
|
f"New collection created.\nCollection details:{self.collection_info}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception("Error while creating new collection")
|
raise Exception("Error while creating new collection")
|
||||||
|
|
||||||
|
@ -1394,14 +1386,14 @@ class QdrantSemanticCache(BaseCache):
|
||||||
"payload": {
|
"payload": {
|
||||||
"text": prompt,
|
"text": prompt,
|
||||||
"response": value,
|
"response": value,
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
keys = self.sync_client.put(
|
keys = self.sync_client.put(
|
||||||
url=f"{self.qdrant_url}/collections/{self.collection_name}/points",
|
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points",
|
||||||
headers=self.headers,
|
headers=self.headers,
|
||||||
json=data
|
json=data,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -1434,13 +1426,13 @@ class QdrantSemanticCache(BaseCache):
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"limit": 1,
|
"limit": 1,
|
||||||
"with_payload": True
|
"with_payload": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
search_response = self.sync_client.post(
|
search_response = self.sync_client.post(
|
||||||
url=f"{self.qdrant_url}/collections/{self.collection_name}/points/search",
|
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search",
|
||||||
headers=self.headers,
|
headers=self.headers,
|
||||||
json=data
|
json=data,
|
||||||
)
|
)
|
||||||
results = search_response.json()["result"]
|
results = search_response.json()["result"]
|
||||||
|
|
||||||
|
@ -1470,8 +1462,10 @@ class QdrantSemanticCache(BaseCache):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def async_set_cache(self, key, value, **kwargs):
|
async def async_set_cache(self, key, value, **kwargs):
|
||||||
from litellm.proxy.proxy_server import llm_router, llm_model_list
|
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
from litellm.proxy.proxy_server import llm_model_list, llm_router
|
||||||
|
|
||||||
print_verbose(f"async qdrant semantic-cache set_cache, kwargs: {kwargs}")
|
print_verbose(f"async qdrant semantic-cache set_cache, kwargs: {kwargs}")
|
||||||
|
|
||||||
# get the prompt
|
# get the prompt
|
||||||
|
@ -1519,21 +1513,21 @@ class QdrantSemanticCache(BaseCache):
|
||||||
"payload": {
|
"payload": {
|
||||||
"text": prompt,
|
"text": prompt,
|
||||||
"response": value,
|
"response": value,
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
keys = await self.async_client.put(
|
keys = await self.async_client.put(
|
||||||
url=f"{self.qdrant_url}/collections/{self.collection_name}/points",
|
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points",
|
||||||
headers=self.headers,
|
headers=self.headers,
|
||||||
json=data
|
json=data,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
async def async_get_cache(self, key, **kwargs):
|
async def async_get_cache(self, key, **kwargs):
|
||||||
print_verbose(f"async qdrant semantic-cache get_cache, kwargs: {kwargs}")
|
print_verbose(f"async qdrant semantic-cache get_cache, kwargs: {kwargs}")
|
||||||
from litellm.proxy.proxy_server import llm_router, llm_model_list
|
from litellm.proxy.proxy_server import llm_model_list, llm_router
|
||||||
|
|
||||||
# get the messages
|
# get the messages
|
||||||
messages = kwargs["messages"]
|
messages = kwargs["messages"]
|
||||||
|
@ -1579,13 +1573,13 @@ class QdrantSemanticCache(BaseCache):
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"limit": 1,
|
"limit": 1,
|
||||||
"with_payload": True
|
"with_payload": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
search_response = await self.async_client.post(
|
search_response = await self.async_client.post(
|
||||||
url=f"{self.qdrant_url}/collections/{self.collection_name}/points/search",
|
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search",
|
||||||
headers=self.headers,
|
headers=self.headers,
|
||||||
json=data
|
json=data,
|
||||||
)
|
)
|
||||||
|
|
||||||
results = search_response.json()["result"]
|
results = search_response.json()["result"]
|
||||||
|
@ -1624,6 +1618,7 @@ class QdrantSemanticCache(BaseCache):
|
||||||
async def _collection_info(self):
|
async def _collection_info(self):
|
||||||
return self.collection_info
|
return self.collection_info
|
||||||
|
|
||||||
|
|
||||||
class S3Cache(BaseCache):
|
class S3Cache(BaseCache):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -2129,12 +2124,11 @@ class Cache:
|
||||||
redis_semantic_cache_embedding_model="text-embedding-ada-002",
|
redis_semantic_cache_embedding_model="text-embedding-ada-002",
|
||||||
redis_flush_size=None,
|
redis_flush_size=None,
|
||||||
disk_cache_dir=None,
|
disk_cache_dir=None,
|
||||||
qdrant_url: Optional[str] = None,
|
qdrant_api_base: Optional[str] = None,
|
||||||
qdrant_api_key: Optional[str] = None,
|
qdrant_api_key: Optional[str] = None,
|
||||||
qdrant_collection_name: Optional[str] = None,
|
qdrant_collection_name: Optional[str] = None,
|
||||||
qdrant_quantization_config: Optional[str] = None,
|
qdrant_quantization_config: Optional[str] = None,
|
||||||
qdrant_semantic_cache_embedding_model="text-embedding-ada-002",
|
qdrant_semantic_cache_embedding_model="text-embedding-ada-002",
|
||||||
qdrant_host_type: Optional[Literal["local","cloud"]] = "local",
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -2145,9 +2139,8 @@ class Cache:
|
||||||
host (str, optional): The host address for the Redis cache. Required if type is "redis".
|
host (str, optional): The host address for the Redis cache. Required if type is "redis".
|
||||||
port (int, optional): The port number for the Redis cache. Required if type is "redis".
|
port (int, optional): The port number for the Redis cache. Required if type is "redis".
|
||||||
password (str, optional): The password for the Redis cache. Required if type is "redis".
|
password (str, optional): The password for the Redis cache. Required if type is "redis".
|
||||||
qdrant_url (str, optional): The url for your qdrant cluster. Required if type is "qdrant-semantic".
|
qdrant_api_base (str, optional): The url for your qdrant cluster. Required if type is "qdrant-semantic".
|
||||||
qdrant_api_key (str, optional): The api_key for the local or cloud qdrant cluster. Required if qdrant_host_type is "cloud" and optional if qdrant_host_type is "local".
|
qdrant_api_key (str, optional): The api_key for the local or cloud qdrant cluster.
|
||||||
qdrant_host_type (str, optional): Can be either "local" or "cloud". Should be "local" when you are running a local qdrant cluster or "cloud" when you are using a qdrant cloud cluster.
|
|
||||||
qdrant_collection_name (str, optional): The name for your qdrant collection. Required if type is "qdrant-semantic".
|
qdrant_collection_name (str, optional): The name for your qdrant collection. Required if type is "qdrant-semantic".
|
||||||
similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic" or "qdrant-semantic".
|
similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic" or "qdrant-semantic".
|
||||||
|
|
||||||
|
@ -2176,13 +2169,12 @@ class Cache:
|
||||||
)
|
)
|
||||||
elif type == "qdrant-semantic":
|
elif type == "qdrant-semantic":
|
||||||
self.cache = QdrantSemanticCache(
|
self.cache = QdrantSemanticCache(
|
||||||
qdrant_url= qdrant_url,
|
qdrant_api_base=qdrant_api_base,
|
||||||
qdrant_api_key=qdrant_api_key,
|
qdrant_api_key=qdrant_api_key,
|
||||||
collection_name=qdrant_collection_name,
|
collection_name=qdrant_collection_name,
|
||||||
similarity_threshold=similarity_threshold,
|
similarity_threshold=similarity_threshold,
|
||||||
quantization_config=qdrant_quantization_config,
|
quantization_config=qdrant_quantization_config,
|
||||||
embedding_model=qdrant_semantic_cache_embedding_model,
|
embedding_model=qdrant_semantic_cache_embedding_model,
|
||||||
host_type=qdrant_host_type
|
|
||||||
)
|
)
|
||||||
elif type == "local":
|
elif type == "local":
|
||||||
self.cache = InMemoryCache()
|
self.cache = InMemoryCache()
|
||||||
|
|
|
@ -210,7 +210,7 @@ class Logging:
|
||||||
self.optional_params = optional_params
|
self.optional_params = optional_params
|
||||||
self.model = model
|
self.model = model
|
||||||
self.user = user
|
self.user = user
|
||||||
self.litellm_params = litellm_params
|
self.litellm_params = scrub_sensitive_keys_in_metadata(litellm_params)
|
||||||
self.logger_fn = litellm_params.get("logger_fn", None)
|
self.logger_fn = litellm_params.get("logger_fn", None)
|
||||||
verbose_logger.debug(f"self.optional_params: {self.optional_params}")
|
verbose_logger.debug(f"self.optional_params: {self.optional_params}")
|
||||||
|
|
||||||
|
@ -2353,3 +2353,28 @@ def get_standard_logging_object_payload(
|
||||||
"Error creating standard logging object - {}".format(str(e))
|
"Error creating standard logging object - {}".format(str(e))
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]):
|
||||||
|
if litellm_params is None:
|
||||||
|
litellm_params = {}
|
||||||
|
|
||||||
|
metadata = litellm_params.get("metadata", {}) or {}
|
||||||
|
|
||||||
|
## check user_api_key_metadata for sensitive logging keys
|
||||||
|
cleaned_user_api_key_metadata = {}
|
||||||
|
if "user_api_key_metadata" in metadata and isinstance(
|
||||||
|
metadata["user_api_key_metadata"], dict
|
||||||
|
):
|
||||||
|
for k, v in metadata["user_api_key_metadata"].items():
|
||||||
|
if k == "logging": # prevent logging user logging keys
|
||||||
|
cleaned_user_api_key_metadata[k] = (
|
||||||
|
"scrubbed_by_litellm_for_sensitive_keys"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cleaned_user_api_key_metadata[k] = v
|
||||||
|
|
||||||
|
metadata["user_api_key_metadata"] = cleaned_user_api_key_metadata
|
||||||
|
litellm_params["metadata"] = metadata
|
||||||
|
|
||||||
|
return litellm_params
|
||||||
|
|
|
@ -188,9 +188,11 @@ class GoogleAIStudioGeminiConfig: # key diff from VertexAI - 'frequency_penalty
|
||||||
elif value["type"] == "text": # type: ignore
|
elif value["type"] == "text": # type: ignore
|
||||||
optional_params["response_mime_type"] = "text/plain"
|
optional_params["response_mime_type"] = "text/plain"
|
||||||
if "response_schema" in value: # type: ignore
|
if "response_schema" in value: # type: ignore
|
||||||
|
optional_params["response_mime_type"] = "application/json"
|
||||||
optional_params["response_schema"] = value["response_schema"] # type: ignore
|
optional_params["response_schema"] = value["response_schema"] # type: ignore
|
||||||
elif value["type"] == "json_schema": # type: ignore
|
elif value["type"] == "json_schema": # type: ignore
|
||||||
if "json_schema" in value and "schema" in value["json_schema"]: # type: ignore
|
if "json_schema" in value and "schema" in value["json_schema"]: # type: ignore
|
||||||
|
optional_params["response_mime_type"] = "application/json"
|
||||||
optional_params["response_schema"] = value["json_schema"]["schema"] # type: ignore
|
optional_params["response_schema"] = value["json_schema"]["schema"] # type: ignore
|
||||||
if param == "tools" and isinstance(value, list):
|
if param == "tools" and isinstance(value, list):
|
||||||
gtool_func_declarations = []
|
gtool_func_declarations = []
|
||||||
|
@ -400,9 +402,11 @@ class VertexGeminiConfig:
|
||||||
elif value["type"] == "text":
|
elif value["type"] == "text":
|
||||||
optional_params["response_mime_type"] = "text/plain"
|
optional_params["response_mime_type"] = "text/plain"
|
||||||
if "response_schema" in value:
|
if "response_schema" in value:
|
||||||
|
optional_params["response_mime_type"] = "application/json"
|
||||||
optional_params["response_schema"] = value["response_schema"]
|
optional_params["response_schema"] = value["response_schema"]
|
||||||
elif value["type"] == "json_schema": # type: ignore
|
elif value["type"] == "json_schema": # type: ignore
|
||||||
if "json_schema" in value and "schema" in value["json_schema"]: # type: ignore
|
if "json_schema" in value and "schema" in value["json_schema"]: # type: ignore
|
||||||
|
optional_params["response_mime_type"] = "application/json"
|
||||||
optional_params["response_schema"] = value["json_schema"]["schema"] # type: ignore
|
optional_params["response_schema"] = value["json_schema"]["schema"] # type: ignore
|
||||||
if param == "frequency_penalty":
|
if param == "frequency_penalty":
|
||||||
optional_params["frequency_penalty"] = value
|
optional_params["frequency_penalty"] = value
|
||||||
|
|
|
@ -21,6 +21,13 @@ else:
|
||||||
Span = Any
|
Span = Any
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLMTeamRoles(enum.Enum):
|
||||||
|
# team admin
|
||||||
|
TEAM_ADMIN = "admin"
|
||||||
|
# team member
|
||||||
|
TEAM_MEMBER = "user"
|
||||||
|
|
||||||
|
|
||||||
class LitellmUserRoles(str, enum.Enum):
|
class LitellmUserRoles(str, enum.Enum):
|
||||||
"""
|
"""
|
||||||
Admin Roles:
|
Admin Roles:
|
||||||
|
@ -335,6 +342,11 @@ class LiteLLMRoutes(enum.Enum):
|
||||||
+ sso_only_routes
|
+ sso_only_routes
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self_managed_routes: List = [
|
||||||
|
"/team/member_add",
|
||||||
|
"/team/member_delete",
|
||||||
|
] # routes that manage their own allowed/disallowed logic
|
||||||
|
|
||||||
|
|
||||||
# class LiteLLMAllowedRoutes(LiteLLMBase):
|
# class LiteLLMAllowedRoutes(LiteLLMBase):
|
||||||
# """
|
# """
|
||||||
|
@ -1308,6 +1320,7 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
|
||||||
soft_budget: Optional[float] = None
|
soft_budget: Optional[float] = None
|
||||||
team_model_aliases: Optional[Dict] = None
|
team_model_aliases: Optional[Dict] = None
|
||||||
team_member_spend: Optional[float] = None
|
team_member_spend: Optional[float] = None
|
||||||
|
team_member: Optional[Member] = None
|
||||||
team_metadata: Optional[Dict] = None
|
team_metadata: Optional[Dict] = None
|
||||||
|
|
||||||
# End User Params
|
# End User Params
|
||||||
|
|
|
@ -975,8 +975,6 @@ async def user_api_key_auth(
|
||||||
if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin
|
if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin
|
||||||
if is_llm_api_route(route=route):
|
if is_llm_api_route(route=route):
|
||||||
pass
|
pass
|
||||||
elif is_llm_api_route(route=request["route"].name):
|
|
||||||
pass
|
|
||||||
elif (
|
elif (
|
||||||
route in LiteLLMRoutes.info_routes.value
|
route in LiteLLMRoutes.info_routes.value
|
||||||
): # check if user allowed to call an info route
|
): # check if user allowed to call an info route
|
||||||
|
@ -1046,11 +1044,16 @@ async def user_api_key_auth(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail=f"user not allowed to access this route, role= {_user_role}. Trying to access: {route}",
|
detail=f"user not allowed to access this route, role= {_user_role}. Trying to access: {route}",
|
||||||
)
|
)
|
||||||
|
|
||||||
elif (
|
elif (
|
||||||
_user_role == LitellmUserRoles.INTERNAL_USER.value
|
_user_role == LitellmUserRoles.INTERNAL_USER.value
|
||||||
and route in LiteLLMRoutes.internal_user_routes.value
|
and route in LiteLLMRoutes.internal_user_routes.value
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
elif (
|
||||||
|
route in LiteLLMRoutes.self_managed_routes.value
|
||||||
|
): # routes that manage their own allowed/disallowed logic
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
user_role = "unknown"
|
user_role = "unknown"
|
||||||
user_id = "unknown"
|
user_id = "unknown"
|
||||||
|
|
|
@ -285,14 +285,18 @@ def get_remaining_tokens_and_requests_from_request_data(data: Dict) -> Dict[str,
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
|
|
||||||
def get_applied_guardrails_header(request_data: Dict) -> Optional[Dict]:
|
def get_logging_caching_headers(request_data: Dict) -> Optional[Dict]:
|
||||||
_metadata = request_data.get("metadata", None) or {}
|
_metadata = request_data.get("metadata", None) or {}
|
||||||
|
headers = {}
|
||||||
if "applied_guardrails" in _metadata:
|
if "applied_guardrails" in _metadata:
|
||||||
return {
|
headers["x-litellm-applied-guardrails"] = ",".join(
|
||||||
"x-litellm-applied-guardrails": ",".join(_metadata["applied_guardrails"]),
|
_metadata["applied_guardrails"]
|
||||||
}
|
)
|
||||||
|
|
||||||
return None
|
if "semantic-similarity" in _metadata:
|
||||||
|
headers["x-litellm-semantic-similarity"] = str(_metadata["semantic-similarity"])
|
||||||
|
|
||||||
|
return headers
|
||||||
|
|
||||||
|
|
||||||
def add_guardrail_to_applied_guardrails_header(
|
def add_guardrail_to_applied_guardrails_header(
|
||||||
|
|
|
@ -95,7 +95,9 @@ def convert_key_logging_metadata_to_callback(
|
||||||
for var, value in data.callback_vars.items():
|
for var, value in data.callback_vars.items():
|
||||||
if team_callback_settings_obj.callback_vars is None:
|
if team_callback_settings_obj.callback_vars is None:
|
||||||
team_callback_settings_obj.callback_vars = {}
|
team_callback_settings_obj.callback_vars = {}
|
||||||
team_callback_settings_obj.callback_vars[var] = litellm.get_secret(value)
|
team_callback_settings_obj.callback_vars[var] = (
|
||||||
|
litellm.utils.get_secret(value, default_value=value) or value
|
||||||
|
)
|
||||||
|
|
||||||
return team_callback_settings_obj
|
return team_callback_settings_obj
|
||||||
|
|
||||||
|
@ -130,7 +132,6 @@ def _get_dynamic_logging_metadata(
|
||||||
data=AddTeamCallback(**item),
|
data=AddTeamCallback(**item),
|
||||||
team_callback_settings_obj=callback_settings_obj,
|
team_callback_settings_obj=callback_settings_obj,
|
||||||
)
|
)
|
||||||
|
|
||||||
return callback_settings_obj
|
return callback_settings_obj
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -119,6 +119,7 @@ async def new_user(
|
||||||
http_request=Request(
|
http_request=Request(
|
||||||
scope={"type": "http", "path": "/user/new"},
|
scope={"type": "http", "path": "/user/new"},
|
||||||
),
|
),
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
if data.send_invite_email is True:
|
if data.send_invite_email is True:
|
||||||
|
|
|
@ -849,7 +849,7 @@ async def generate_key_helper_fn(
|
||||||
}
|
}
|
||||||
|
|
||||||
if (
|
if (
|
||||||
litellm.get_secret("DISABLE_KEY_NAME", False) == True
|
litellm.get_secret("DISABLE_KEY_NAME", False) is True
|
||||||
): # allow user to disable storing abbreviated key name (shown in UI, to help figure out which key spent how much)
|
): # allow user to disable storing abbreviated key name (shown in UI, to help figure out which key spent how much)
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -30,7 +30,7 @@ from litellm.proxy._types import (
|
||||||
UpdateTeamRequest,
|
UpdateTeamRequest,
|
||||||
UserAPIKeyAuth,
|
UserAPIKeyAuth,
|
||||||
)
|
)
|
||||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
from litellm.proxy.auth.user_api_key_auth import _is_user_proxy_admin, user_api_key_auth
|
||||||
from litellm.proxy.management_helpers.utils import (
|
from litellm.proxy.management_helpers.utils import (
|
||||||
add_new_member,
|
add_new_member,
|
||||||
management_endpoint_wrapper,
|
management_endpoint_wrapper,
|
||||||
|
@ -39,6 +39,16 @@ from litellm.proxy.management_helpers.utils import (
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
def _is_user_team_admin(
|
||||||
|
user_api_key_dict: UserAPIKeyAuth, team_obj: LiteLLM_TeamTable
|
||||||
|
) -> bool:
|
||||||
|
for member in team_obj.members_with_roles:
|
||||||
|
if member.user_id is not None and member.user_id == user_api_key_dict.user_id:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
#### TEAM MANAGEMENT ####
|
#### TEAM MANAGEMENT ####
|
||||||
@router.post(
|
@router.post(
|
||||||
"/team/new",
|
"/team/new",
|
||||||
|
@ -417,6 +427,7 @@ async def team_member_add(
|
||||||
|
|
||||||
If user doesn't exist, new user row will also be added to User Table
|
If user doesn't exist, new user row will also be added to User Table
|
||||||
|
|
||||||
|
Only proxy_admin or admin of team, allowed to access this endpoint.
|
||||||
```
|
```
|
||||||
|
|
||||||
curl -X POST 'http://0.0.0.0:4000/team/member_add' \
|
curl -X POST 'http://0.0.0.0:4000/team/member_add' \
|
||||||
|
@ -465,6 +476,24 @@ async def team_member_add(
|
||||||
|
|
||||||
complete_team_data = LiteLLM_TeamTable(**existing_team_row.model_dump())
|
complete_team_data = LiteLLM_TeamTable(**existing_team_row.model_dump())
|
||||||
|
|
||||||
|
## CHECK IF USER IS PROXY ADMIN OR TEAM ADMIN
|
||||||
|
|
||||||
|
if (
|
||||||
|
user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value
|
||||||
|
and not _is_user_team_admin(
|
||||||
|
user_api_key_dict=user_api_key_dict, team_obj=complete_team_data
|
||||||
|
)
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail={
|
||||||
|
"error": "Call not allowed. User not proxy admin OR team admin. route={}, team_id={}".format(
|
||||||
|
"/team/member_add",
|
||||||
|
complete_team_data.team_id,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(data.member, Member):
|
if isinstance(data.member, Member):
|
||||||
# add to team db
|
# add to team db
|
||||||
new_member = data.member
|
new_member = data.member
|
||||||
|
@ -569,6 +598,23 @@ async def team_member_delete(
|
||||||
)
|
)
|
||||||
existing_team_row = LiteLLM_TeamTable(**_existing_team_row.model_dump())
|
existing_team_row = LiteLLM_TeamTable(**_existing_team_row.model_dump())
|
||||||
|
|
||||||
|
## CHECK IF USER IS PROXY ADMIN OR TEAM ADMIN
|
||||||
|
|
||||||
|
if (
|
||||||
|
user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value
|
||||||
|
and not _is_user_team_admin(
|
||||||
|
user_api_key_dict=user_api_key_dict, team_obj=existing_team_row
|
||||||
|
)
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail={
|
||||||
|
"error": "Call not allowed. User not proxy admin OR team admin. route={}, team_id={}".format(
|
||||||
|
"/team/member_delete", existing_team_row.team_id
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
## DELETE MEMBER FROM TEAM
|
## DELETE MEMBER FROM TEAM
|
||||||
new_team_members: List[Member] = []
|
new_team_members: List[Member] = []
|
||||||
for m in existing_team_row.members_with_roles:
|
for m in existing_team_row.members_with_roles:
|
||||||
|
|
|
@ -4,15 +4,17 @@ model_list:
|
||||||
model: openai/fake
|
model: openai/fake
|
||||||
api_key: fake-key
|
api_key: fake-key
|
||||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||||
|
- model_name: openai-embedding
|
||||||
guardrails:
|
|
||||||
- guardrail_name: "lakera-pre-guard"
|
|
||||||
litellm_params:
|
litellm_params:
|
||||||
guardrail: lakera # supported values: "aporia", "bedrock", "lakera"
|
model: openai/text-embedding-3-small
|
||||||
mode: "during_call"
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
api_key: os.environ/LAKERA_API_KEY
|
|
||||||
api_base: os.environ/LAKERA_API_BASE
|
|
||||||
category_thresholds:
|
|
||||||
prompt_injection: 0.1
|
|
||||||
jailbreak: 0.1
|
|
||||||
|
|
||||||
|
litellm_settings:
|
||||||
|
set_verbose: True
|
||||||
|
cache: True # set cache responses to True, litellm defaults to using a redis cache
|
||||||
|
cache_params:
|
||||||
|
type: qdrant-semantic
|
||||||
|
qdrant_semantic_cache_embedding_model: openai-embedding
|
||||||
|
qdrant_collection_name: test_collection
|
||||||
|
qdrant_quantization_config: binary
|
||||||
|
similarity_threshold: 0.8 # similarity threshold for semantic cache
|
|
@ -149,7 +149,7 @@ from litellm.proxy.common_utils.admin_ui_utils import (
|
||||||
show_missing_vars_in_env,
|
show_missing_vars_in_env,
|
||||||
)
|
)
|
||||||
from litellm.proxy.common_utils.callback_utils import (
|
from litellm.proxy.common_utils.callback_utils import (
|
||||||
get_applied_guardrails_header,
|
get_logging_caching_headers,
|
||||||
get_remaining_tokens_and_requests_from_request_data,
|
get_remaining_tokens_and_requests_from_request_data,
|
||||||
initialize_callbacks_on_proxy,
|
initialize_callbacks_on_proxy,
|
||||||
)
|
)
|
||||||
|
@ -543,9 +543,9 @@ def get_custom_headers(
|
||||||
)
|
)
|
||||||
headers.update(remaining_tokens_header)
|
headers.update(remaining_tokens_header)
|
||||||
|
|
||||||
applied_guardrails = get_applied_guardrails_header(request_data)
|
logging_caching_headers = get_logging_caching_headers(request_data)
|
||||||
if applied_guardrails:
|
if logging_caching_headers:
|
||||||
headers.update(applied_guardrails)
|
headers.update(logging_caching_headers)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return {
|
return {
|
||||||
|
|
|
@ -44,6 +44,7 @@ from litellm.proxy._types import (
|
||||||
DynamoDBArgs,
|
DynamoDBArgs,
|
||||||
LiteLLM_VerificationTokenView,
|
LiteLLM_VerificationTokenView,
|
||||||
LitellmUserRoles,
|
LitellmUserRoles,
|
||||||
|
Member,
|
||||||
ResetTeamBudgetRequest,
|
ResetTeamBudgetRequest,
|
||||||
SpendLogsMetadata,
|
SpendLogsMetadata,
|
||||||
SpendLogsPayload,
|
SpendLogsPayload,
|
||||||
|
@ -1395,6 +1396,7 @@ class PrismaClient:
|
||||||
t.blocked AS team_blocked,
|
t.blocked AS team_blocked,
|
||||||
t.team_alias AS team_alias,
|
t.team_alias AS team_alias,
|
||||||
t.metadata AS team_metadata,
|
t.metadata AS team_metadata,
|
||||||
|
t.members_with_roles AS team_members_with_roles,
|
||||||
tm.spend AS team_member_spend,
|
tm.spend AS team_member_spend,
|
||||||
m.aliases as team_model_aliases
|
m.aliases as team_model_aliases
|
||||||
FROM "LiteLLM_VerificationToken" AS v
|
FROM "LiteLLM_VerificationToken" AS v
|
||||||
|
@ -1412,6 +1414,33 @@ class PrismaClient:
|
||||||
response["team_models"] = []
|
response["team_models"] = []
|
||||||
if response["team_blocked"] is None:
|
if response["team_blocked"] is None:
|
||||||
response["team_blocked"] = False
|
response["team_blocked"] = False
|
||||||
|
|
||||||
|
team_member: Optional[Member] = None
|
||||||
|
if (
|
||||||
|
response["team_members_with_roles"] is not None
|
||||||
|
and response["user_id"] is not None
|
||||||
|
):
|
||||||
|
## find the team member corresponding to user id
|
||||||
|
"""
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "admin",
|
||||||
|
"user_id": "default_user_id",
|
||||||
|
"user_email": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"user_id": null,
|
||||||
|
"user_email": "test@email.com"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
for tm in response["team_members_with_roles"]:
|
||||||
|
if tm.get("user_id") is not None and response[
|
||||||
|
"user_id"
|
||||||
|
] == tm.get("user_id"):
|
||||||
|
team_member = Member(**tm)
|
||||||
|
response["team_member"] = team_member
|
||||||
response = LiteLLM_VerificationTokenView(
|
response = LiteLLM_VerificationTokenView(
|
||||||
**response, last_refreshed_at=time.time()
|
**response, last_refreshed_at=time.time()
|
||||||
)
|
)
|
||||||
|
|
|
@ -1558,6 +1558,16 @@ async def test_gemini_pro_json_schema_args_sent_httpx_openai_schema(
|
||||||
"response_schema"
|
"response_schema"
|
||||||
in mock_call.call_args.kwargs["json"]["generationConfig"]
|
in mock_call.call_args.kwargs["json"]["generationConfig"]
|
||||||
)
|
)
|
||||||
|
assert (
|
||||||
|
"response_mime_type"
|
||||||
|
in mock_call.call_args.kwargs["json"]["generationConfig"]
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
mock_call.call_args.kwargs["json"]["generationConfig"][
|
||||||
|
"response_mime_type"
|
||||||
|
]
|
||||||
|
== "application/json"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
assert (
|
assert (
|
||||||
"response_schema"
|
"response_schema"
|
||||||
|
|
|
@ -1733,8 +1733,10 @@ def test_caching_redis_simple(caplog, capsys):
|
||||||
assert redis_service_logging_error is False
|
assert redis_service_logging_error is False
|
||||||
assert "async success_callback: reaches cache for logging" not in captured.out
|
assert "async success_callback: reaches cache for logging" not in captured.out
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_qdrant_semantic_cache_acompletion():
|
async def test_qdrant_semantic_cache_acompletion():
|
||||||
|
litellm.set_verbose = True
|
||||||
random_number = random.randint(
|
random_number = random.randint(
|
||||||
1, 100000
|
1, 100000
|
||||||
) # add a random number to ensure it's always adding /reading from cache
|
) # add a random number to ensure it's always adding /reading from cache
|
||||||
|
@ -1743,12 +1745,12 @@ async def test_qdrant_semantic_cache_acompletion():
|
||||||
|
|
||||||
litellm.cache = Cache(
|
litellm.cache = Cache(
|
||||||
type="qdrant-semantic",
|
type="qdrant-semantic",
|
||||||
qdrant_host_type="cloud",
|
_host_type="cloud",
|
||||||
qdrant_url=os.getenv("QDRANT_URL"),
|
qdrant_api_base=os.getenv("QDRANT_URL"),
|
||||||
qdrant_api_key=os.getenv("QDRANT_API_KEY"),
|
qdrant_api_key=os.getenv("QDRANT_API_KEY"),
|
||||||
qdrant_collection_name='test_collection',
|
qdrant_collection_name="test_collection",
|
||||||
similarity_threshold=0.8,
|
similarity_threshold=0.8,
|
||||||
qdrant_quantization_config="binary"
|
qdrant_quantization_config="binary",
|
||||||
)
|
)
|
||||||
|
|
||||||
response1 = await litellm.acompletion(
|
response1 = await litellm.acompletion(
|
||||||
|
@ -1759,6 +1761,7 @@ async def test_qdrant_semantic_cache_acompletion():
|
||||||
"content": f"write a one sentence poem about: {random_number}",
|
"content": f"write a one sentence poem about: {random_number}",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
mock_response="hello",
|
||||||
max_tokens=20,
|
max_tokens=20,
|
||||||
)
|
)
|
||||||
print(f"Response1: {response1}")
|
print(f"Response1: {response1}")
|
||||||
|
@ -1778,6 +1781,7 @@ async def test_qdrant_semantic_cache_acompletion():
|
||||||
print(f"Response2: {response2}")
|
print(f"Response2: {response2}")
|
||||||
assert response1.id == response2.id
|
assert response1.id == response2.id
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_qdrant_semantic_cache_acompletion_stream():
|
async def test_qdrant_semantic_cache_acompletion_stream():
|
||||||
try:
|
try:
|
||||||
|
@ -1790,12 +1794,11 @@ async def test_qdrant_semantic_cache_acompletion_stream():
|
||||||
]
|
]
|
||||||
litellm.cache = Cache(
|
litellm.cache = Cache(
|
||||||
type="qdrant-semantic",
|
type="qdrant-semantic",
|
||||||
qdrant_host_type="cloud",
|
qdrant_api_base=os.getenv("QDRANT_URL"),
|
||||||
qdrant_url=os.getenv("QDRANT_URL"),
|
|
||||||
qdrant_api_key=os.getenv("QDRANT_API_KEY"),
|
qdrant_api_key=os.getenv("QDRANT_API_KEY"),
|
||||||
qdrant_collection_name='test_collection',
|
qdrant_collection_name="test_collection",
|
||||||
similarity_threshold=0.8,
|
similarity_threshold=0.8,
|
||||||
qdrant_quantization_config="binary"
|
qdrant_quantization_config="binary",
|
||||||
)
|
)
|
||||||
print("Test Qdrant Semantic Caching with streaming + acompletion")
|
print("Test Qdrant Semantic Caching with streaming + acompletion")
|
||||||
response_1_content = ""
|
response_1_content = ""
|
||||||
|
@ -1807,6 +1810,7 @@ async def test_qdrant_semantic_cache_acompletion_stream():
|
||||||
max_tokens=40,
|
max_tokens=40,
|
||||||
temperature=1,
|
temperature=1,
|
||||||
stream=True,
|
stream=True,
|
||||||
|
mock_response="hi",
|
||||||
)
|
)
|
||||||
async for chunk in response1:
|
async for chunk in response1:
|
||||||
response_1_id = chunk.id
|
response_1_id = chunk.id
|
||||||
|
@ -1830,7 +1834,9 @@ async def test_qdrant_semantic_cache_acompletion_stream():
|
||||||
assert (
|
assert (
|
||||||
response_1_content == response_2_content
|
response_1_content == response_2_content
|
||||||
), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
|
), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
|
||||||
assert (response_1_id == response_2_id), f"Response 1 id != Response 2 id, Response 1 id: {response_1_id} != Response 2 id: {response_2_id}"
|
assert (
|
||||||
|
response_1_id == response_2_id
|
||||||
|
), f"Response 1 id != Response 2 id, Response 1 id: {response_1_id} != Response 2 id: {response_2_id}"
|
||||||
litellm.cache = None
|
litellm.cache = None
|
||||||
litellm.success_callback = []
|
litellm.success_callback = []
|
||||||
litellm._async_success_callback = []
|
litellm._async_success_callback = []
|
||||||
|
|
|
@ -3283,9 +3283,9 @@ def test_completion_together_ai_mixtral():
|
||||||
# test_completion_together_ai_mixtral()
|
# test_completion_together_ai_mixtral()
|
||||||
|
|
||||||
|
|
||||||
def test_completion_together_ai_yi_chat():
|
def test_completion_together_ai_llama():
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
model_name = "together_ai/mistralai/Mistral-7B-Instruct-v0.1"
|
model_name = "together_ai/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo"
|
||||||
try:
|
try:
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "user", "content": "What llm are you?"},
|
{"role": "user", "content": "What llm are you?"},
|
||||||
|
|
|
@ -909,7 +909,7 @@ async def test_create_team_member_add(prisma_client, new_member_method):
|
||||||
|
|
||||||
await team_member_add(
|
await team_member_add(
|
||||||
data=team_member_add_request,
|
data=team_member_add_request,
|
||||||
user_api_key_dict=UserAPIKeyAuth(),
|
user_api_key_dict=UserAPIKeyAuth(user_role="proxy_admin"),
|
||||||
http_request=Request(
|
http_request=Request(
|
||||||
scope={"type": "http", "path": "/user/new"},
|
scope={"type": "http", "path": "/user/new"},
|
||||||
),
|
),
|
||||||
|
@ -930,6 +930,172 @@ async def test_create_team_member_add(prisma_client, new_member_method):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("team_member_role", ["admin", "user"])
|
||||||
|
@pytest.mark.parametrize("team_route", ["/team/member_add", "/team/member_delete"])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_team_member_add_team_admin_user_api_key_auth(
|
||||||
|
prisma_client, team_member_role, team_route
|
||||||
|
):
|
||||||
|
import time
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
from litellm.proxy._types import LiteLLM_TeamTableCachedObj, Member
|
||||||
|
from litellm.proxy.proxy_server import (
|
||||||
|
ProxyException,
|
||||||
|
hash_token,
|
||||||
|
user_api_key_auth,
|
||||||
|
user_api_key_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||||
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||||
|
setattr(litellm, "max_internal_user_budget", 10)
|
||||||
|
setattr(litellm, "internal_user_budget_duration", "5m")
|
||||||
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
|
user = f"ishaan {uuid.uuid4().hex}"
|
||||||
|
_team_id = "litellm-test-client-id-new"
|
||||||
|
user_key = "sk-12345678"
|
||||||
|
|
||||||
|
valid_token = UserAPIKeyAuth(
|
||||||
|
team_id=_team_id,
|
||||||
|
token=hash_token(user_key),
|
||||||
|
team_member=Member(role=team_member_role, user_id=user),
|
||||||
|
last_refreshed_at=time.time(),
|
||||||
|
)
|
||||||
|
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
|
||||||
|
|
||||||
|
team_obj = LiteLLM_TeamTableCachedObj(
|
||||||
|
team_id=_team_id,
|
||||||
|
blocked=False,
|
||||||
|
last_refreshed_at=time.time(),
|
||||||
|
metadata={"guardrails": {"modify_guardrails": False}},
|
||||||
|
)
|
||||||
|
|
||||||
|
user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)
|
||||||
|
|
||||||
|
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
|
||||||
|
|
||||||
|
## TEST IF TEAM ADMIN ALLOWED TO CALL /MEMBER_ADD ENDPOINT
|
||||||
|
import json
|
||||||
|
|
||||||
|
from starlette.datastructures import URL
|
||||||
|
|
||||||
|
request = Request(scope={"type": "http"})
|
||||||
|
request._url = URL(url=team_route)
|
||||||
|
|
||||||
|
body = {}
|
||||||
|
json_bytes = json.dumps(body).encode("utf-8")
|
||||||
|
|
||||||
|
request._body = json_bytes
|
||||||
|
|
||||||
|
## ALLOWED BY USER_API_KEY_AUTH
|
||||||
|
await user_api_key_auth(request=request, api_key="Bearer " + user_key)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("new_member_method", ["user_id", "user_email"])
|
||||||
|
@pytest.mark.parametrize("user_role", ["admin", "user"])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_team_member_add_team_admin(
|
||||||
|
prisma_client, new_member_method, user_role
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Relevant issue - https://github.com/BerriAI/litellm/issues/5300
|
||||||
|
|
||||||
|
Allow team admins to:
|
||||||
|
- Add and remove team members
|
||||||
|
- raise error if team member not an existing 'internal_user'
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
from litellm.proxy._types import LiteLLM_TeamTableCachedObj, Member
|
||||||
|
from litellm.proxy.proxy_server import (
|
||||||
|
HTTPException,
|
||||||
|
ProxyException,
|
||||||
|
hash_token,
|
||||||
|
user_api_key_auth,
|
||||||
|
user_api_key_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||||
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||||
|
setattr(litellm, "max_internal_user_budget", 10)
|
||||||
|
setattr(litellm, "internal_user_budget_duration", "5m")
|
||||||
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
|
user = f"ishaan {uuid.uuid4().hex}"
|
||||||
|
_team_id = "litellm-test-client-id-new"
|
||||||
|
user_key = "sk-12345678"
|
||||||
|
|
||||||
|
valid_token = UserAPIKeyAuth(
|
||||||
|
team_id=_team_id,
|
||||||
|
user_id=user,
|
||||||
|
token=hash_token(user_key),
|
||||||
|
last_refreshed_at=time.time(),
|
||||||
|
)
|
||||||
|
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
|
||||||
|
|
||||||
|
team_obj = LiteLLM_TeamTableCachedObj(
|
||||||
|
team_id=_team_id,
|
||||||
|
blocked=False,
|
||||||
|
last_refreshed_at=time.time(),
|
||||||
|
members_with_roles=[Member(role=user_role, user_id=user)],
|
||||||
|
metadata={"guardrails": {"modify_guardrails": False}},
|
||||||
|
)
|
||||||
|
|
||||||
|
user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)
|
||||||
|
|
||||||
|
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
|
||||||
|
if new_member_method == "user_id":
|
||||||
|
data = {
|
||||||
|
"team_id": _team_id,
|
||||||
|
"member": [{"role": "user", "user_id": user}],
|
||||||
|
}
|
||||||
|
elif new_member_method == "user_email":
|
||||||
|
data = {
|
||||||
|
"team_id": _team_id,
|
||||||
|
"member": [{"role": "user", "user_email": user}],
|
||||||
|
}
|
||||||
|
team_member_add_request = TeamMemberAddRequest(**data)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"litellm.proxy.proxy_server.prisma_client.db.litellm_usertable",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_litellm_usertable:
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_litellm_usertable.upsert = mock_client
|
||||||
|
mock_litellm_usertable.find_many = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await team_member_add(
|
||||||
|
data=team_member_add_request,
|
||||||
|
user_api_key_dict=valid_token,
|
||||||
|
http_request=Request(
|
||||||
|
scope={"type": "http", "path": "/user/new"},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except HTTPException as e:
|
||||||
|
if user_role == "user":
|
||||||
|
assert e.status_code == 403
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
mock_client.assert_called()
|
||||||
|
|
||||||
|
print(f"mock_client.call_args: {mock_client.call_args}")
|
||||||
|
print("mock_client.call_args.kwargs: {}".format(mock_client.call_args.kwargs))
|
||||||
|
|
||||||
|
assert (
|
||||||
|
mock_client.call_args.kwargs["data"]["create"]["max_budget"]
|
||||||
|
== litellm.max_internal_user_budget
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
mock_client.call_args.kwargs["data"]["create"]["budget_duration"]
|
||||||
|
== litellm.internal_user_budget_duration
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_user_info_team_list(prisma_client):
|
async def test_user_info_team_list(prisma_client):
|
||||||
"""Assert user_info for admin calls team_list function"""
|
"""Assert user_info for admin calls team_list function"""
|
||||||
|
@ -1116,8 +1282,8 @@ async def test_add_callback_via_key_litellm_pre_call_utils(prisma_client):
|
||||||
"callback_name": "langfuse",
|
"callback_name": "langfuse",
|
||||||
"callback_type": "success",
|
"callback_type": "success",
|
||||||
"callback_vars": {
|
"callback_vars": {
|
||||||
"langfuse_public_key": "os.environ/LANGFUSE_PUBLIC_KEY",
|
"langfuse_public_key": "my-mock-public-key",
|
||||||
"langfuse_secret_key": "os.environ/LANGFUSE_SECRET_KEY",
|
"langfuse_secret_key": "my-mock-secret-key",
|
||||||
"langfuse_host": "https://us.cloud.langfuse.com",
|
"langfuse_host": "https://us.cloud.langfuse.com",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -1165,7 +1331,9 @@ async def test_add_callback_via_key_litellm_pre_call_utils(prisma_client):
|
||||||
assert "success_callback" in new_data
|
assert "success_callback" in new_data
|
||||||
assert new_data["success_callback"] == ["langfuse"]
|
assert new_data["success_callback"] == ["langfuse"]
|
||||||
assert "langfuse_public_key" in new_data
|
assert "langfuse_public_key" in new_data
|
||||||
|
assert new_data["langfuse_public_key"] == "my-mock-public-key"
|
||||||
assert "langfuse_secret_key" in new_data
|
assert "langfuse_secret_key" in new_data
|
||||||
|
assert new_data["langfuse_secret_key"] == "my-mock-secret-key"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
|
@ -121,7 +121,7 @@ import importlib.metadata
|
||||||
from openai import OpenAIError as OriginalError
|
from openai import OpenAIError as OriginalError
|
||||||
|
|
||||||
from ._logging import verbose_logger
|
from ._logging import verbose_logger
|
||||||
from .caching import RedisCache, RedisSemanticCache, S3Cache, QdrantSemanticCache
|
from .caching import QdrantSemanticCache, RedisCache, RedisSemanticCache, S3Cache
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
APIConnectionError,
|
APIConnectionError,
|
||||||
APIError,
|
APIError,
|
||||||
|
@ -8622,7 +8622,9 @@ def get_secret(
|
||||||
return secret_value_as_bool
|
return secret_value_as_bool
|
||||||
else:
|
else:
|
||||||
return secret
|
return secret
|
||||||
except:
|
except Exception:
|
||||||
|
if default_value is not None:
|
||||||
|
return default_value
|
||||||
return secret
|
return secret
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if default_value is not None:
|
if default_value is not None:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue