mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
LiteLLM Minor Fixes & Improvements (09/21/2024) (#5819)
* fix(router.py): fix error message * Litellm disable keys (#5814) * build(schema.prisma): allow blocking/unblocking keys Fixes https://github.com/BerriAI/litellm/issues/5328 * fix(key_management_endpoints.py): fix pop * feat(auth_checks.py): allow admin to enable/disable virtual keys Closes https://github.com/BerriAI/litellm/issues/5328 * docs(vertex.md): add auth section for vertex ai Addresses - https://github.com/BerriAI/litellm/issues/5768#issuecomment-2365284223 * build(model_prices_and_context_window.json): show which models support prompt_caching Closes https://github.com/BerriAI/litellm/issues/5776 * fix(router.py): allow setting default priority for requests * fix(router.py): add 'retry-after' header for concurrent request limit errors Fixes https://github.com/BerriAI/litellm/issues/5783 * fix(router.py): correctly raise and use retry-after header from azure+openai Fixes https://github.com/BerriAI/litellm/issues/5783 * fix(user_api_key_auth.py): fix valid token being none * fix(auth_checks.py): fix model dump for cache management object * fix(user_api_key_auth.py): pass prisma_client to obj * test(test_otel.py): update test for new key check * test: fix test
This commit is contained in:
parent
f0543a6f9d
commit
f3fa2160a0
25 changed files with 1006 additions and 182 deletions
|
@ -207,7 +207,7 @@ class RedisCache(BaseCache):
|
|||
host=None,
|
||||
port=None,
|
||||
password=None,
|
||||
redis_flush_size=100,
|
||||
redis_flush_size: Optional[int] = 100,
|
||||
namespace: Optional[str] = None,
|
||||
startup_nodes: Optional[List] = None, # for redis-cluster
|
||||
**kwargs,
|
||||
|
@ -244,7 +244,10 @@ class RedisCache(BaseCache):
|
|||
self.namespace = namespace
|
||||
# for high traffic, we store the redis results in memory and then batch write to redis
|
||||
self.redis_batch_writing_buffer: list = []
|
||||
self.redis_flush_size = redis_flush_size
|
||||
if redis_flush_size is None:
|
||||
self.redis_flush_size: int = 100
|
||||
else:
|
||||
self.redis_flush_size = redis_flush_size
|
||||
self.redis_version = "Unknown"
|
||||
try:
|
||||
self.redis_version = self.redis_client.info()["redis_version"]
|
||||
|
@ -317,7 +320,7 @@ class RedisCache(BaseCache):
|
|||
current_ttl = _redis_client.ttl(key)
|
||||
if current_ttl == -1:
|
||||
# Key has no expiration
|
||||
_redis_client.expire(key, ttl)
|
||||
_redis_client.expire(key, ttl) # type: ignore
|
||||
return result
|
||||
except Exception as e:
|
||||
## LOGGING ##
|
||||
|
@ -331,10 +334,13 @@ class RedisCache(BaseCache):
|
|||
raise e
|
||||
|
||||
async def async_scan_iter(self, pattern: str, count: int = 100) -> list:
|
||||
from redis.asyncio import Redis
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
keys = []
|
||||
_redis_client = self.init_async_client()
|
||||
_redis_client: Redis = self.init_async_client() # type: ignore
|
||||
|
||||
async with _redis_client as redis_client:
|
||||
async for key in redis_client.scan_iter(
|
||||
match=pattern + "*", count=count
|
||||
|
@ -374,9 +380,11 @@ class RedisCache(BaseCache):
|
|||
raise e
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
from redis.asyncio import Redis
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
_redis_client = self.init_async_client()
|
||||
_redis_client: Redis = self.init_async_client() # type: ignore
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
|
@ -397,6 +405,7 @@ class RedisCache(BaseCache):
|
|||
str(e),
|
||||
value,
|
||||
)
|
||||
raise e
|
||||
|
||||
key = self.check_and_fix_namespace(key=key)
|
||||
async with _redis_client as redis_client:
|
||||
|
@ -405,6 +414,10 @@ class RedisCache(BaseCache):
|
|||
f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
|
||||
)
|
||||
try:
|
||||
if not hasattr(redis_client, "set"):
|
||||
raise Exception(
|
||||
"Redis client cannot set cache. Attribute not found."
|
||||
)
|
||||
await redis_client.set(name=key, value=json.dumps(value), ex=ttl)
|
||||
print_verbose(
|
||||
f"Successfully Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
|
||||
|
@ -446,12 +459,15 @@ class RedisCache(BaseCache):
|
|||
"""
|
||||
Use Redis Pipelines for bulk write operations
|
||||
"""
|
||||
_redis_client = self.init_async_client()
|
||||
from redis.asyncio import Redis
|
||||
|
||||
_redis_client: Redis = self.init_async_client() # type: ignore
|
||||
start_time = time.time()
|
||||
|
||||
print_verbose(
|
||||
f"Set Async Redis Cache: key list: {cache_list}\nttl={ttl}, redis_version={self.redis_version}"
|
||||
)
|
||||
cache_value: Any = None
|
||||
try:
|
||||
async with _redis_client as redis_client:
|
||||
async with redis_client.pipeline(transaction=True) as pipe:
|
||||
|
@ -463,6 +479,7 @@ class RedisCache(BaseCache):
|
|||
)
|
||||
json_cache_value = json.dumps(cache_value)
|
||||
# Set the value with a TTL if it's provided.
|
||||
|
||||
if ttl is not None:
|
||||
pipe.setex(cache_key, ttl, json_cache_value)
|
||||
else:
|
||||
|
@ -511,9 +528,11 @@ class RedisCache(BaseCache):
|
|||
async def async_set_cache_sadd(
|
||||
self, key, value: List, ttl: Optional[float], **kwargs
|
||||
):
|
||||
from redis.asyncio import Redis
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
_redis_client = self.init_async_client()
|
||||
_redis_client: Redis = self.init_async_client() # type: ignore
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
_duration = end_time - start_time
|
||||
|
@ -592,9 +611,11 @@ class RedisCache(BaseCache):
|
|||
await self.flush_cache_buffer() # logging done in here
|
||||
|
||||
async def async_increment(
|
||||
self, key, value: float, ttl: Optional[float] = None, **kwargs
|
||||
self, key, value: float, ttl: Optional[int] = None, **kwargs
|
||||
) -> float:
|
||||
_redis_client = self.init_async_client()
|
||||
from redis.asyncio import Redis
|
||||
|
||||
_redis_client: Redis = self.init_async_client() # type: ignore
|
||||
start_time = time.time()
|
||||
try:
|
||||
async with _redis_client as redis_client:
|
||||
|
@ -708,7 +729,9 @@ class RedisCache(BaseCache):
|
|||
return key_value_dict
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
_redis_client = self.init_async_client()
|
||||
from redis.asyncio import Redis
|
||||
|
||||
_redis_client: Redis = self.init_async_client() # type: ignore
|
||||
key = self.check_and_fix_namespace(key=key)
|
||||
start_time = time.time()
|
||||
async with _redis_client as redis_client:
|
||||
|
@ -903,6 +926,12 @@ class RedisCache(BaseCache):
|
|||
async def disconnect(self):
|
||||
await self.async_redis_conn_pool.disconnect(inuse_connections=True)
|
||||
|
||||
async def async_delete_cache(self, key: str):
|
||||
_redis_client = self.init_async_client()
|
||||
# keys is str
|
||||
async with _redis_client as redis_client:
|
||||
await redis_client.delete(key)
|
||||
|
||||
def delete_cache(self, key):
|
||||
self.redis_client.delete(key)
|
||||
|
||||
|
@ -1241,6 +1270,7 @@ class QdrantSemanticCache(BaseCache):
|
|||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
if collection_name is None:
|
||||
raise Exception("collection_name must be provided, passed None")
|
||||
|
@ -1261,12 +1291,12 @@ class QdrantSemanticCache(BaseCache):
|
|||
if isinstance(qdrant_api_base, str) and qdrant_api_base.startswith(
|
||||
"os.environ/"
|
||||
):
|
||||
qdrant_api_base = litellm.get_secret(qdrant_api_base)
|
||||
qdrant_api_base = get_secret_str(qdrant_api_base)
|
||||
if qdrant_api_key:
|
||||
if isinstance(qdrant_api_key, str) and qdrant_api_key.startswith(
|
||||
"os.environ/"
|
||||
):
|
||||
qdrant_api_key = litellm.get_secret(qdrant_api_key)
|
||||
qdrant_api_key = get_secret_str(qdrant_api_key)
|
||||
|
||||
qdrant_api_base = (
|
||||
qdrant_api_base or os.getenv("QDRANT_URL") or os.getenv("QDRANT_API_BASE")
|
||||
|
@ -1633,7 +1663,7 @@ class S3Cache(BaseCache):
|
|||
s3_bucket_name,
|
||||
s3_region_name=None,
|
||||
s3_api_version=None,
|
||||
s3_use_ssl=True,
|
||||
s3_use_ssl: Optional[bool] = True,
|
||||
s3_verify=None,
|
||||
s3_endpoint_url=None,
|
||||
s3_aws_access_key_id=None,
|
||||
|
@ -1721,7 +1751,7 @@ class S3Cache(BaseCache):
|
|||
Bucket=self.bucket_name, Key=key
|
||||
)
|
||||
|
||||
if cached_response != None:
|
||||
if cached_response is not None:
|
||||
# cached_response is in `b{} convert it to ModelResponse
|
||||
cached_response = (
|
||||
cached_response["Body"].read().decode("utf-8")
|
||||
|
@ -1739,7 +1769,7 @@ class S3Cache(BaseCache):
|
|||
)
|
||||
|
||||
return cached_response
|
||||
except botocore.exceptions.ClientError as e:
|
||||
except botocore.exceptions.ClientError as e: # type: ignore
|
||||
if e.response["Error"]["Code"] == "NoSuchKey":
|
||||
verbose_logger.debug(
|
||||
f"S3 Cache: The specified key '{key}' does not exist in the S3 bucket."
|
||||
|
@ -2081,6 +2111,15 @@ class DualCache(BaseCache):
|
|||
if self.redis_cache is not None:
|
||||
self.redis_cache.delete_cache(key)
|
||||
|
||||
async def async_delete_cache(self, key: str):
|
||||
"""
|
||||
Delete a key from the cache
|
||||
"""
|
||||
if self.in_memory_cache is not None:
|
||||
self.in_memory_cache.delete_cache(key)
|
||||
if self.redis_cache is not None:
|
||||
await self.redis_cache.async_delete_cache(key)
|
||||
|
||||
|
||||
#### LiteLLM.Completion / Embedding Cache ####
|
||||
class Cache:
|
||||
|
@ -2137,7 +2176,7 @@ class Cache:
|
|||
s3_path: Optional[str] = None,
|
||||
redis_semantic_cache_use_async=False,
|
||||
redis_semantic_cache_embedding_model="text-embedding-ada-002",
|
||||
redis_flush_size=None,
|
||||
redis_flush_size: Optional[int] = None,
|
||||
redis_startup_nodes: Optional[List] = None,
|
||||
disk_cache_dir=None,
|
||||
qdrant_api_base: Optional[str] = None,
|
||||
|
@ -2501,10 +2540,9 @@ class Cache:
|
|||
if self.ttl is not None:
|
||||
kwargs["ttl"] = self.ttl
|
||||
## Get Cache-Controls ##
|
||||
if kwargs.get("cache", None) is not None and isinstance(
|
||||
kwargs.get("cache"), dict
|
||||
):
|
||||
for k, v in kwargs.get("cache").items():
|
||||
_cache_kwargs = kwargs.get("cache", None)
|
||||
if isinstance(_cache_kwargs, dict):
|
||||
for k, v in _cache_kwargs.items():
|
||||
if k == "ttl":
|
||||
kwargs["ttl"] = v
|
||||
|
||||
|
@ -2574,14 +2612,15 @@ class Cache:
|
|||
**kwargs,
|
||||
)
|
||||
cache_list.append((cache_key, cached_data))
|
||||
if hasattr(self.cache, "async_set_cache_pipeline"):
|
||||
await self.cache.async_set_cache_pipeline(cache_list=cache_list)
|
||||
async_set_cache_pipeline = getattr(
|
||||
self.cache, "async_set_cache_pipeline", None
|
||||
)
|
||||
if async_set_cache_pipeline:
|
||||
await async_set_cache_pipeline(cache_list=cache_list)
|
||||
else:
|
||||
tasks = []
|
||||
for val in cache_list:
|
||||
tasks.append(
|
||||
self.cache.async_set_cache(cache_key, cached_data, **kwargs)
|
||||
)
|
||||
tasks.append(self.cache.async_set_cache(val[0], val[1], **kwargs))
|
||||
await asyncio.gather(*tasks)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||
|
@ -2611,13 +2650,15 @@ class Cache:
|
|||
await self.cache.batch_cache_write(cache_key, cached_data, **kwargs)
|
||||
|
||||
async def ping(self):
|
||||
if hasattr(self.cache, "ping"):
|
||||
return await self.cache.ping()
|
||||
cache_ping = getattr(self.cache, "ping")
|
||||
if cache_ping:
|
||||
return await cache_ping()
|
||||
return None
|
||||
|
||||
async def delete_cache_keys(self, keys):
|
||||
if hasattr(self.cache, "delete_cache_keys"):
|
||||
return await self.cache.delete_cache_keys(keys)
|
||||
cache_delete_cache_keys = getattr(self.cache, "delete_cache_keys")
|
||||
if cache_delete_cache_keys:
|
||||
return await cache_delete_cache_keys(keys)
|
||||
return None
|
||||
|
||||
async def disconnect(self):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue