LiteLLM Minor Fixes & Improvements (09/21/2024) (#5819)

* fix(router.py): fix error message

* Litellm disable keys (#5814)

* build(schema.prisma): allow blocking/unblocking keys

Fixes https://github.com/BerriAI/litellm/issues/5328

* fix(key_management_endpoints.py): fix pop

* feat(auth_checks.py): allow admin to enable/disable virtual keys

Closes https://github.com/BerriAI/litellm/issues/5328

* docs(vertex.md): add auth section for vertex ai

Addresses - https://github.com/BerriAI/litellm/issues/5768#issuecomment-2365284223

* build(model_prices_and_context_window.json): show which models support prompt_caching

Closes https://github.com/BerriAI/litellm/issues/5776

* fix(router.py): allow setting default priority for requests

* fix(router.py): add 'retry-after' header for concurrent request limit errors

Fixes https://github.com/BerriAI/litellm/issues/5783

* fix(router.py): correctly raise and use retry-after header from azure+openai

Fixes https://github.com/BerriAI/litellm/issues/5783

* fix(user_api_key_auth.py): fix valid token being none

* fix(auth_checks.py): fix model dump for cache management object

* fix(user_api_key_auth.py): pass prisma_client to obj

* test(test_otel.py): update test for new key check

* test: fix test
This commit is contained in:
Krish Dholakia 2024-09-21 18:51:53 -07:00 committed by GitHub
parent 1ca638973f
commit 8039b95aaf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 1006 additions and 182 deletions

View file

@ -207,7 +207,7 @@ class RedisCache(BaseCache):
host=None,
port=None,
password=None,
redis_flush_size=100,
redis_flush_size: Optional[int] = 100,
namespace: Optional[str] = None,
startup_nodes: Optional[List] = None, # for redis-cluster
**kwargs,
@ -244,7 +244,10 @@ class RedisCache(BaseCache):
self.namespace = namespace
# for high traffic, we store the redis results in memory and then batch write to redis
self.redis_batch_writing_buffer: list = []
self.redis_flush_size = redis_flush_size
if redis_flush_size is None:
self.redis_flush_size: int = 100
else:
self.redis_flush_size = redis_flush_size
self.redis_version = "Unknown"
try:
self.redis_version = self.redis_client.info()["redis_version"]
@ -317,7 +320,7 @@ class RedisCache(BaseCache):
current_ttl = _redis_client.ttl(key)
if current_ttl == -1:
# Key has no expiration
_redis_client.expire(key, ttl)
_redis_client.expire(key, ttl) # type: ignore
return result
except Exception as e:
## LOGGING ##
@ -331,10 +334,13 @@ class RedisCache(BaseCache):
raise e
async def async_scan_iter(self, pattern: str, count: int = 100) -> list:
from redis.asyncio import Redis
start_time = time.time()
try:
keys = []
_redis_client = self.init_async_client()
_redis_client: Redis = self.init_async_client() # type: ignore
async with _redis_client as redis_client:
async for key in redis_client.scan_iter(
match=pattern + "*", count=count
@ -374,9 +380,11 @@ class RedisCache(BaseCache):
raise e
async def async_set_cache(self, key, value, **kwargs):
from redis.asyncio import Redis
start_time = time.time()
try:
_redis_client = self.init_async_client()
_redis_client: Redis = self.init_async_client() # type: ignore
except Exception as e:
end_time = time.time()
_duration = end_time - start_time
@ -397,6 +405,7 @@ class RedisCache(BaseCache):
str(e),
value,
)
raise e
key = self.check_and_fix_namespace(key=key)
async with _redis_client as redis_client:
@ -405,6 +414,10 @@ class RedisCache(BaseCache):
f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
)
try:
if not hasattr(redis_client, "set"):
raise Exception(
"Redis client cannot set cache. Attribute not found."
)
await redis_client.set(name=key, value=json.dumps(value), ex=ttl)
print_verbose(
f"Successfully Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
@ -446,12 +459,15 @@ class RedisCache(BaseCache):
"""
Use Redis Pipelines for bulk write operations
"""
_redis_client = self.init_async_client()
from redis.asyncio import Redis
_redis_client: Redis = self.init_async_client() # type: ignore
start_time = time.time()
print_verbose(
f"Set Async Redis Cache: key list: {cache_list}\nttl={ttl}, redis_version={self.redis_version}"
)
cache_value: Any = None
try:
async with _redis_client as redis_client:
async with redis_client.pipeline(transaction=True) as pipe:
@ -463,6 +479,7 @@ class RedisCache(BaseCache):
)
json_cache_value = json.dumps(cache_value)
# Set the value with a TTL if it's provided.
if ttl is not None:
pipe.setex(cache_key, ttl, json_cache_value)
else:
@ -511,9 +528,11 @@ class RedisCache(BaseCache):
async def async_set_cache_sadd(
self, key, value: List, ttl: Optional[float], **kwargs
):
from redis.asyncio import Redis
start_time = time.time()
try:
_redis_client = self.init_async_client()
_redis_client: Redis = self.init_async_client() # type: ignore
except Exception as e:
end_time = time.time()
_duration = end_time - start_time
@ -592,9 +611,11 @@ class RedisCache(BaseCache):
await self.flush_cache_buffer() # logging done in here
async def async_increment(
self, key, value: float, ttl: Optional[float] = None, **kwargs
self, key, value: float, ttl: Optional[int] = None, **kwargs
) -> float:
_redis_client = self.init_async_client()
from redis.asyncio import Redis
_redis_client: Redis = self.init_async_client() # type: ignore
start_time = time.time()
try:
async with _redis_client as redis_client:
@ -708,7 +729,9 @@ class RedisCache(BaseCache):
return key_value_dict
async def async_get_cache(self, key, **kwargs):
_redis_client = self.init_async_client()
from redis.asyncio import Redis
_redis_client: Redis = self.init_async_client() # type: ignore
key = self.check_and_fix_namespace(key=key)
start_time = time.time()
async with _redis_client as redis_client:
@ -903,6 +926,12 @@ class RedisCache(BaseCache):
async def disconnect(self):
await self.async_redis_conn_pool.disconnect(inuse_connections=True)
async def async_delete_cache(self, key: str):
_redis_client = self.init_async_client()
# keys is str
async with _redis_client as redis_client:
await redis_client.delete(key)
def delete_cache(self, key):
self.redis_client.delete(key)
@ -1241,6 +1270,7 @@ class QdrantSemanticCache(BaseCache):
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.secret_managers.main import get_secret_str
if collection_name is None:
raise Exception("collection_name must be provided, passed None")
@ -1261,12 +1291,12 @@ class QdrantSemanticCache(BaseCache):
if isinstance(qdrant_api_base, str) and qdrant_api_base.startswith(
"os.environ/"
):
qdrant_api_base = litellm.get_secret(qdrant_api_base)
qdrant_api_base = get_secret_str(qdrant_api_base)
if qdrant_api_key:
if isinstance(qdrant_api_key, str) and qdrant_api_key.startswith(
"os.environ/"
):
qdrant_api_key = litellm.get_secret(qdrant_api_key)
qdrant_api_key = get_secret_str(qdrant_api_key)
qdrant_api_base = (
qdrant_api_base or os.getenv("QDRANT_URL") or os.getenv("QDRANT_API_BASE")
@ -1633,7 +1663,7 @@ class S3Cache(BaseCache):
s3_bucket_name,
s3_region_name=None,
s3_api_version=None,
s3_use_ssl=True,
s3_use_ssl: Optional[bool] = True,
s3_verify=None,
s3_endpoint_url=None,
s3_aws_access_key_id=None,
@ -1721,7 +1751,7 @@ class S3Cache(BaseCache):
Bucket=self.bucket_name, Key=key
)
if cached_response != None:
if cached_response is not None:
# cached_response is in `b{} convert it to ModelResponse
cached_response = (
cached_response["Body"].read().decode("utf-8")
@ -1739,7 +1769,7 @@ class S3Cache(BaseCache):
)
return cached_response
except botocore.exceptions.ClientError as e:
except botocore.exceptions.ClientError as e: # type: ignore
if e.response["Error"]["Code"] == "NoSuchKey":
verbose_logger.debug(
f"S3 Cache: The specified key '{key}' does not exist in the S3 bucket."
@ -2081,6 +2111,15 @@ class DualCache(BaseCache):
if self.redis_cache is not None:
self.redis_cache.delete_cache(key)
async def async_delete_cache(self, key: str):
"""
Delete a key from the cache
"""
if self.in_memory_cache is not None:
self.in_memory_cache.delete_cache(key)
if self.redis_cache is not None:
await self.redis_cache.async_delete_cache(key)
#### LiteLLM.Completion / Embedding Cache ####
class Cache:
@ -2137,7 +2176,7 @@ class Cache:
s3_path: Optional[str] = None,
redis_semantic_cache_use_async=False,
redis_semantic_cache_embedding_model="text-embedding-ada-002",
redis_flush_size=None,
redis_flush_size: Optional[int] = None,
redis_startup_nodes: Optional[List] = None,
disk_cache_dir=None,
qdrant_api_base: Optional[str] = None,
@ -2501,10 +2540,9 @@ class Cache:
if self.ttl is not None:
kwargs["ttl"] = self.ttl
## Get Cache-Controls ##
if kwargs.get("cache", None) is not None and isinstance(
kwargs.get("cache"), dict
):
for k, v in kwargs.get("cache").items():
_cache_kwargs = kwargs.get("cache", None)
if isinstance(_cache_kwargs, dict):
for k, v in _cache_kwargs.items():
if k == "ttl":
kwargs["ttl"] = v
@ -2574,14 +2612,15 @@ class Cache:
**kwargs,
)
cache_list.append((cache_key, cached_data))
if hasattr(self.cache, "async_set_cache_pipeline"):
await self.cache.async_set_cache_pipeline(cache_list=cache_list)
async_set_cache_pipeline = getattr(
self.cache, "async_set_cache_pipeline", None
)
if async_set_cache_pipeline:
await async_set_cache_pipeline(cache_list=cache_list)
else:
tasks = []
for val in cache_list:
tasks.append(
self.cache.async_set_cache(cache_key, cached_data, **kwargs)
)
tasks.append(self.cache.async_set_cache(val[0], val[1], **kwargs))
await asyncio.gather(*tasks)
except Exception as e:
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
@ -2611,13 +2650,15 @@ class Cache:
await self.cache.batch_cache_write(cache_key, cached_data, **kwargs)
async def ping(self):
if hasattr(self.cache, "ping"):
return await self.cache.ping()
cache_ping = getattr(self.cache, "ping")
if cache_ping:
return await cache_ping()
return None
async def delete_cache_keys(self, keys):
if hasattr(self.cache, "delete_cache_keys"):
return await self.cache.delete_cache_keys(keys)
cache_delete_cache_keys = getattr(self.cache, "delete_cache_keys")
if cache_delete_cache_keys:
return await cache_delete_cache_keys(keys)
return None
async def disconnect(self):