mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Merge branch 'main' into litellm_disable_storing_master_key_hash_in_db
This commit is contained in:
commit
3971880af4
22 changed files with 537 additions and 169 deletions
|
@ -161,8 +161,7 @@ random_number = random.randint(
|
|||
print("testing semantic caching")
|
||||
litellm.cache = Cache(
|
||||
type="qdrant-semantic",
|
||||
qdrant_host_type="cloud", # can be either 'cloud' or 'local'
|
||||
qdrant_url=os.environ["QDRANT_URL"],
|
||||
qdrant_api_base=os.environ["QDRANT_API_BASE"],
|
||||
qdrant_api_key=os.environ["QDRANT_API_KEY"],
|
||||
qdrant_collection_name="your_collection_name", # any name of your collection
|
||||
similarity_threshold=0.7, # similarity threshold for cache hits, 0 == no similarity, 1 = exact matches, 0.5 == 50% similarity
|
||||
|
@ -491,12 +490,11 @@ def __init__(
|
|||
disk_cache_dir=None,
|
||||
|
||||
# qdrant cache params
|
||||
qdrant_url: Optional[str] = None,
|
||||
qdrant_api_base: Optional[str] = None,
|
||||
qdrant_api_key: Optional[str] = None,
|
||||
qdrant_collection_name: Optional[str] = None,
|
||||
qdrant_quantization_config: Optional[str] = None,
|
||||
qdrant_semantic_cache_embedding_model="text-embedding-ada-002",
|
||||
qdrant_host_type: Optional[Literal["local","cloud"]] = "local",
|
||||
|
||||
**kwargs
|
||||
):
|
||||
|
|
|
@ -7,6 +7,7 @@ Cache LLM Responses
|
|||
LiteLLM supports:
|
||||
- In Memory Cache
|
||||
- Redis Cache
|
||||
- Qdrant Semantic Cache
|
||||
- Redis Semantic Cache
|
||||
- s3 Bucket Cache
|
||||
|
||||
|
@ -103,6 +104,66 @@ $ litellm --config /path/to/config.yaml
|
|||
```
|
||||
</TabItem>
|
||||
|
||||
|
||||
<TabItem value="qdrant-semantic" label="Qdrant Semantic cache">
|
||||
|
||||
Caching can be enabled by adding the `cache` key in the `config.yaml`
|
||||
|
||||
#### Step 1: Add `cache` to the config.yaml
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: fake-openai-endpoint
|
||||
litellm_params:
|
||||
model: openai/fake
|
||||
api_key: fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
- model_name: openai-embedding
|
||||
litellm_params:
|
||||
model: openai/text-embedding-3-small
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
litellm_settings:
|
||||
set_verbose: True
|
||||
cache: True # set cache responses to True, litellm defaults to using a redis cache
|
||||
cache_params:
|
||||
type: qdrant-semantic
|
||||
qdrant_semantic_cache_embedding_model: openai-embedding # the model should be defined on the model_list
|
||||
qdrant_collection_name: test_collection
|
||||
qdrant_quantization_config: binary
|
||||
similarity_threshold: 0.8 # similarity threshold for semantic cache
|
||||
```
|
||||
|
||||
#### Step 2: Add Qdrant Credentials to your .env
|
||||
|
||||
```shell
|
||||
QDRANT_API_KEY = "16rJUMBRx*************"
|
||||
QDRANT_API_BASE = "https://5392d382-45*********.cloud.qdrant.io"
|
||||
```
|
||||
|
||||
#### Step 3: Run proxy with config
|
||||
```shell
|
||||
$ litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
|
||||
#### Step 4. Test it
|
||||
|
||||
```shell
|
||||
curl -i http://localhost:4000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-d '{
|
||||
"model": "fake-openai-endpoint",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
**Expect to see `x-litellm-semantic-similarity` in the response headers when semantic caching is one**
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="s3" label="s3 cache">
|
||||
|
||||
#### Step 1: Add `cache` to the config.yaml
|
||||
|
@ -182,6 +243,9 @@ REDIS_<redis-kwarg-name> = ""
|
|||
$ litellm --config /path/to/config.yaml
|
||||
```
|
||||
</TabItem>
|
||||
|
||||
|
||||
|
||||
</Tabs>
|
||||
|
||||
|
||||
|
|
|
@ -74,6 +74,7 @@ const sidebars = {
|
|||
"proxy/alerting",
|
||||
"proxy/ui",
|
||||
"proxy/prometheus",
|
||||
"proxy/caching",
|
||||
"proxy/pass_through",
|
||||
"proxy/email",
|
||||
"proxy/multiple_admins",
|
||||
|
@ -88,7 +89,6 @@ const sidebars = {
|
|||
"proxy/health",
|
||||
"proxy/debugging",
|
||||
"proxy/pii_masking",
|
||||
"proxy/caching",
|
||||
"proxy/call_hooks",
|
||||
"proxy/rules",
|
||||
"proxy/cli",
|
||||
|
|
|
@ -1219,25 +1219,28 @@ class RedisSemanticCache(BaseCache):
|
|||
async def _index_info(self):
|
||||
return await self.index.ainfo()
|
||||
|
||||
|
||||
class QdrantSemanticCache(BaseCache):
|
||||
def __init__(
|
||||
self,
|
||||
qdrant_url=None,
|
||||
qdrant_api_key = None,
|
||||
collection_name=None,
|
||||
similarity_threshold=None,
|
||||
quantization_config=None,
|
||||
embedding_model="text-embedding-ada-002",
|
||||
host_type = None
|
||||
):
|
||||
self,
|
||||
qdrant_api_base=None,
|
||||
qdrant_api_key=None,
|
||||
collection_name=None,
|
||||
similarity_threshold=None,
|
||||
quantization_config=None,
|
||||
embedding_model="text-embedding-ada-002",
|
||||
host_type=None,
|
||||
):
|
||||
import os
|
||||
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
_get_httpx_client,
|
||||
_get_async_httpx_client
|
||||
_get_async_httpx_client,
|
||||
_get_httpx_client,
|
||||
)
|
||||
|
||||
if collection_name is None:
|
||||
raise Exception("collection_name must be provided, passed None")
|
||||
|
||||
|
||||
self.collection_name = collection_name
|
||||
print_verbose(
|
||||
f"qdrant semantic-cache initializing COLLECTION - {self.collection_name}"
|
||||
|
@ -1247,108 +1250,97 @@ class QdrantSemanticCache(BaseCache):
|
|||
raise Exception("similarity_threshold must be provided, passed None")
|
||||
self.similarity_threshold = similarity_threshold
|
||||
self.embedding_model = embedding_model
|
||||
headers = {}
|
||||
|
||||
if host_type=="cloud":
|
||||
import os
|
||||
if qdrant_url is None:
|
||||
qdrant_url = os.getenv('QDRANT_URL')
|
||||
if qdrant_api_key is None:
|
||||
qdrant_api_key = os.getenv('QDRANT_API_KEY')
|
||||
if qdrant_url is not None and qdrant_api_key is not None:
|
||||
headers = {
|
||||
"api-key": qdrant_api_key,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
else:
|
||||
raise Exception("Qdrant url and api_key must be provided for qdrant cloud hosting")
|
||||
elif host_type=="local":
|
||||
import os
|
||||
if qdrant_url is None:
|
||||
qdrant_url = os.getenv('QDRANT_URL')
|
||||
if qdrant_url is None:
|
||||
raise Exception("Qdrant url must be provided for qdrant local hosting")
|
||||
if qdrant_api_key is None:
|
||||
qdrant_api_key = os.getenv('QDRANT_API_KEY')
|
||||
if qdrant_api_key is None:
|
||||
print_verbose('Running locally without API Key.')
|
||||
headers= {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
else:
|
||||
print_verbose("Running locally with API Key")
|
||||
headers = {
|
||||
"api-key": qdrant_api_key,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
else:
|
||||
raise Exception("Host type can be either 'local' or 'cloud'")
|
||||
|
||||
self.qdrant_url = qdrant_url
|
||||
# check if defined as os.environ/ variable
|
||||
if qdrant_api_base:
|
||||
if isinstance(qdrant_api_base, str) and qdrant_api_base.startswith(
|
||||
"os.environ/"
|
||||
):
|
||||
qdrant_api_base = litellm.get_secret(qdrant_api_base)
|
||||
if qdrant_api_key:
|
||||
if isinstance(qdrant_api_key, str) and qdrant_api_key.startswith(
|
||||
"os.environ/"
|
||||
):
|
||||
qdrant_api_key = litellm.get_secret(qdrant_api_key)
|
||||
|
||||
qdrant_api_base = (
|
||||
qdrant_api_base or os.getenv("QDRANT_URL") or os.getenv("QDRANT_API_BASE")
|
||||
)
|
||||
qdrant_api_key = qdrant_api_key or os.getenv("QDRANT_API_KEY")
|
||||
headers = {"api-key": qdrant_api_key, "Content-Type": "application/json"}
|
||||
|
||||
if qdrant_api_key is None or qdrant_api_base is None:
|
||||
raise ValueError("Qdrant url and api_key must be")
|
||||
|
||||
self.qdrant_api_base = qdrant_api_base
|
||||
self.qdrant_api_key = qdrant_api_key
|
||||
print_verbose(f"qdrant semantic-cache qdrant_url: {self.qdrant_url}")
|
||||
print_verbose(f"qdrant semantic-cache qdrant_api_base: {self.qdrant_api_base}")
|
||||
|
||||
self.headers = headers
|
||||
|
||||
|
||||
self.sync_client = _get_httpx_client()
|
||||
self.async_client = _get_async_httpx_client()
|
||||
|
||||
if quantization_config is None:
|
||||
print('Quantization config is not provided. Default binary quantization will be used.')
|
||||
|
||||
collection_exists = self.sync_client.get(
|
||||
url= f"{self.qdrant_url}/collections/{self.collection_name}/exists",
|
||||
headers=self.headers
|
||||
print_verbose(
|
||||
"Quantization config is not provided. Default binary quantization will be used."
|
||||
)
|
||||
if collection_exists.json()['result']['exists']:
|
||||
collection_exists = self.sync_client.get(
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/exists",
|
||||
headers=self.headers,
|
||||
)
|
||||
if collection_exists.status_code != 200:
|
||||
raise ValueError(
|
||||
f"Error from qdrant checking if /collections exist {collection_exists.text}"
|
||||
)
|
||||
|
||||
if collection_exists.json()["result"]["exists"]:
|
||||
collection_details = self.sync_client.get(
|
||||
url=f"{self.qdrant_url}/collections/{self.collection_name}",
|
||||
headers=self.headers
|
||||
)
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
|
||||
headers=self.headers,
|
||||
)
|
||||
self.collection_info = collection_details.json()
|
||||
print_verbose(f'Collection already exists.\nCollection details:{self.collection_info}')
|
||||
print_verbose(
|
||||
f"Collection already exists.\nCollection details:{self.collection_info}"
|
||||
)
|
||||
else:
|
||||
if quantization_config is None or quantization_config == 'binary':
|
||||
if quantization_config is None or quantization_config == "binary":
|
||||
quantization_params = {
|
||||
"binary": {
|
||||
"always_ram": False,
|
||||
}
|
||||
}
|
||||
elif quantization_config == 'scalar':
|
||||
elif quantization_config == "scalar":
|
||||
quantization_params = {
|
||||
"scalar": {
|
||||
"type": "int8",
|
||||
"quantile": 0.99,
|
||||
"always_ram": False
|
||||
}
|
||||
"scalar": {"type": "int8", "quantile": 0.99, "always_ram": False}
|
||||
}
|
||||
elif quantization_config == 'product':
|
||||
elif quantization_config == "product":
|
||||
quantization_params = {
|
||||
"product": {
|
||||
"compression": "x16",
|
||||
"always_ram": False
|
||||
}
|
||||
"product": {"compression": "x16", "always_ram": False}
|
||||
}
|
||||
else:
|
||||
raise Exception("Quantization config must be one of 'scalar', 'binary' or 'product'")
|
||||
|
||||
else:
|
||||
raise Exception(
|
||||
"Quantization config must be one of 'scalar', 'binary' or 'product'"
|
||||
)
|
||||
|
||||
new_collection_status = self.sync_client.put(
|
||||
url=f"{self.qdrant_url}/collections/{self.collection_name}",
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
|
||||
json={
|
||||
"vectors": {
|
||||
"size": 1536,
|
||||
"distance": "Cosine"
|
||||
},
|
||||
"quantization_config": quantization_params
|
||||
"vectors": {"size": 1536, "distance": "Cosine"},
|
||||
"quantization_config": quantization_params,
|
||||
},
|
||||
headers=self.headers
|
||||
headers=self.headers,
|
||||
)
|
||||
if new_collection_status.json()["result"]:
|
||||
collection_details = self.sync_client.get(
|
||||
url=f"{self.qdrant_url}/collections/{self.collection_name}",
|
||||
headers=self.headers
|
||||
)
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
|
||||
headers=self.headers,
|
||||
)
|
||||
self.collection_info = collection_details.json()
|
||||
print_verbose(f'New collection created.\nCollection details:{self.collection_info}')
|
||||
print_verbose(
|
||||
f"New collection created.\nCollection details:{self.collection_info}"
|
||||
)
|
||||
else:
|
||||
raise Exception("Error while creating new collection")
|
||||
|
||||
|
@ -1394,14 +1386,14 @@ class QdrantSemanticCache(BaseCache):
|
|||
"payload": {
|
||||
"text": prompt,
|
||||
"response": value,
|
||||
}
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
keys = self.sync_client.put(
|
||||
url=f"{self.qdrant_url}/collections/{self.collection_name}/points",
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points",
|
||||
headers=self.headers,
|
||||
json=data
|
||||
json=data,
|
||||
)
|
||||
return
|
||||
|
||||
|
@ -1433,14 +1425,14 @@ class QdrantSemanticCache(BaseCache):
|
|||
"oversampling": 3.0,
|
||||
}
|
||||
},
|
||||
"limit":1,
|
||||
"with_payload": True
|
||||
"limit": 1,
|
||||
"with_payload": True,
|
||||
}
|
||||
|
||||
search_response = self.sync_client.post(
|
||||
url=f"{self.qdrant_url}/collections/{self.collection_name}/points/search",
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search",
|
||||
headers=self.headers,
|
||||
json=data
|
||||
json=data,
|
||||
)
|
||||
results = search_response.json()["result"]
|
||||
|
||||
|
@ -1470,8 +1462,10 @@ class QdrantSemanticCache(BaseCache):
|
|||
pass
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
from litellm.proxy.proxy_server import llm_router, llm_model_list
|
||||
import uuid
|
||||
|
||||
from litellm.proxy.proxy_server import llm_model_list, llm_router
|
||||
|
||||
print_verbose(f"async qdrant semantic-cache set_cache, kwargs: {kwargs}")
|
||||
|
||||
# get the prompt
|
||||
|
@ -1519,21 +1513,21 @@ class QdrantSemanticCache(BaseCache):
|
|||
"payload": {
|
||||
"text": prompt,
|
||||
"response": value,
|
||||
}
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
keys = await self.async_client.put(
|
||||
url=f"{self.qdrant_url}/collections/{self.collection_name}/points",
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points",
|
||||
headers=self.headers,
|
||||
json=data
|
||||
json=data,
|
||||
)
|
||||
return
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
print_verbose(f"async qdrant semantic-cache get_cache, kwargs: {kwargs}")
|
||||
from litellm.proxy.proxy_server import llm_router, llm_model_list
|
||||
from litellm.proxy.proxy_server import llm_model_list, llm_router
|
||||
|
||||
# get the messages
|
||||
messages = kwargs["messages"]
|
||||
|
@ -1578,14 +1572,14 @@ class QdrantSemanticCache(BaseCache):
|
|||
"oversampling": 3.0,
|
||||
}
|
||||
},
|
||||
"limit":1,
|
||||
"with_payload": True
|
||||
"limit": 1,
|
||||
"with_payload": True,
|
||||
}
|
||||
|
||||
search_response = await self.async_client.post(
|
||||
url=f"{self.qdrant_url}/collections/{self.collection_name}/points/search",
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search",
|
||||
headers=self.headers,
|
||||
json=data
|
||||
json=data,
|
||||
)
|
||||
|
||||
results = search_response.json()["result"]
|
||||
|
@ -1624,6 +1618,7 @@ class QdrantSemanticCache(BaseCache):
|
|||
async def _collection_info(self):
|
||||
return self.collection_info
|
||||
|
||||
|
||||
class S3Cache(BaseCache):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -2129,12 +2124,11 @@ class Cache:
|
|||
redis_semantic_cache_embedding_model="text-embedding-ada-002",
|
||||
redis_flush_size=None,
|
||||
disk_cache_dir=None,
|
||||
qdrant_url: Optional[str] = None,
|
||||
qdrant_api_base: Optional[str] = None,
|
||||
qdrant_api_key: Optional[str] = None,
|
||||
qdrant_collection_name: Optional[str] = None,
|
||||
qdrant_quantization_config: Optional[str] = None,
|
||||
qdrant_semantic_cache_embedding_model="text-embedding-ada-002",
|
||||
qdrant_host_type: Optional[Literal["local","cloud"]] = "local",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
@ -2145,9 +2139,8 @@ class Cache:
|
|||
host (str, optional): The host address for the Redis cache. Required if type is "redis".
|
||||
port (int, optional): The port number for the Redis cache. Required if type is "redis".
|
||||
password (str, optional): The password for the Redis cache. Required if type is "redis".
|
||||
qdrant_url (str, optional): The url for your qdrant cluster. Required if type is "qdrant-semantic".
|
||||
qdrant_api_key (str, optional): The api_key for the local or cloud qdrant cluster. Required if qdrant_host_type is "cloud" and optional if qdrant_host_type is "local".
|
||||
qdrant_host_type (str, optional): Can be either "local" or "cloud". Should be "local" when you are running a local qdrant cluster or "cloud" when you are using a qdrant cloud cluster.
|
||||
qdrant_api_base (str, optional): The url for your qdrant cluster. Required if type is "qdrant-semantic".
|
||||
qdrant_api_key (str, optional): The api_key for the local or cloud qdrant cluster.
|
||||
qdrant_collection_name (str, optional): The name for your qdrant collection. Required if type is "qdrant-semantic".
|
||||
similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic" or "qdrant-semantic".
|
||||
|
||||
|
@ -2176,13 +2169,12 @@ class Cache:
|
|||
)
|
||||
elif type == "qdrant-semantic":
|
||||
self.cache = QdrantSemanticCache(
|
||||
qdrant_url= qdrant_url,
|
||||
qdrant_api_key= qdrant_api_key,
|
||||
collection_name= qdrant_collection_name,
|
||||
similarity_threshold= similarity_threshold,
|
||||
quantization_config= qdrant_quantization_config,
|
||||
embedding_model= qdrant_semantic_cache_embedding_model,
|
||||
host_type=qdrant_host_type
|
||||
qdrant_api_base=qdrant_api_base,
|
||||
qdrant_api_key=qdrant_api_key,
|
||||
collection_name=qdrant_collection_name,
|
||||
similarity_threshold=similarity_threshold,
|
||||
quantization_config=qdrant_quantization_config,
|
||||
embedding_model=qdrant_semantic_cache_embedding_model,
|
||||
)
|
||||
elif type == "local":
|
||||
self.cache = InMemoryCache()
|
||||
|
|
|
@ -210,7 +210,7 @@ class Logging:
|
|||
self.optional_params = optional_params
|
||||
self.model = model
|
||||
self.user = user
|
||||
self.litellm_params = litellm_params
|
||||
self.litellm_params = scrub_sensitive_keys_in_metadata(litellm_params)
|
||||
self.logger_fn = litellm_params.get("logger_fn", None)
|
||||
verbose_logger.debug(f"self.optional_params: {self.optional_params}")
|
||||
|
||||
|
@ -2353,3 +2353,28 @@ def get_standard_logging_object_payload(
|
|||
"Error creating standard logging object - {}".format(str(e))
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]):
|
||||
if litellm_params is None:
|
||||
litellm_params = {}
|
||||
|
||||
metadata = litellm_params.get("metadata", {}) or {}
|
||||
|
||||
## check user_api_key_metadata for sensitive logging keys
|
||||
cleaned_user_api_key_metadata = {}
|
||||
if "user_api_key_metadata" in metadata and isinstance(
|
||||
metadata["user_api_key_metadata"], dict
|
||||
):
|
||||
for k, v in metadata["user_api_key_metadata"].items():
|
||||
if k == "logging": # prevent logging user logging keys
|
||||
cleaned_user_api_key_metadata[k] = (
|
||||
"scrubbed_by_litellm_for_sensitive_keys"
|
||||
)
|
||||
else:
|
||||
cleaned_user_api_key_metadata[k] = v
|
||||
|
||||
metadata["user_api_key_metadata"] = cleaned_user_api_key_metadata
|
||||
litellm_params["metadata"] = metadata
|
||||
|
||||
return litellm_params
|
||||
|
|
|
@ -188,9 +188,11 @@ class GoogleAIStudioGeminiConfig: # key diff from VertexAI - 'frequency_penalty
|
|||
elif value["type"] == "text": # type: ignore
|
||||
optional_params["response_mime_type"] = "text/plain"
|
||||
if "response_schema" in value: # type: ignore
|
||||
optional_params["response_mime_type"] = "application/json"
|
||||
optional_params["response_schema"] = value["response_schema"] # type: ignore
|
||||
elif value["type"] == "json_schema": # type: ignore
|
||||
if "json_schema" in value and "schema" in value["json_schema"]: # type: ignore
|
||||
optional_params["response_mime_type"] = "application/json"
|
||||
optional_params["response_schema"] = value["json_schema"]["schema"] # type: ignore
|
||||
if param == "tools" and isinstance(value, list):
|
||||
gtool_func_declarations = []
|
||||
|
@ -400,9 +402,11 @@ class VertexGeminiConfig:
|
|||
elif value["type"] == "text":
|
||||
optional_params["response_mime_type"] = "text/plain"
|
||||
if "response_schema" in value:
|
||||
optional_params["response_mime_type"] = "application/json"
|
||||
optional_params["response_schema"] = value["response_schema"]
|
||||
elif value["type"] == "json_schema": # type: ignore
|
||||
if "json_schema" in value and "schema" in value["json_schema"]: # type: ignore
|
||||
optional_params["response_mime_type"] = "application/json"
|
||||
optional_params["response_schema"] = value["json_schema"]["schema"] # type: ignore
|
||||
if param == "frequency_penalty":
|
||||
optional_params["frequency_penalty"] = value
|
||||
|
|
|
@ -4,4 +4,4 @@ model_list:
|
|||
model: "*"
|
||||
|
||||
general_settings:
|
||||
disable_adding_master_key_hash_to_db: True
|
||||
disable_adding_master_key_hash_to_db: True
|
||||
|
|
|
@ -21,6 +21,13 @@ else:
|
|||
Span = Any
|
||||
|
||||
|
||||
class LiteLLMTeamRoles(enum.Enum):
|
||||
# team admin
|
||||
TEAM_ADMIN = "admin"
|
||||
# team member
|
||||
TEAM_MEMBER = "user"
|
||||
|
||||
|
||||
class LitellmUserRoles(str, enum.Enum):
|
||||
"""
|
||||
Admin Roles:
|
||||
|
@ -335,6 +342,11 @@ class LiteLLMRoutes(enum.Enum):
|
|||
+ sso_only_routes
|
||||
)
|
||||
|
||||
self_managed_routes: List = [
|
||||
"/team/member_add",
|
||||
"/team/member_delete",
|
||||
] # routes that manage their own allowed/disallowed logic
|
||||
|
||||
|
||||
# class LiteLLMAllowedRoutes(LiteLLMBase):
|
||||
# """
|
||||
|
@ -1308,6 +1320,7 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
|
|||
soft_budget: Optional[float] = None
|
||||
team_model_aliases: Optional[Dict] = None
|
||||
team_member_spend: Optional[float] = None
|
||||
team_member: Optional[Member] = None
|
||||
team_metadata: Optional[Dict] = None
|
||||
|
||||
# End User Params
|
||||
|
|
|
@ -975,8 +975,6 @@ async def user_api_key_auth(
|
|||
if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin
|
||||
if is_llm_api_route(route=route):
|
||||
pass
|
||||
elif is_llm_api_route(route=request["route"].name):
|
||||
pass
|
||||
elif (
|
||||
route in LiteLLMRoutes.info_routes.value
|
||||
): # check if user allowed to call an info route
|
||||
|
@ -1046,11 +1044,16 @@ async def user_api_key_auth(
|
|||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"user not allowed to access this route, role= {_user_role}. Trying to access: {route}",
|
||||
)
|
||||
|
||||
elif (
|
||||
_user_role == LitellmUserRoles.INTERNAL_USER.value
|
||||
and route in LiteLLMRoutes.internal_user_routes.value
|
||||
):
|
||||
pass
|
||||
elif (
|
||||
route in LiteLLMRoutes.self_managed_routes.value
|
||||
): # routes that manage their own allowed/disallowed logic
|
||||
pass
|
||||
else:
|
||||
user_role = "unknown"
|
||||
user_id = "unknown"
|
||||
|
|
|
@ -285,14 +285,18 @@ def get_remaining_tokens_and_requests_from_request_data(data: Dict) -> Dict[str,
|
|||
return headers
|
||||
|
||||
|
||||
def get_applied_guardrails_header(request_data: Dict) -> Optional[Dict]:
|
||||
def get_logging_caching_headers(request_data: Dict) -> Optional[Dict]:
|
||||
_metadata = request_data.get("metadata", None) or {}
|
||||
headers = {}
|
||||
if "applied_guardrails" in _metadata:
|
||||
return {
|
||||
"x-litellm-applied-guardrails": ",".join(_metadata["applied_guardrails"]),
|
||||
}
|
||||
headers["x-litellm-applied-guardrails"] = ",".join(
|
||||
_metadata["applied_guardrails"]
|
||||
)
|
||||
|
||||
return None
|
||||
if "semantic-similarity" in _metadata:
|
||||
headers["x-litellm-semantic-similarity"] = str(_metadata["semantic-similarity"])
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
def add_guardrail_to_applied_guardrails_header(
|
||||
|
|
|
@ -95,7 +95,9 @@ def convert_key_logging_metadata_to_callback(
|
|||
for var, value in data.callback_vars.items():
|
||||
if team_callback_settings_obj.callback_vars is None:
|
||||
team_callback_settings_obj.callback_vars = {}
|
||||
team_callback_settings_obj.callback_vars[var] = litellm.get_secret(value)
|
||||
team_callback_settings_obj.callback_vars[var] = (
|
||||
litellm.utils.get_secret(value, default_value=value) or value
|
||||
)
|
||||
|
||||
return team_callback_settings_obj
|
||||
|
||||
|
@ -130,7 +132,6 @@ def _get_dynamic_logging_metadata(
|
|||
data=AddTeamCallback(**item),
|
||||
team_callback_settings_obj=callback_settings_obj,
|
||||
)
|
||||
|
||||
return callback_settings_obj
|
||||
|
||||
|
||||
|
|
|
@ -119,6 +119,7 @@ async def new_user(
|
|||
http_request=Request(
|
||||
scope={"type": "http", "path": "/user/new"},
|
||||
),
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
|
||||
if data.send_invite_email is True:
|
||||
|
|
|
@ -849,7 +849,7 @@ async def generate_key_helper_fn(
|
|||
}
|
||||
|
||||
if (
|
||||
litellm.get_secret("DISABLE_KEY_NAME", False) == True
|
||||
litellm.get_secret("DISABLE_KEY_NAME", False) is True
|
||||
): # allow user to disable storing abbreviated key name (shown in UI, to help figure out which key spent how much)
|
||||
pass
|
||||
else:
|
||||
|
|
|
@ -30,7 +30,7 @@ from litellm.proxy._types import (
|
|||
UpdateTeamRequest,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.auth.user_api_key_auth import _is_user_proxy_admin, user_api_key_auth
|
||||
from litellm.proxy.management_helpers.utils import (
|
||||
add_new_member,
|
||||
management_endpoint_wrapper,
|
||||
|
@ -39,6 +39,16 @@ from litellm.proxy.management_helpers.utils import (
|
|||
router = APIRouter()
|
||||
|
||||
|
||||
def _is_user_team_admin(
|
||||
user_api_key_dict: UserAPIKeyAuth, team_obj: LiteLLM_TeamTable
|
||||
) -> bool:
|
||||
for member in team_obj.members_with_roles:
|
||||
if member.user_id is not None and member.user_id == user_api_key_dict.user_id:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
#### TEAM MANAGEMENT ####
|
||||
@router.post(
|
||||
"/team/new",
|
||||
|
@ -417,6 +427,7 @@ async def team_member_add(
|
|||
|
||||
If user doesn't exist, new user row will also be added to User Table
|
||||
|
||||
Only proxy_admin or admin of team, allowed to access this endpoint.
|
||||
```
|
||||
|
||||
curl -X POST 'http://0.0.0.0:4000/team/member_add' \
|
||||
|
@ -465,6 +476,24 @@ async def team_member_add(
|
|||
|
||||
complete_team_data = LiteLLM_TeamTable(**existing_team_row.model_dump())
|
||||
|
||||
## CHECK IF USER IS PROXY ADMIN OR TEAM ADMIN
|
||||
|
||||
if (
|
||||
user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value
|
||||
and not _is_user_team_admin(
|
||||
user_api_key_dict=user_api_key_dict, team_obj=complete_team_data
|
||||
)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": "Call not allowed. User not proxy admin OR team admin. route={}, team_id={}".format(
|
||||
"/team/member_add",
|
||||
complete_team_data.team_id,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
if isinstance(data.member, Member):
|
||||
# add to team db
|
||||
new_member = data.member
|
||||
|
@ -569,6 +598,23 @@ async def team_member_delete(
|
|||
)
|
||||
existing_team_row = LiteLLM_TeamTable(**_existing_team_row.model_dump())
|
||||
|
||||
## CHECK IF USER IS PROXY ADMIN OR TEAM ADMIN
|
||||
|
||||
if (
|
||||
user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value
|
||||
and not _is_user_team_admin(
|
||||
user_api_key_dict=user_api_key_dict, team_obj=existing_team_row
|
||||
)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": "Call not allowed. User not proxy admin OR team admin. route={}, team_id={}".format(
|
||||
"/team/member_delete", existing_team_row.team_id
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
## DELETE MEMBER FROM TEAM
|
||||
new_team_members: List[Member] = []
|
||||
for m in existing_team_row.members_with_roles:
|
||||
|
|
|
@ -4,15 +4,17 @@ model_list:
|
|||
model: openai/fake
|
||||
api_key: fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
|
||||
guardrails:
|
||||
- guardrail_name: "lakera-pre-guard"
|
||||
- model_name: openai-embedding
|
||||
litellm_params:
|
||||
guardrail: lakera # supported values: "aporia", "bedrock", "lakera"
|
||||
mode: "during_call"
|
||||
api_key: os.environ/LAKERA_API_KEY
|
||||
api_base: os.environ/LAKERA_API_BASE
|
||||
category_thresholds:
|
||||
prompt_injection: 0.1
|
||||
jailbreak: 0.1
|
||||
|
||||
model: openai/text-embedding-3-small
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
litellm_settings:
|
||||
set_verbose: True
|
||||
cache: True # set cache responses to True, litellm defaults to using a redis cache
|
||||
cache_params:
|
||||
type: qdrant-semantic
|
||||
qdrant_semantic_cache_embedding_model: openai-embedding
|
||||
qdrant_collection_name: test_collection
|
||||
qdrant_quantization_config: binary
|
||||
similarity_threshold: 0.8 # similarity threshold for semantic cache
|
|
@ -149,7 +149,7 @@ from litellm.proxy.common_utils.admin_ui_utils import (
|
|||
show_missing_vars_in_env,
|
||||
)
|
||||
from litellm.proxy.common_utils.callback_utils import (
|
||||
get_applied_guardrails_header,
|
||||
get_logging_caching_headers,
|
||||
get_remaining_tokens_and_requests_from_request_data,
|
||||
initialize_callbacks_on_proxy,
|
||||
)
|
||||
|
@ -543,9 +543,9 @@ def get_custom_headers(
|
|||
)
|
||||
headers.update(remaining_tokens_header)
|
||||
|
||||
applied_guardrails = get_applied_guardrails_header(request_data)
|
||||
if applied_guardrails:
|
||||
headers.update(applied_guardrails)
|
||||
logging_caching_headers = get_logging_caching_headers(request_data)
|
||||
if logging_caching_headers:
|
||||
headers.update(logging_caching_headers)
|
||||
|
||||
try:
|
||||
return {
|
||||
|
|
|
@ -44,6 +44,7 @@ from litellm.proxy._types import (
|
|||
DynamoDBArgs,
|
||||
LiteLLM_VerificationTokenView,
|
||||
LitellmUserRoles,
|
||||
Member,
|
||||
ResetTeamBudgetRequest,
|
||||
SpendLogsMetadata,
|
||||
SpendLogsPayload,
|
||||
|
@ -1395,6 +1396,7 @@ class PrismaClient:
|
|||
t.blocked AS team_blocked,
|
||||
t.team_alias AS team_alias,
|
||||
t.metadata AS team_metadata,
|
||||
t.members_with_roles AS team_members_with_roles,
|
||||
tm.spend AS team_member_spend,
|
||||
m.aliases as team_model_aliases
|
||||
FROM "LiteLLM_VerificationToken" AS v
|
||||
|
@ -1412,6 +1414,33 @@ class PrismaClient:
|
|||
response["team_models"] = []
|
||||
if response["team_blocked"] is None:
|
||||
response["team_blocked"] = False
|
||||
|
||||
team_member: Optional[Member] = None
|
||||
if (
|
||||
response["team_members_with_roles"] is not None
|
||||
and response["user_id"] is not None
|
||||
):
|
||||
## find the team member corresponding to user id
|
||||
"""
|
||||
[
|
||||
{
|
||||
"role": "admin",
|
||||
"user_id": "default_user_id",
|
||||
"user_email": null
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"user_id": null,
|
||||
"user_email": "test@email.com"
|
||||
}
|
||||
]
|
||||
"""
|
||||
for tm in response["team_members_with_roles"]:
|
||||
if tm.get("user_id") is not None and response[
|
||||
"user_id"
|
||||
] == tm.get("user_id"):
|
||||
team_member = Member(**tm)
|
||||
response["team_member"] = team_member
|
||||
response = LiteLLM_VerificationTokenView(
|
||||
**response, last_refreshed_at=time.time()
|
||||
)
|
||||
|
|
|
@ -1558,6 +1558,16 @@ async def test_gemini_pro_json_schema_args_sent_httpx_openai_schema(
|
|||
"response_schema"
|
||||
in mock_call.call_args.kwargs["json"]["generationConfig"]
|
||||
)
|
||||
assert (
|
||||
"response_mime_type"
|
||||
in mock_call.call_args.kwargs["json"]["generationConfig"]
|
||||
)
|
||||
assert (
|
||||
mock_call.call_args.kwargs["json"]["generationConfig"][
|
||||
"response_mime_type"
|
||||
]
|
||||
== "application/json"
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
"response_schema"
|
||||
|
|
|
@ -1733,8 +1733,10 @@ def test_caching_redis_simple(caplog, capsys):
|
|||
assert redis_service_logging_error is False
|
||||
assert "async success_callback: reaches cache for logging" not in captured.out
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_qdrant_semantic_cache_acompletion():
|
||||
litellm.set_verbose = True
|
||||
random_number = random.randint(
|
||||
1, 100000
|
||||
) # add a random number to ensure it's always adding /reading from cache
|
||||
|
@ -1742,13 +1744,13 @@ async def test_qdrant_semantic_cache_acompletion():
|
|||
print("Testing Qdrant Semantic Caching with acompletion")
|
||||
|
||||
litellm.cache = Cache(
|
||||
type="qdrant-semantic",
|
||||
qdrant_host_type="cloud",
|
||||
qdrant_url=os.getenv("QDRANT_URL"),
|
||||
qdrant_api_key=os.getenv("QDRANT_API_KEY"),
|
||||
qdrant_collection_name='test_collection',
|
||||
similarity_threshold=0.8,
|
||||
qdrant_quantization_config="binary"
|
||||
type="qdrant-semantic",
|
||||
_host_type="cloud",
|
||||
qdrant_api_base=os.getenv("QDRANT_URL"),
|
||||
qdrant_api_key=os.getenv("QDRANT_API_KEY"),
|
||||
qdrant_collection_name="test_collection",
|
||||
similarity_threshold=0.8,
|
||||
qdrant_quantization_config="binary",
|
||||
)
|
||||
|
||||
response1 = await litellm.acompletion(
|
||||
|
@ -1759,6 +1761,7 @@ async def test_qdrant_semantic_cache_acompletion():
|
|||
"content": f"write a one sentence poem about: {random_number}",
|
||||
}
|
||||
],
|
||||
mock_response="hello",
|
||||
max_tokens=20,
|
||||
)
|
||||
print(f"Response1: {response1}")
|
||||
|
@ -1778,6 +1781,7 @@ async def test_qdrant_semantic_cache_acompletion():
|
|||
print(f"Response2: {response2}")
|
||||
assert response1.id == response2.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_qdrant_semantic_cache_acompletion_stream():
|
||||
try:
|
||||
|
@ -1789,13 +1793,12 @@ async def test_qdrant_semantic_cache_acompletion_stream():
|
|||
}
|
||||
]
|
||||
litellm.cache = Cache(
|
||||
type="qdrant-semantic",
|
||||
qdrant_host_type="cloud",
|
||||
qdrant_url=os.getenv("QDRANT_URL"),
|
||||
qdrant_api_key=os.getenv("QDRANT_API_KEY"),
|
||||
qdrant_collection_name='test_collection',
|
||||
similarity_threshold=0.8,
|
||||
qdrant_quantization_config="binary"
|
||||
type="qdrant-semantic",
|
||||
qdrant_api_base=os.getenv("QDRANT_URL"),
|
||||
qdrant_api_key=os.getenv("QDRANT_API_KEY"),
|
||||
qdrant_collection_name="test_collection",
|
||||
similarity_threshold=0.8,
|
||||
qdrant_quantization_config="binary",
|
||||
)
|
||||
print("Test Qdrant Semantic Caching with streaming + acompletion")
|
||||
response_1_content = ""
|
||||
|
@ -1807,6 +1810,7 @@ async def test_qdrant_semantic_cache_acompletion_stream():
|
|||
max_tokens=40,
|
||||
temperature=1,
|
||||
stream=True,
|
||||
mock_response="hi",
|
||||
)
|
||||
async for chunk in response1:
|
||||
response_1_id = chunk.id
|
||||
|
@ -1830,7 +1834,9 @@ async def test_qdrant_semantic_cache_acompletion_stream():
|
|||
assert (
|
||||
response_1_content == response_2_content
|
||||
), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
|
||||
assert (response_1_id == response_2_id), f"Response 1 id != Response 2 id, Response 1 id: {response_1_id} != Response 2 id: {response_2_id}"
|
||||
assert (
|
||||
response_1_id == response_2_id
|
||||
), f"Response 1 id != Response 2 id, Response 1 id: {response_1_id} != Response 2 id: {response_2_id}"
|
||||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
|
|
|
@ -3283,9 +3283,9 @@ def test_completion_together_ai_mixtral():
|
|||
# test_completion_together_ai_mixtral()
|
||||
|
||||
|
||||
def test_completion_together_ai_yi_chat():
|
||||
def test_completion_together_ai_llama():
|
||||
litellm.set_verbose = True
|
||||
model_name = "together_ai/mistralai/Mistral-7B-Instruct-v0.1"
|
||||
model_name = "together_ai/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo"
|
||||
try:
|
||||
messages = [
|
||||
{"role": "user", "content": "What llm are you?"},
|
||||
|
|
|
@ -909,7 +909,7 @@ async def test_create_team_member_add(prisma_client, new_member_method):
|
|||
|
||||
await team_member_add(
|
||||
data=team_member_add_request,
|
||||
user_api_key_dict=UserAPIKeyAuth(),
|
||||
user_api_key_dict=UserAPIKeyAuth(user_role="proxy_admin"),
|
||||
http_request=Request(
|
||||
scope={"type": "http", "path": "/user/new"},
|
||||
),
|
||||
|
@ -930,6 +930,172 @@ async def test_create_team_member_add(prisma_client, new_member_method):
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("team_member_role", ["admin", "user"])
|
||||
@pytest.mark.parametrize("team_route", ["/team/member_add", "/team/member_delete"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_team_member_add_team_admin_user_api_key_auth(
|
||||
prisma_client, team_member_role, team_route
|
||||
):
|
||||
import time
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from litellm.proxy._types import LiteLLM_TeamTableCachedObj, Member
|
||||
from litellm.proxy.proxy_server import (
|
||||
ProxyException,
|
||||
hash_token,
|
||||
user_api_key_auth,
|
||||
user_api_key_cache,
|
||||
)
|
||||
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||
setattr(litellm, "max_internal_user_budget", 10)
|
||||
setattr(litellm, "internal_user_budget_duration", "5m")
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
user = f"ishaan {uuid.uuid4().hex}"
|
||||
_team_id = "litellm-test-client-id-new"
|
||||
user_key = "sk-12345678"
|
||||
|
||||
valid_token = UserAPIKeyAuth(
|
||||
team_id=_team_id,
|
||||
token=hash_token(user_key),
|
||||
team_member=Member(role=team_member_role, user_id=user),
|
||||
last_refreshed_at=time.time(),
|
||||
)
|
||||
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
|
||||
|
||||
team_obj = LiteLLM_TeamTableCachedObj(
|
||||
team_id=_team_id,
|
||||
blocked=False,
|
||||
last_refreshed_at=time.time(),
|
||||
metadata={"guardrails": {"modify_guardrails": False}},
|
||||
)
|
||||
|
||||
user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)
|
||||
|
||||
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
|
||||
|
||||
## TEST IF TEAM ADMIN ALLOWED TO CALL /MEMBER_ADD ENDPOINT
|
||||
import json
|
||||
|
||||
from starlette.datastructures import URL
|
||||
|
||||
request = Request(scope={"type": "http"})
|
||||
request._url = URL(url=team_route)
|
||||
|
||||
body = {}
|
||||
json_bytes = json.dumps(body).encode("utf-8")
|
||||
|
||||
request._body = json_bytes
|
||||
|
||||
## ALLOWED BY USER_API_KEY_AUTH
|
||||
await user_api_key_auth(request=request, api_key="Bearer " + user_key)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("new_member_method", ["user_id", "user_email"])
|
||||
@pytest.mark.parametrize("user_role", ["admin", "user"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_team_member_add_team_admin(
|
||||
prisma_client, new_member_method, user_role
|
||||
):
|
||||
"""
|
||||
Relevant issue - https://github.com/BerriAI/litellm/issues/5300
|
||||
|
||||
Allow team admins to:
|
||||
- Add and remove team members
|
||||
- raise error if team member not an existing 'internal_user'
|
||||
"""
|
||||
import time
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from litellm.proxy._types import LiteLLM_TeamTableCachedObj, Member
|
||||
from litellm.proxy.proxy_server import (
|
||||
HTTPException,
|
||||
ProxyException,
|
||||
hash_token,
|
||||
user_api_key_auth,
|
||||
user_api_key_cache,
|
||||
)
|
||||
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||
setattr(litellm, "max_internal_user_budget", 10)
|
||||
setattr(litellm, "internal_user_budget_duration", "5m")
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
user = f"ishaan {uuid.uuid4().hex}"
|
||||
_team_id = "litellm-test-client-id-new"
|
||||
user_key = "sk-12345678"
|
||||
|
||||
valid_token = UserAPIKeyAuth(
|
||||
team_id=_team_id,
|
||||
user_id=user,
|
||||
token=hash_token(user_key),
|
||||
last_refreshed_at=time.time(),
|
||||
)
|
||||
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
|
||||
|
||||
team_obj = LiteLLM_TeamTableCachedObj(
|
||||
team_id=_team_id,
|
||||
blocked=False,
|
||||
last_refreshed_at=time.time(),
|
||||
members_with_roles=[Member(role=user_role, user_id=user)],
|
||||
metadata={"guardrails": {"modify_guardrails": False}},
|
||||
)
|
||||
|
||||
user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)
|
||||
|
||||
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
|
||||
if new_member_method == "user_id":
|
||||
data = {
|
||||
"team_id": _team_id,
|
||||
"member": [{"role": "user", "user_id": user}],
|
||||
}
|
||||
elif new_member_method == "user_email":
|
||||
data = {
|
||||
"team_id": _team_id,
|
||||
"member": [{"role": "user", "user_email": user}],
|
||||
}
|
||||
team_member_add_request = TeamMemberAddRequest(**data)
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.proxy_server.prisma_client.db.litellm_usertable",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_litellm_usertable:
|
||||
mock_client = AsyncMock()
|
||||
mock_litellm_usertable.upsert = mock_client
|
||||
mock_litellm_usertable.find_many = AsyncMock(return_value=None)
|
||||
|
||||
try:
|
||||
await team_member_add(
|
||||
data=team_member_add_request,
|
||||
user_api_key_dict=valid_token,
|
||||
http_request=Request(
|
||||
scope={"type": "http", "path": "/user/new"},
|
||||
),
|
||||
)
|
||||
except HTTPException as e:
|
||||
if user_role == "user":
|
||||
assert e.status_code == 403
|
||||
else:
|
||||
raise e
|
||||
|
||||
mock_client.assert_called()
|
||||
|
||||
print(f"mock_client.call_args: {mock_client.call_args}")
|
||||
print("mock_client.call_args.kwargs: {}".format(mock_client.call_args.kwargs))
|
||||
|
||||
assert (
|
||||
mock_client.call_args.kwargs["data"]["create"]["max_budget"]
|
||||
== litellm.max_internal_user_budget
|
||||
)
|
||||
assert (
|
||||
mock_client.call_args.kwargs["data"]["create"]["budget_duration"]
|
||||
== litellm.internal_user_budget_duration
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_info_team_list(prisma_client):
|
||||
"""Assert user_info for admin calls team_list function"""
|
||||
|
@ -1116,8 +1282,8 @@ async def test_add_callback_via_key_litellm_pre_call_utils(prisma_client):
|
|||
"callback_name": "langfuse",
|
||||
"callback_type": "success",
|
||||
"callback_vars": {
|
||||
"langfuse_public_key": "os.environ/LANGFUSE_PUBLIC_KEY",
|
||||
"langfuse_secret_key": "os.environ/LANGFUSE_SECRET_KEY",
|
||||
"langfuse_public_key": "my-mock-public-key",
|
||||
"langfuse_secret_key": "my-mock-secret-key",
|
||||
"langfuse_host": "https://us.cloud.langfuse.com",
|
||||
},
|
||||
}
|
||||
|
@ -1165,7 +1331,9 @@ async def test_add_callback_via_key_litellm_pre_call_utils(prisma_client):
|
|||
assert "success_callback" in new_data
|
||||
assert new_data["success_callback"] == ["langfuse"]
|
||||
assert "langfuse_public_key" in new_data
|
||||
assert new_data["langfuse_public_key"] == "my-mock-public-key"
|
||||
assert "langfuse_secret_key" in new_data
|
||||
assert new_data["langfuse_secret_key"] == "my-mock-secret-key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
@ -121,7 +121,7 @@ import importlib.metadata
|
|||
from openai import OpenAIError as OriginalError
|
||||
|
||||
from ._logging import verbose_logger
|
||||
from .caching import RedisCache, RedisSemanticCache, S3Cache, QdrantSemanticCache
|
||||
from .caching import QdrantSemanticCache, RedisCache, RedisSemanticCache, S3Cache
|
||||
from .exceptions import (
|
||||
APIConnectionError,
|
||||
APIError,
|
||||
|
@ -8622,7 +8622,9 @@ def get_secret(
|
|||
return secret_value_as_bool
|
||||
else:
|
||||
return secret
|
||||
except:
|
||||
except Exception:
|
||||
if default_value is not None:
|
||||
return default_value
|
||||
return secret
|
||||
except Exception as e:
|
||||
if default_value is not None:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue