Merge branch 'main' into litellm_disable_storing_master_key_hash_in_db

This commit is contained in:
Krish Dholakia 2024-08-21 15:37:25 -07:00 committed by GitHub
commit 72169fd5c4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 537 additions and 169 deletions

View file

@ -161,8 +161,7 @@ random_number = random.randint(
print("testing semantic caching") print("testing semantic caching")
litellm.cache = Cache( litellm.cache = Cache(
type="qdrant-semantic", type="qdrant-semantic",
qdrant_host_type="cloud", # can be either 'cloud' or 'local' qdrant_api_base=os.environ["QDRANT_API_BASE"],
qdrant_url=os.environ["QDRANT_URL"],
qdrant_api_key=os.environ["QDRANT_API_KEY"], qdrant_api_key=os.environ["QDRANT_API_KEY"],
qdrant_collection_name="your_collection_name", # any name of your collection qdrant_collection_name="your_collection_name", # any name of your collection
similarity_threshold=0.7, # similarity threshold for cache hits, 0 == no similarity, 1 = exact matches, 0.5 == 50% similarity similarity_threshold=0.7, # similarity threshold for cache hits, 0 == no similarity, 1 = exact matches, 0.5 == 50% similarity
@ -491,12 +490,11 @@ def __init__(
disk_cache_dir=None, disk_cache_dir=None,
# qdrant cache params # qdrant cache params
qdrant_url: Optional[str] = None, qdrant_api_base: Optional[str] = None,
qdrant_api_key: Optional[str] = None, qdrant_api_key: Optional[str] = None,
qdrant_collection_name: Optional[str] = None, qdrant_collection_name: Optional[str] = None,
qdrant_quantization_config: Optional[str] = None, qdrant_quantization_config: Optional[str] = None,
qdrant_semantic_cache_embedding_model="text-embedding-ada-002", qdrant_semantic_cache_embedding_model="text-embedding-ada-002",
qdrant_host_type: Optional[Literal["local","cloud"]] = "local",
**kwargs **kwargs
): ):

View file

@ -7,6 +7,7 @@ Cache LLM Responses
LiteLLM supports: LiteLLM supports:
- In Memory Cache - In Memory Cache
- Redis Cache - Redis Cache
- Qdrant Semantic Cache
- Redis Semantic Cache - Redis Semantic Cache
- s3 Bucket Cache - s3 Bucket Cache
@ -103,6 +104,66 @@ $ litellm --config /path/to/config.yaml
``` ```
</TabItem> </TabItem>
<TabItem value="qdrant-semantic" label="Qdrant Semantic cache">
Caching can be enabled by adding the `cache` key in the `config.yaml`
#### Step 1: Add `cache` to the config.yaml
```yaml
model_list:
- model_name: fake-openai-endpoint
litellm_params:
model: openai/fake
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
- model_name: openai-embedding
litellm_params:
model: openai/text-embedding-3-small
api_key: os.environ/OPENAI_API_KEY
litellm_settings:
set_verbose: True
cache: True # set cache responses to True, litellm defaults to using a redis cache
cache_params:
type: qdrant-semantic
qdrant_semantic_cache_embedding_model: openai-embedding # the model should be defined on the model_list
qdrant_collection_name: test_collection
qdrant_quantization_config: binary
similarity_threshold: 0.8 # similarity threshold for semantic cache
```
#### Step 2: Add Qdrant Credentials to your .env
```shell
QDRANT_API_KEY = "16rJUMBRx*************"
QDRANT_API_BASE = "https://5392d382-45*********.cloud.qdrant.io"
```
#### Step 3: Run proxy with config
```shell
$ litellm --config /path/to/config.yaml
```
#### Step 4. Test it
```shell
curl -i http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234" \
-d '{
"model": "fake-openai-endpoint",
"messages": [
{"role": "user", "content": "Hello"}
]
}'
```
**Expect to see `x-litellm-semantic-similarity` in the response headers when semantic caching is one**
</TabItem>
<TabItem value="s3" label="s3 cache"> <TabItem value="s3" label="s3 cache">
#### Step 1: Add `cache` to the config.yaml #### Step 1: Add `cache` to the config.yaml
@ -182,6 +243,9 @@ REDIS_<redis-kwarg-name> = ""
$ litellm --config /path/to/config.yaml $ litellm --config /path/to/config.yaml
``` ```
</TabItem> </TabItem>
</Tabs> </Tabs>

View file

@ -74,6 +74,7 @@ const sidebars = {
"proxy/alerting", "proxy/alerting",
"proxy/ui", "proxy/ui",
"proxy/prometheus", "proxy/prometheus",
"proxy/caching",
"proxy/pass_through", "proxy/pass_through",
"proxy/email", "proxy/email",
"proxy/multiple_admins", "proxy/multiple_admins",
@ -88,7 +89,6 @@ const sidebars = {
"proxy/health", "proxy/health",
"proxy/debugging", "proxy/debugging",
"proxy/pii_masking", "proxy/pii_masking",
"proxy/caching",
"proxy/call_hooks", "proxy/call_hooks",
"proxy/rules", "proxy/rules",
"proxy/cli", "proxy/cli",

View file

@ -1219,25 +1219,28 @@ class RedisSemanticCache(BaseCache):
async def _index_info(self): async def _index_info(self):
return await self.index.ainfo() return await self.index.ainfo()
class QdrantSemanticCache(BaseCache): class QdrantSemanticCache(BaseCache):
def __init__( def __init__(
self, self,
qdrant_url=None, qdrant_api_base=None,
qdrant_api_key = None, qdrant_api_key=None,
collection_name=None, collection_name=None,
similarity_threshold=None, similarity_threshold=None,
quantization_config=None, quantization_config=None,
embedding_model="text-embedding-ada-002", embedding_model="text-embedding-ada-002",
host_type = None host_type=None,
): ):
import os
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client, _get_async_httpx_client,
_get_async_httpx_client _get_httpx_client,
) )
if collection_name is None: if collection_name is None:
raise Exception("collection_name must be provided, passed None") raise Exception("collection_name must be provided, passed None")
self.collection_name = collection_name self.collection_name = collection_name
print_verbose( print_verbose(
f"qdrant semantic-cache initializing COLLECTION - {self.collection_name}" f"qdrant semantic-cache initializing COLLECTION - {self.collection_name}"
@ -1247,108 +1250,97 @@ class QdrantSemanticCache(BaseCache):
raise Exception("similarity_threshold must be provided, passed None") raise Exception("similarity_threshold must be provided, passed None")
self.similarity_threshold = similarity_threshold self.similarity_threshold = similarity_threshold
self.embedding_model = embedding_model self.embedding_model = embedding_model
headers = {}
if host_type=="cloud": # check if defined as os.environ/ variable
import os if qdrant_api_base:
if qdrant_url is None: if isinstance(qdrant_api_base, str) and qdrant_api_base.startswith(
qdrant_url = os.getenv('QDRANT_URL') "os.environ/"
if qdrant_api_key is None: ):
qdrant_api_key = os.getenv('QDRANT_API_KEY') qdrant_api_base = litellm.get_secret(qdrant_api_base)
if qdrant_url is not None and qdrant_api_key is not None: if qdrant_api_key:
headers = { if isinstance(qdrant_api_key, str) and qdrant_api_key.startswith(
"api-key": qdrant_api_key, "os.environ/"
"Content-Type": "application/json" ):
} qdrant_api_key = litellm.get_secret(qdrant_api_key)
else:
raise Exception("Qdrant url and api_key must be provided for qdrant cloud hosting") qdrant_api_base = (
elif host_type=="local": qdrant_api_base or os.getenv("QDRANT_URL") or os.getenv("QDRANT_API_BASE")
import os )
if qdrant_url is None: qdrant_api_key = qdrant_api_key or os.getenv("QDRANT_API_KEY")
qdrant_url = os.getenv('QDRANT_URL') headers = {"api-key": qdrant_api_key, "Content-Type": "application/json"}
if qdrant_url is None:
raise Exception("Qdrant url must be provided for qdrant local hosting") if qdrant_api_key is None or qdrant_api_base is None:
if qdrant_api_key is None: raise ValueError("Qdrant url and api_key must be")
qdrant_api_key = os.getenv('QDRANT_API_KEY')
if qdrant_api_key is None: self.qdrant_api_base = qdrant_api_base
print_verbose('Running locally without API Key.')
headers= {
"Content-Type": "application/json"
}
else:
print_verbose("Running locally with API Key")
headers = {
"api-key": qdrant_api_key,
"Content-Type": "application/json"
}
else:
raise Exception("Host type can be either 'local' or 'cloud'")
self.qdrant_url = qdrant_url
self.qdrant_api_key = qdrant_api_key self.qdrant_api_key = qdrant_api_key
print_verbose(f"qdrant semantic-cache qdrant_url: {self.qdrant_url}") print_verbose(f"qdrant semantic-cache qdrant_api_base: {self.qdrant_api_base}")
self.headers = headers self.headers = headers
self.sync_client = _get_httpx_client() self.sync_client = _get_httpx_client()
self.async_client = _get_async_httpx_client() self.async_client = _get_async_httpx_client()
if quantization_config is None: if quantization_config is None:
print('Quantization config is not provided. Default binary quantization will be used.') print_verbose(
"Quantization config is not provided. Default binary quantization will be used."
collection_exists = self.sync_client.get(
url= f"{self.qdrant_url}/collections/{self.collection_name}/exists",
headers=self.headers
) )
if collection_exists.json()['result']['exists']: collection_exists = self.sync_client.get(
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/exists",
headers=self.headers,
)
if collection_exists.status_code != 200:
raise ValueError(
f"Error from qdrant checking if /collections exist {collection_exists.text}"
)
if collection_exists.json()["result"]["exists"]:
collection_details = self.sync_client.get( collection_details = self.sync_client.get(
url=f"{self.qdrant_url}/collections/{self.collection_name}", url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
headers=self.headers headers=self.headers,
) )
self.collection_info = collection_details.json() self.collection_info = collection_details.json()
print_verbose(f'Collection already exists.\nCollection details:{self.collection_info}') print_verbose(
f"Collection already exists.\nCollection details:{self.collection_info}"
)
else: else:
if quantization_config is None or quantization_config == 'binary': if quantization_config is None or quantization_config == "binary":
quantization_params = { quantization_params = {
"binary": { "binary": {
"always_ram": False, "always_ram": False,
} }
} }
elif quantization_config == 'scalar': elif quantization_config == "scalar":
quantization_params = { quantization_params = {
"scalar": { "scalar": {"type": "int8", "quantile": 0.99, "always_ram": False}
"type": "int8",
"quantile": 0.99,
"always_ram": False
}
} }
elif quantization_config == 'product': elif quantization_config == "product":
quantization_params = { quantization_params = {
"product": { "product": {"compression": "x16", "always_ram": False}
"compression": "x16",
"always_ram": False
}
} }
else: else:
raise Exception("Quantization config must be one of 'scalar', 'binary' or 'product'") raise Exception(
"Quantization config must be one of 'scalar', 'binary' or 'product'"
)
new_collection_status = self.sync_client.put( new_collection_status = self.sync_client.put(
url=f"{self.qdrant_url}/collections/{self.collection_name}", url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
json={ json={
"vectors": { "vectors": {"size": 1536, "distance": "Cosine"},
"size": 1536, "quantization_config": quantization_params,
"distance": "Cosine"
},
"quantization_config": quantization_params
}, },
headers=self.headers headers=self.headers,
) )
if new_collection_status.json()["result"]: if new_collection_status.json()["result"]:
collection_details = self.sync_client.get( collection_details = self.sync_client.get(
url=f"{self.qdrant_url}/collections/{self.collection_name}", url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
headers=self.headers headers=self.headers,
) )
self.collection_info = collection_details.json() self.collection_info = collection_details.json()
print_verbose(f'New collection created.\nCollection details:{self.collection_info}') print_verbose(
f"New collection created.\nCollection details:{self.collection_info}"
)
else: else:
raise Exception("Error while creating new collection") raise Exception("Error while creating new collection")
@ -1394,14 +1386,14 @@ class QdrantSemanticCache(BaseCache):
"payload": { "payload": {
"text": prompt, "text": prompt,
"response": value, "response": value,
} },
}, },
] ]
} }
keys = self.sync_client.put( keys = self.sync_client.put(
url=f"{self.qdrant_url}/collections/{self.collection_name}/points", url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points",
headers=self.headers, headers=self.headers,
json=data json=data,
) )
return return
@ -1433,14 +1425,14 @@ class QdrantSemanticCache(BaseCache):
"oversampling": 3.0, "oversampling": 3.0,
} }
}, },
"limit":1, "limit": 1,
"with_payload": True "with_payload": True,
} }
search_response = self.sync_client.post( search_response = self.sync_client.post(
url=f"{self.qdrant_url}/collections/{self.collection_name}/points/search", url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search",
headers=self.headers, headers=self.headers,
json=data json=data,
) )
results = search_response.json()["result"] results = search_response.json()["result"]
@ -1470,8 +1462,10 @@ class QdrantSemanticCache(BaseCache):
pass pass
async def async_set_cache(self, key, value, **kwargs): async def async_set_cache(self, key, value, **kwargs):
from litellm.proxy.proxy_server import llm_router, llm_model_list
import uuid import uuid
from litellm.proxy.proxy_server import llm_model_list, llm_router
print_verbose(f"async qdrant semantic-cache set_cache, kwargs: {kwargs}") print_verbose(f"async qdrant semantic-cache set_cache, kwargs: {kwargs}")
# get the prompt # get the prompt
@ -1519,21 +1513,21 @@ class QdrantSemanticCache(BaseCache):
"payload": { "payload": {
"text": prompt, "text": prompt,
"response": value, "response": value,
} },
}, },
] ]
} }
keys = await self.async_client.put( keys = await self.async_client.put(
url=f"{self.qdrant_url}/collections/{self.collection_name}/points", url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points",
headers=self.headers, headers=self.headers,
json=data json=data,
) )
return return
async def async_get_cache(self, key, **kwargs): async def async_get_cache(self, key, **kwargs):
print_verbose(f"async qdrant semantic-cache get_cache, kwargs: {kwargs}") print_verbose(f"async qdrant semantic-cache get_cache, kwargs: {kwargs}")
from litellm.proxy.proxy_server import llm_router, llm_model_list from litellm.proxy.proxy_server import llm_model_list, llm_router
# get the messages # get the messages
messages = kwargs["messages"] messages = kwargs["messages"]
@ -1578,14 +1572,14 @@ class QdrantSemanticCache(BaseCache):
"oversampling": 3.0, "oversampling": 3.0,
} }
}, },
"limit":1, "limit": 1,
"with_payload": True "with_payload": True,
} }
search_response = await self.async_client.post( search_response = await self.async_client.post(
url=f"{self.qdrant_url}/collections/{self.collection_name}/points/search", url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search",
headers=self.headers, headers=self.headers,
json=data json=data,
) )
results = search_response.json()["result"] results = search_response.json()["result"]
@ -1624,6 +1618,7 @@ class QdrantSemanticCache(BaseCache):
async def _collection_info(self): async def _collection_info(self):
return self.collection_info return self.collection_info
class S3Cache(BaseCache): class S3Cache(BaseCache):
def __init__( def __init__(
self, self,
@ -2129,12 +2124,11 @@ class Cache:
redis_semantic_cache_embedding_model="text-embedding-ada-002", redis_semantic_cache_embedding_model="text-embedding-ada-002",
redis_flush_size=None, redis_flush_size=None,
disk_cache_dir=None, disk_cache_dir=None,
qdrant_url: Optional[str] = None, qdrant_api_base: Optional[str] = None,
qdrant_api_key: Optional[str] = None, qdrant_api_key: Optional[str] = None,
qdrant_collection_name: Optional[str] = None, qdrant_collection_name: Optional[str] = None,
qdrant_quantization_config: Optional[str] = None, qdrant_quantization_config: Optional[str] = None,
qdrant_semantic_cache_embedding_model="text-embedding-ada-002", qdrant_semantic_cache_embedding_model="text-embedding-ada-002",
qdrant_host_type: Optional[Literal["local","cloud"]] = "local",
**kwargs, **kwargs,
): ):
""" """
@ -2145,9 +2139,8 @@ class Cache:
host (str, optional): The host address for the Redis cache. Required if type is "redis". host (str, optional): The host address for the Redis cache. Required if type is "redis".
port (int, optional): The port number for the Redis cache. Required if type is "redis". port (int, optional): The port number for the Redis cache. Required if type is "redis".
password (str, optional): The password for the Redis cache. Required if type is "redis". password (str, optional): The password for the Redis cache. Required if type is "redis".
qdrant_url (str, optional): The url for your qdrant cluster. Required if type is "qdrant-semantic". qdrant_api_base (str, optional): The url for your qdrant cluster. Required if type is "qdrant-semantic".
qdrant_api_key (str, optional): The api_key for the local or cloud qdrant cluster. Required if qdrant_host_type is "cloud" and optional if qdrant_host_type is "local". qdrant_api_key (str, optional): The api_key for the local or cloud qdrant cluster.
qdrant_host_type (str, optional): Can be either "local" or "cloud". Should be "local" when you are running a local qdrant cluster or "cloud" when you are using a qdrant cloud cluster.
qdrant_collection_name (str, optional): The name for your qdrant collection. Required if type is "qdrant-semantic". qdrant_collection_name (str, optional): The name for your qdrant collection. Required if type is "qdrant-semantic".
similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic" or "qdrant-semantic". similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic" or "qdrant-semantic".
@ -2176,13 +2169,12 @@ class Cache:
) )
elif type == "qdrant-semantic": elif type == "qdrant-semantic":
self.cache = QdrantSemanticCache( self.cache = QdrantSemanticCache(
qdrant_url= qdrant_url, qdrant_api_base=qdrant_api_base,
qdrant_api_key= qdrant_api_key, qdrant_api_key=qdrant_api_key,
collection_name= qdrant_collection_name, collection_name=qdrant_collection_name,
similarity_threshold= similarity_threshold, similarity_threshold=similarity_threshold,
quantization_config= qdrant_quantization_config, quantization_config=qdrant_quantization_config,
embedding_model= qdrant_semantic_cache_embedding_model, embedding_model=qdrant_semantic_cache_embedding_model,
host_type=qdrant_host_type
) )
elif type == "local": elif type == "local":
self.cache = InMemoryCache() self.cache = InMemoryCache()

View file

@ -210,7 +210,7 @@ class Logging:
self.optional_params = optional_params self.optional_params = optional_params
self.model = model self.model = model
self.user = user self.user = user
self.litellm_params = litellm_params self.litellm_params = scrub_sensitive_keys_in_metadata(litellm_params)
self.logger_fn = litellm_params.get("logger_fn", None) self.logger_fn = litellm_params.get("logger_fn", None)
verbose_logger.debug(f"self.optional_params: {self.optional_params}") verbose_logger.debug(f"self.optional_params: {self.optional_params}")
@ -2353,3 +2353,28 @@ def get_standard_logging_object_payload(
"Error creating standard logging object - {}".format(str(e)) "Error creating standard logging object - {}".format(str(e))
) )
return None return None
def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]):
if litellm_params is None:
litellm_params = {}
metadata = litellm_params.get("metadata", {}) or {}
## check user_api_key_metadata for sensitive logging keys
cleaned_user_api_key_metadata = {}
if "user_api_key_metadata" in metadata and isinstance(
metadata["user_api_key_metadata"], dict
):
for k, v in metadata["user_api_key_metadata"].items():
if k == "logging": # prevent logging user logging keys
cleaned_user_api_key_metadata[k] = (
"scrubbed_by_litellm_for_sensitive_keys"
)
else:
cleaned_user_api_key_metadata[k] = v
metadata["user_api_key_metadata"] = cleaned_user_api_key_metadata
litellm_params["metadata"] = metadata
return litellm_params

View file

@ -188,9 +188,11 @@ class GoogleAIStudioGeminiConfig: # key diff from VertexAI - 'frequency_penalty
elif value["type"] == "text": # type: ignore elif value["type"] == "text": # type: ignore
optional_params["response_mime_type"] = "text/plain" optional_params["response_mime_type"] = "text/plain"
if "response_schema" in value: # type: ignore if "response_schema" in value: # type: ignore
optional_params["response_mime_type"] = "application/json"
optional_params["response_schema"] = value["response_schema"] # type: ignore optional_params["response_schema"] = value["response_schema"] # type: ignore
elif value["type"] == "json_schema": # type: ignore elif value["type"] == "json_schema": # type: ignore
if "json_schema" in value and "schema" in value["json_schema"]: # type: ignore if "json_schema" in value and "schema" in value["json_schema"]: # type: ignore
optional_params["response_mime_type"] = "application/json"
optional_params["response_schema"] = value["json_schema"]["schema"] # type: ignore optional_params["response_schema"] = value["json_schema"]["schema"] # type: ignore
if param == "tools" and isinstance(value, list): if param == "tools" and isinstance(value, list):
gtool_func_declarations = [] gtool_func_declarations = []
@ -400,9 +402,11 @@ class VertexGeminiConfig:
elif value["type"] == "text": elif value["type"] == "text":
optional_params["response_mime_type"] = "text/plain" optional_params["response_mime_type"] = "text/plain"
if "response_schema" in value: if "response_schema" in value:
optional_params["response_mime_type"] = "application/json"
optional_params["response_schema"] = value["response_schema"] optional_params["response_schema"] = value["response_schema"]
elif value["type"] == "json_schema": # type: ignore elif value["type"] == "json_schema": # type: ignore
if "json_schema" in value and "schema" in value["json_schema"]: # type: ignore if "json_schema" in value and "schema" in value["json_schema"]: # type: ignore
optional_params["response_mime_type"] = "application/json"
optional_params["response_schema"] = value["json_schema"]["schema"] # type: ignore optional_params["response_schema"] = value["json_schema"]["schema"] # type: ignore
if param == "frequency_penalty": if param == "frequency_penalty":
optional_params["frequency_penalty"] = value optional_params["frequency_penalty"] = value

View file

@ -4,4 +4,4 @@ model_list:
model: "*" model: "*"
general_settings: general_settings:
disable_adding_master_key_hash_to_db: True disable_adding_master_key_hash_to_db: True

View file

@ -21,6 +21,13 @@ else:
Span = Any Span = Any
class LiteLLMTeamRoles(enum.Enum):
# team admin
TEAM_ADMIN = "admin"
# team member
TEAM_MEMBER = "user"
class LitellmUserRoles(str, enum.Enum): class LitellmUserRoles(str, enum.Enum):
""" """
Admin Roles: Admin Roles:
@ -335,6 +342,11 @@ class LiteLLMRoutes(enum.Enum):
+ sso_only_routes + sso_only_routes
) )
self_managed_routes: List = [
"/team/member_add",
"/team/member_delete",
] # routes that manage their own allowed/disallowed logic
# class LiteLLMAllowedRoutes(LiteLLMBase): # class LiteLLMAllowedRoutes(LiteLLMBase):
# """ # """
@ -1308,6 +1320,7 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
soft_budget: Optional[float] = None soft_budget: Optional[float] = None
team_model_aliases: Optional[Dict] = None team_model_aliases: Optional[Dict] = None
team_member_spend: Optional[float] = None team_member_spend: Optional[float] = None
team_member: Optional[Member] = None
team_metadata: Optional[Dict] = None team_metadata: Optional[Dict] = None
# End User Params # End User Params

View file

@ -975,8 +975,6 @@ async def user_api_key_auth(
if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin
if is_llm_api_route(route=route): if is_llm_api_route(route=route):
pass pass
elif is_llm_api_route(route=request["route"].name):
pass
elif ( elif (
route in LiteLLMRoutes.info_routes.value route in LiteLLMRoutes.info_routes.value
): # check if user allowed to call an info route ): # check if user allowed to call an info route
@ -1046,11 +1044,16 @@ async def user_api_key_auth(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail=f"user not allowed to access this route, role= {_user_role}. Trying to access: {route}", detail=f"user not allowed to access this route, role= {_user_role}. Trying to access: {route}",
) )
elif ( elif (
_user_role == LitellmUserRoles.INTERNAL_USER.value _user_role == LitellmUserRoles.INTERNAL_USER.value
and route in LiteLLMRoutes.internal_user_routes.value and route in LiteLLMRoutes.internal_user_routes.value
): ):
pass pass
elif (
route in LiteLLMRoutes.self_managed_routes.value
): # routes that manage their own allowed/disallowed logic
pass
else: else:
user_role = "unknown" user_role = "unknown"
user_id = "unknown" user_id = "unknown"

View file

@ -285,14 +285,18 @@ def get_remaining_tokens_and_requests_from_request_data(data: Dict) -> Dict[str,
return headers return headers
def get_applied_guardrails_header(request_data: Dict) -> Optional[Dict]: def get_logging_caching_headers(request_data: Dict) -> Optional[Dict]:
_metadata = request_data.get("metadata", None) or {} _metadata = request_data.get("metadata", None) or {}
headers = {}
if "applied_guardrails" in _metadata: if "applied_guardrails" in _metadata:
return { headers["x-litellm-applied-guardrails"] = ",".join(
"x-litellm-applied-guardrails": ",".join(_metadata["applied_guardrails"]), _metadata["applied_guardrails"]
} )
return None if "semantic-similarity" in _metadata:
headers["x-litellm-semantic-similarity"] = str(_metadata["semantic-similarity"])
return headers
def add_guardrail_to_applied_guardrails_header( def add_guardrail_to_applied_guardrails_header(

View file

@ -95,7 +95,9 @@ def convert_key_logging_metadata_to_callback(
for var, value in data.callback_vars.items(): for var, value in data.callback_vars.items():
if team_callback_settings_obj.callback_vars is None: if team_callback_settings_obj.callback_vars is None:
team_callback_settings_obj.callback_vars = {} team_callback_settings_obj.callback_vars = {}
team_callback_settings_obj.callback_vars[var] = litellm.get_secret(value) team_callback_settings_obj.callback_vars[var] = (
litellm.utils.get_secret(value, default_value=value) or value
)
return team_callback_settings_obj return team_callback_settings_obj
@ -130,7 +132,6 @@ def _get_dynamic_logging_metadata(
data=AddTeamCallback(**item), data=AddTeamCallback(**item),
team_callback_settings_obj=callback_settings_obj, team_callback_settings_obj=callback_settings_obj,
) )
return callback_settings_obj return callback_settings_obj

View file

@ -119,6 +119,7 @@ async def new_user(
http_request=Request( http_request=Request(
scope={"type": "http", "path": "/user/new"}, scope={"type": "http", "path": "/user/new"},
), ),
user_api_key_dict=user_api_key_dict,
) )
if data.send_invite_email is True: if data.send_invite_email is True:

View file

@ -849,7 +849,7 @@ async def generate_key_helper_fn(
} }
if ( if (
litellm.get_secret("DISABLE_KEY_NAME", False) == True litellm.get_secret("DISABLE_KEY_NAME", False) is True
): # allow user to disable storing abbreviated key name (shown in UI, to help figure out which key spent how much) ): # allow user to disable storing abbreviated key name (shown in UI, to help figure out which key spent how much)
pass pass
else: else:

View file

@ -30,7 +30,7 @@ from litellm.proxy._types import (
UpdateTeamRequest, UpdateTeamRequest,
UserAPIKeyAuth, UserAPIKeyAuth,
) )
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.auth.user_api_key_auth import _is_user_proxy_admin, user_api_key_auth
from litellm.proxy.management_helpers.utils import ( from litellm.proxy.management_helpers.utils import (
add_new_member, add_new_member,
management_endpoint_wrapper, management_endpoint_wrapper,
@ -39,6 +39,16 @@ from litellm.proxy.management_helpers.utils import (
router = APIRouter() router = APIRouter()
def _is_user_team_admin(
user_api_key_dict: UserAPIKeyAuth, team_obj: LiteLLM_TeamTable
) -> bool:
for member in team_obj.members_with_roles:
if member.user_id is not None and member.user_id == user_api_key_dict.user_id:
return True
return False
#### TEAM MANAGEMENT #### #### TEAM MANAGEMENT ####
@router.post( @router.post(
"/team/new", "/team/new",
@ -417,6 +427,7 @@ async def team_member_add(
If user doesn't exist, new user row will also be added to User Table If user doesn't exist, new user row will also be added to User Table
Only proxy_admin or admin of team, allowed to access this endpoint.
``` ```
curl -X POST 'http://0.0.0.0:4000/team/member_add' \ curl -X POST 'http://0.0.0.0:4000/team/member_add' \
@ -465,6 +476,24 @@ async def team_member_add(
complete_team_data = LiteLLM_TeamTable(**existing_team_row.model_dump()) complete_team_data = LiteLLM_TeamTable(**existing_team_row.model_dump())
## CHECK IF USER IS PROXY ADMIN OR TEAM ADMIN
if (
user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value
and not _is_user_team_admin(
user_api_key_dict=user_api_key_dict, team_obj=complete_team_data
)
):
raise HTTPException(
status_code=403,
detail={
"error": "Call not allowed. User not proxy admin OR team admin. route={}, team_id={}".format(
"/team/member_add",
complete_team_data.team_id,
)
},
)
if isinstance(data.member, Member): if isinstance(data.member, Member):
# add to team db # add to team db
new_member = data.member new_member = data.member
@ -569,6 +598,23 @@ async def team_member_delete(
) )
existing_team_row = LiteLLM_TeamTable(**_existing_team_row.model_dump()) existing_team_row = LiteLLM_TeamTable(**_existing_team_row.model_dump())
## CHECK IF USER IS PROXY ADMIN OR TEAM ADMIN
if (
user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value
and not _is_user_team_admin(
user_api_key_dict=user_api_key_dict, team_obj=existing_team_row
)
):
raise HTTPException(
status_code=403,
detail={
"error": "Call not allowed. User not proxy admin OR team admin. route={}, team_id={}".format(
"/team/member_delete", existing_team_row.team_id
)
},
)
## DELETE MEMBER FROM TEAM ## DELETE MEMBER FROM TEAM
new_team_members: List[Member] = [] new_team_members: List[Member] = []
for m in existing_team_row.members_with_roles: for m in existing_team_row.members_with_roles:

View file

@ -4,15 +4,17 @@ model_list:
model: openai/fake model: openai/fake
api_key: fake-key api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/ api_base: https://exampleopenaiendpoint-production.up.railway.app/
- model_name: openai-embedding
guardrails:
- guardrail_name: "lakera-pre-guard"
litellm_params: litellm_params:
guardrail: lakera # supported values: "aporia", "bedrock", "lakera" model: openai/text-embedding-3-small
mode: "during_call" api_key: os.environ/OPENAI_API_KEY
api_key: os.environ/LAKERA_API_KEY
api_base: os.environ/LAKERA_API_BASE litellm_settings:
category_thresholds: set_verbose: True
prompt_injection: 0.1 cache: True # set cache responses to True, litellm defaults to using a redis cache
jailbreak: 0.1 cache_params:
type: qdrant-semantic
qdrant_semantic_cache_embedding_model: openai-embedding
qdrant_collection_name: test_collection
qdrant_quantization_config: binary
similarity_threshold: 0.8 # similarity threshold for semantic cache

View file

@ -149,7 +149,7 @@ from litellm.proxy.common_utils.admin_ui_utils import (
show_missing_vars_in_env, show_missing_vars_in_env,
) )
from litellm.proxy.common_utils.callback_utils import ( from litellm.proxy.common_utils.callback_utils import (
get_applied_guardrails_header, get_logging_caching_headers,
get_remaining_tokens_and_requests_from_request_data, get_remaining_tokens_and_requests_from_request_data,
initialize_callbacks_on_proxy, initialize_callbacks_on_proxy,
) )
@ -543,9 +543,9 @@ def get_custom_headers(
) )
headers.update(remaining_tokens_header) headers.update(remaining_tokens_header)
applied_guardrails = get_applied_guardrails_header(request_data) logging_caching_headers = get_logging_caching_headers(request_data)
if applied_guardrails: if logging_caching_headers:
headers.update(applied_guardrails) headers.update(logging_caching_headers)
try: try:
return { return {

View file

@ -44,6 +44,7 @@ from litellm.proxy._types import (
DynamoDBArgs, DynamoDBArgs,
LiteLLM_VerificationTokenView, LiteLLM_VerificationTokenView,
LitellmUserRoles, LitellmUserRoles,
Member,
ResetTeamBudgetRequest, ResetTeamBudgetRequest,
SpendLogsMetadata, SpendLogsMetadata,
SpendLogsPayload, SpendLogsPayload,
@ -1395,6 +1396,7 @@ class PrismaClient:
t.blocked AS team_blocked, t.blocked AS team_blocked,
t.team_alias AS team_alias, t.team_alias AS team_alias,
t.metadata AS team_metadata, t.metadata AS team_metadata,
t.members_with_roles AS team_members_with_roles,
tm.spend AS team_member_spend, tm.spend AS team_member_spend,
m.aliases as team_model_aliases m.aliases as team_model_aliases
FROM "LiteLLM_VerificationToken" AS v FROM "LiteLLM_VerificationToken" AS v
@ -1412,6 +1414,33 @@ class PrismaClient:
response["team_models"] = [] response["team_models"] = []
if response["team_blocked"] is None: if response["team_blocked"] is None:
response["team_blocked"] = False response["team_blocked"] = False
team_member: Optional[Member] = None
if (
response["team_members_with_roles"] is not None
and response["user_id"] is not None
):
## find the team member corresponding to user id
"""
[
{
"role": "admin",
"user_id": "default_user_id",
"user_email": null
},
{
"role": "user",
"user_id": null,
"user_email": "test@email.com"
}
]
"""
for tm in response["team_members_with_roles"]:
if tm.get("user_id") is not None and response[
"user_id"
] == tm.get("user_id"):
team_member = Member(**tm)
response["team_member"] = team_member
response = LiteLLM_VerificationTokenView( response = LiteLLM_VerificationTokenView(
**response, last_refreshed_at=time.time() **response, last_refreshed_at=time.time()
) )

View file

@ -1558,6 +1558,16 @@ async def test_gemini_pro_json_schema_args_sent_httpx_openai_schema(
"response_schema" "response_schema"
in mock_call.call_args.kwargs["json"]["generationConfig"] in mock_call.call_args.kwargs["json"]["generationConfig"]
) )
assert (
"response_mime_type"
in mock_call.call_args.kwargs["json"]["generationConfig"]
)
assert (
mock_call.call_args.kwargs["json"]["generationConfig"][
"response_mime_type"
]
== "application/json"
)
else: else:
assert ( assert (
"response_schema" "response_schema"

View file

@ -1733,8 +1733,10 @@ def test_caching_redis_simple(caplog, capsys):
assert redis_service_logging_error is False assert redis_service_logging_error is False
assert "async success_callback: reaches cache for logging" not in captured.out assert "async success_callback: reaches cache for logging" not in captured.out
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_qdrant_semantic_cache_acompletion(): async def test_qdrant_semantic_cache_acompletion():
litellm.set_verbose = True
random_number = random.randint( random_number = random.randint(
1, 100000 1, 100000
) # add a random number to ensure it's always adding /reading from cache ) # add a random number to ensure it's always adding /reading from cache
@ -1742,13 +1744,13 @@ async def test_qdrant_semantic_cache_acompletion():
print("Testing Qdrant Semantic Caching with acompletion") print("Testing Qdrant Semantic Caching with acompletion")
litellm.cache = Cache( litellm.cache = Cache(
type="qdrant-semantic", type="qdrant-semantic",
qdrant_host_type="cloud", _host_type="cloud",
qdrant_url=os.getenv("QDRANT_URL"), qdrant_api_base=os.getenv("QDRANT_URL"),
qdrant_api_key=os.getenv("QDRANT_API_KEY"), qdrant_api_key=os.getenv("QDRANT_API_KEY"),
qdrant_collection_name='test_collection', qdrant_collection_name="test_collection",
similarity_threshold=0.8, similarity_threshold=0.8,
qdrant_quantization_config="binary" qdrant_quantization_config="binary",
) )
response1 = await litellm.acompletion( response1 = await litellm.acompletion(
@ -1759,6 +1761,7 @@ async def test_qdrant_semantic_cache_acompletion():
"content": f"write a one sentence poem about: {random_number}", "content": f"write a one sentence poem about: {random_number}",
} }
], ],
mock_response="hello",
max_tokens=20, max_tokens=20,
) )
print(f"Response1: {response1}") print(f"Response1: {response1}")
@ -1778,6 +1781,7 @@ async def test_qdrant_semantic_cache_acompletion():
print(f"Response2: {response2}") print(f"Response2: {response2}")
assert response1.id == response2.id assert response1.id == response2.id
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_qdrant_semantic_cache_acompletion_stream(): async def test_qdrant_semantic_cache_acompletion_stream():
try: try:
@ -1789,13 +1793,12 @@ async def test_qdrant_semantic_cache_acompletion_stream():
} }
] ]
litellm.cache = Cache( litellm.cache = Cache(
type="qdrant-semantic", type="qdrant-semantic",
qdrant_host_type="cloud", qdrant_api_base=os.getenv("QDRANT_URL"),
qdrant_url=os.getenv("QDRANT_URL"), qdrant_api_key=os.getenv("QDRANT_API_KEY"),
qdrant_api_key=os.getenv("QDRANT_API_KEY"), qdrant_collection_name="test_collection",
qdrant_collection_name='test_collection', similarity_threshold=0.8,
similarity_threshold=0.8, qdrant_quantization_config="binary",
qdrant_quantization_config="binary"
) )
print("Test Qdrant Semantic Caching with streaming + acompletion") print("Test Qdrant Semantic Caching with streaming + acompletion")
response_1_content = "" response_1_content = ""
@ -1807,6 +1810,7 @@ async def test_qdrant_semantic_cache_acompletion_stream():
max_tokens=40, max_tokens=40,
temperature=1, temperature=1,
stream=True, stream=True,
mock_response="hi",
) )
async for chunk in response1: async for chunk in response1:
response_1_id = chunk.id response_1_id = chunk.id
@ -1830,7 +1834,9 @@ async def test_qdrant_semantic_cache_acompletion_stream():
assert ( assert (
response_1_content == response_2_content response_1_content == response_2_content
), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}" ), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
assert (response_1_id == response_2_id), f"Response 1 id != Response 2 id, Response 1 id: {response_1_id} != Response 2 id: {response_2_id}" assert (
response_1_id == response_2_id
), f"Response 1 id != Response 2 id, Response 1 id: {response_1_id} != Response 2 id: {response_2_id}"
litellm.cache = None litellm.cache = None
litellm.success_callback = [] litellm.success_callback = []
litellm._async_success_callback = [] litellm._async_success_callback = []

View file

@ -3283,9 +3283,9 @@ def test_completion_together_ai_mixtral():
# test_completion_together_ai_mixtral() # test_completion_together_ai_mixtral()
def test_completion_together_ai_yi_chat(): def test_completion_together_ai_llama():
litellm.set_verbose = True litellm.set_verbose = True
model_name = "together_ai/mistralai/Mistral-7B-Instruct-v0.1" model_name = "together_ai/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo"
try: try:
messages = [ messages = [
{"role": "user", "content": "What llm are you?"}, {"role": "user", "content": "What llm are you?"},

View file

@ -909,7 +909,7 @@ async def test_create_team_member_add(prisma_client, new_member_method):
await team_member_add( await team_member_add(
data=team_member_add_request, data=team_member_add_request,
user_api_key_dict=UserAPIKeyAuth(), user_api_key_dict=UserAPIKeyAuth(user_role="proxy_admin"),
http_request=Request( http_request=Request(
scope={"type": "http", "path": "/user/new"}, scope={"type": "http", "path": "/user/new"},
), ),
@ -930,6 +930,172 @@ async def test_create_team_member_add(prisma_client, new_member_method):
) )
@pytest.mark.parametrize("team_member_role", ["admin", "user"])
@pytest.mark.parametrize("team_route", ["/team/member_add", "/team/member_delete"])
@pytest.mark.asyncio
async def test_create_team_member_add_team_admin_user_api_key_auth(
prisma_client, team_member_role, team_route
):
import time
from fastapi import Request
from litellm.proxy._types import LiteLLM_TeamTableCachedObj, Member
from litellm.proxy.proxy_server import (
ProxyException,
hash_token,
user_api_key_auth,
user_api_key_cache,
)
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm, "max_internal_user_budget", 10)
setattr(litellm, "internal_user_budget_duration", "5m")
await litellm.proxy.proxy_server.prisma_client.connect()
user = f"ishaan {uuid.uuid4().hex}"
_team_id = "litellm-test-client-id-new"
user_key = "sk-12345678"
valid_token = UserAPIKeyAuth(
team_id=_team_id,
token=hash_token(user_key),
team_member=Member(role=team_member_role, user_id=user),
last_refreshed_at=time.time(),
)
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
team_obj = LiteLLM_TeamTableCachedObj(
team_id=_team_id,
blocked=False,
last_refreshed_at=time.time(),
metadata={"guardrails": {"modify_guardrails": False}},
)
user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
## TEST IF TEAM ADMIN ALLOWED TO CALL /MEMBER_ADD ENDPOINT
import json
from starlette.datastructures import URL
request = Request(scope={"type": "http"})
request._url = URL(url=team_route)
body = {}
json_bytes = json.dumps(body).encode("utf-8")
request._body = json_bytes
## ALLOWED BY USER_API_KEY_AUTH
await user_api_key_auth(request=request, api_key="Bearer " + user_key)
@pytest.mark.parametrize("new_member_method", ["user_id", "user_email"])
@pytest.mark.parametrize("user_role", ["admin", "user"])
@pytest.mark.asyncio
async def test_create_team_member_add_team_admin(
prisma_client, new_member_method, user_role
):
"""
Relevant issue - https://github.com/BerriAI/litellm/issues/5300
Allow team admins to:
- Add and remove team members
- raise error if team member not an existing 'internal_user'
"""
import time
from fastapi import Request
from litellm.proxy._types import LiteLLM_TeamTableCachedObj, Member
from litellm.proxy.proxy_server import (
HTTPException,
ProxyException,
hash_token,
user_api_key_auth,
user_api_key_cache,
)
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
setattr(litellm, "max_internal_user_budget", 10)
setattr(litellm, "internal_user_budget_duration", "5m")
await litellm.proxy.proxy_server.prisma_client.connect()
user = f"ishaan {uuid.uuid4().hex}"
_team_id = "litellm-test-client-id-new"
user_key = "sk-12345678"
valid_token = UserAPIKeyAuth(
team_id=_team_id,
user_id=user,
token=hash_token(user_key),
last_refreshed_at=time.time(),
)
user_api_key_cache.set_cache(key=hash_token(user_key), value=valid_token)
team_obj = LiteLLM_TeamTableCachedObj(
team_id=_team_id,
blocked=False,
last_refreshed_at=time.time(),
members_with_roles=[Member(role=user_role, user_id=user)],
metadata={"guardrails": {"modify_guardrails": False}},
)
user_api_key_cache.set_cache(key="team_id:{}".format(_team_id), value=team_obj)
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
if new_member_method == "user_id":
data = {
"team_id": _team_id,
"member": [{"role": "user", "user_id": user}],
}
elif new_member_method == "user_email":
data = {
"team_id": _team_id,
"member": [{"role": "user", "user_email": user}],
}
team_member_add_request = TeamMemberAddRequest(**data)
with patch(
"litellm.proxy.proxy_server.prisma_client.db.litellm_usertable",
new_callable=AsyncMock,
) as mock_litellm_usertable:
mock_client = AsyncMock()
mock_litellm_usertable.upsert = mock_client
mock_litellm_usertable.find_many = AsyncMock(return_value=None)
try:
await team_member_add(
data=team_member_add_request,
user_api_key_dict=valid_token,
http_request=Request(
scope={"type": "http", "path": "/user/new"},
),
)
except HTTPException as e:
if user_role == "user":
assert e.status_code == 403
else:
raise e
mock_client.assert_called()
print(f"mock_client.call_args: {mock_client.call_args}")
print("mock_client.call_args.kwargs: {}".format(mock_client.call_args.kwargs))
assert (
mock_client.call_args.kwargs["data"]["create"]["max_budget"]
== litellm.max_internal_user_budget
)
assert (
mock_client.call_args.kwargs["data"]["create"]["budget_duration"]
== litellm.internal_user_budget_duration
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_user_info_team_list(prisma_client): async def test_user_info_team_list(prisma_client):
"""Assert user_info for admin calls team_list function""" """Assert user_info for admin calls team_list function"""
@ -1116,8 +1282,8 @@ async def test_add_callback_via_key_litellm_pre_call_utils(prisma_client):
"callback_name": "langfuse", "callback_name": "langfuse",
"callback_type": "success", "callback_type": "success",
"callback_vars": { "callback_vars": {
"langfuse_public_key": "os.environ/LANGFUSE_PUBLIC_KEY", "langfuse_public_key": "my-mock-public-key",
"langfuse_secret_key": "os.environ/LANGFUSE_SECRET_KEY", "langfuse_secret_key": "my-mock-secret-key",
"langfuse_host": "https://us.cloud.langfuse.com", "langfuse_host": "https://us.cloud.langfuse.com",
}, },
} }
@ -1165,7 +1331,9 @@ async def test_add_callback_via_key_litellm_pre_call_utils(prisma_client):
assert "success_callback" in new_data assert "success_callback" in new_data
assert new_data["success_callback"] == ["langfuse"] assert new_data["success_callback"] == ["langfuse"]
assert "langfuse_public_key" in new_data assert "langfuse_public_key" in new_data
assert new_data["langfuse_public_key"] == "my-mock-public-key"
assert "langfuse_secret_key" in new_data assert "langfuse_secret_key" in new_data
assert new_data["langfuse_secret_key"] == "my-mock-secret-key"
@pytest.mark.asyncio @pytest.mark.asyncio

View file

@ -121,7 +121,7 @@ import importlib.metadata
from openai import OpenAIError as OriginalError from openai import OpenAIError as OriginalError
from ._logging import verbose_logger from ._logging import verbose_logger
from .caching import RedisCache, RedisSemanticCache, S3Cache, QdrantSemanticCache from .caching import QdrantSemanticCache, RedisCache, RedisSemanticCache, S3Cache
from .exceptions import ( from .exceptions import (
APIConnectionError, APIConnectionError,
APIError, APIError,
@ -8622,7 +8622,9 @@ def get_secret(
return secret_value_as_bool return secret_value_as_bool
else: else:
return secret return secret
except: except Exception:
if default_value is not None:
return default_value
return secret return secret
except Exception as e: except Exception as e:
if default_value is not None: if default_value is not None: