mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Merge pull request #1809 from BerriAI/litellm_embedding_caching_updates
Support caching individual items in embedding list (Async embedding only)
This commit is contained in:
commit
28df60b609
13 changed files with 638 additions and 196 deletions
|
@ -285,6 +285,7 @@ openai_compatible_endpoints: List = [
|
||||||
"api.endpoints.anyscale.com/v1",
|
"api.endpoints.anyscale.com/v1",
|
||||||
"api.deepinfra.com/v1/openai",
|
"api.deepinfra.com/v1/openai",
|
||||||
"api.mistral.ai/v1",
|
"api.mistral.ai/v1",
|
||||||
|
"api.together.xyz/v1",
|
||||||
]
|
]
|
||||||
|
|
||||||
# this is maintained for Exception Mapping
|
# this is maintained for Exception Mapping
|
||||||
|
@ -294,6 +295,7 @@ openai_compatible_providers: List = [
|
||||||
"deepinfra",
|
"deepinfra",
|
||||||
"perplexity",
|
"perplexity",
|
||||||
"xinference",
|
"xinference",
|
||||||
|
"together_ai",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -11,6 +11,7 @@
|
||||||
import os
|
import os
|
||||||
import inspect
|
import inspect
|
||||||
import redis, litellm
|
import redis, litellm
|
||||||
|
import redis.asyncio as async_redis
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@ -67,7 +68,10 @@ def get_redis_url_from_environment():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_redis_client(**env_overrides):
|
def _get_redis_client_logic(**env_overrides):
|
||||||
|
"""
|
||||||
|
Common functionality across sync + async redis client implementations
|
||||||
|
"""
|
||||||
### check if "os.environ/<key-name>" passed in
|
### check if "os.environ/<key-name>" passed in
|
||||||
for k, v in env_overrides.items():
|
for k, v in env_overrides.items():
|
||||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||||
|
@ -85,9 +89,33 @@ def get_redis_client(**env_overrides):
|
||||||
redis_kwargs.pop("port", None)
|
redis_kwargs.pop("port", None)
|
||||||
redis_kwargs.pop("db", None)
|
redis_kwargs.pop("db", None)
|
||||||
redis_kwargs.pop("password", None)
|
redis_kwargs.pop("password", None)
|
||||||
|
|
||||||
return redis.Redis.from_url(**redis_kwargs)
|
|
||||||
elif "host" not in redis_kwargs or redis_kwargs["host"] is None:
|
elif "host" not in redis_kwargs or redis_kwargs["host"] is None:
|
||||||
raise ValueError("Either 'host' or 'url' must be specified for redis.")
|
raise ValueError("Either 'host' or 'url' must be specified for redis.")
|
||||||
litellm.print_verbose(f"redis_kwargs: {redis_kwargs}")
|
litellm.print_verbose(f"redis_kwargs: {redis_kwargs}")
|
||||||
|
return redis_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def get_redis_client(**env_overrides):
|
||||||
|
redis_kwargs = _get_redis_client_logic(**env_overrides)
|
||||||
|
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
|
||||||
|
return redis.Redis.from_url(**redis_kwargs)
|
||||||
return redis.Redis(**redis_kwargs)
|
return redis.Redis(**redis_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_redis_async_client(**env_overrides):
|
||||||
|
redis_kwargs = _get_redis_client_logic(**env_overrides)
|
||||||
|
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
|
||||||
|
return async_redis.Redis.from_url(**redis_kwargs)
|
||||||
|
return async_redis.Redis(
|
||||||
|
socket_timeout=5,
|
||||||
|
**redis_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_redis_connection_pool(**env_overrides):
|
||||||
|
redis_kwargs = _get_redis_client_logic(**env_overrides)
|
||||||
|
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
|
||||||
|
return async_redis.BlockingConnectionPool.from_url(
|
||||||
|
timeout=5, url=redis_kwargs["url"]
|
||||||
|
)
|
||||||
|
return async_redis.BlockingConnectionPool(timeout=5, **redis_kwargs)
|
||||||
|
|
|
@ -8,7 +8,7 @@
|
||||||
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
import time, logging
|
import time, logging, asyncio
|
||||||
import json, traceback, ast, hashlib
|
import json, traceback, ast, hashlib
|
||||||
from typing import Optional, Literal, List, Union, Any
|
from typing import Optional, Literal, List, Union, Any
|
||||||
from openai._models import BaseModel as OpenAIObject
|
from openai._models import BaseModel as OpenAIObject
|
||||||
|
@ -28,9 +28,18 @@ class BaseCache:
|
||||||
def set_cache(self, key, value, **kwargs):
|
def set_cache(self, key, value, **kwargs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def async_set_cache(self, key, value, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_cache(self, key, **kwargs):
|
def get_cache(self, key, **kwargs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def async_get_cache(self, key, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def disconnect(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class InMemoryCache(BaseCache):
|
class InMemoryCache(BaseCache):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -43,6 +52,16 @@ class InMemoryCache(BaseCache):
|
||||||
if "ttl" in kwargs:
|
if "ttl" in kwargs:
|
||||||
self.ttl_dict[key] = time.time() + kwargs["ttl"]
|
self.ttl_dict[key] = time.time() + kwargs["ttl"]
|
||||||
|
|
||||||
|
async def async_set_cache(self, key, value, **kwargs):
|
||||||
|
self.set_cache(key=key, value=value, **kwargs)
|
||||||
|
|
||||||
|
async def async_set_cache_pipeline(self, cache_list, ttl=None):
|
||||||
|
for cache_key, cache_value in cache_list:
|
||||||
|
if ttl is not None:
|
||||||
|
self.set_cache(key=cache_key, value=cache_value, ttl=ttl)
|
||||||
|
else:
|
||||||
|
self.set_cache(key=cache_key, value=cache_value)
|
||||||
|
|
||||||
def get_cache(self, key, **kwargs):
|
def get_cache(self, key, **kwargs):
|
||||||
if key in self.cache_dict:
|
if key in self.cache_dict:
|
||||||
if key in self.ttl_dict:
|
if key in self.ttl_dict:
|
||||||
|
@ -57,21 +76,27 @@ class InMemoryCache(BaseCache):
|
||||||
return cached_response
|
return cached_response
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def async_get_cache(self, key, **kwargs):
|
||||||
|
return self.get_cache(key=key, **kwargs)
|
||||||
|
|
||||||
def flush_cache(self):
|
def flush_cache(self):
|
||||||
self.cache_dict.clear()
|
self.cache_dict.clear()
|
||||||
self.ttl_dict.clear()
|
self.ttl_dict.clear()
|
||||||
|
|
||||||
|
|
||||||
|
async def disconnect(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def delete_cache(self, key):
|
def delete_cache(self, key):
|
||||||
self.cache_dict.pop(key, None)
|
self.cache_dict.pop(key, None)
|
||||||
self.ttl_dict.pop(key, None)
|
self.ttl_dict.pop(key, None)
|
||||||
|
|
||||||
|
|
||||||
class RedisCache(BaseCache):
|
class RedisCache(BaseCache):
|
||||||
def __init__(self, host=None, port=None, password=None, **kwargs):
|
# if users don't provider one, use the default litellm cache
|
||||||
import redis
|
|
||||||
|
|
||||||
# if users don't provider one, use the default litellm cache
|
def __init__(self, host=None, port=None, password=None, **kwargs):
|
||||||
from ._redis import get_redis_client
|
from ._redis import get_redis_client, get_redis_connection_pool
|
||||||
|
|
||||||
redis_kwargs = {}
|
redis_kwargs = {}
|
||||||
if host is not None:
|
if host is not None:
|
||||||
|
@ -82,18 +107,84 @@ class RedisCache(BaseCache):
|
||||||
redis_kwargs["password"] = password
|
redis_kwargs["password"] = password
|
||||||
|
|
||||||
redis_kwargs.update(kwargs)
|
redis_kwargs.update(kwargs)
|
||||||
|
|
||||||
self.redis_client = get_redis_client(**redis_kwargs)
|
self.redis_client = get_redis_client(**redis_kwargs)
|
||||||
|
self.redis_kwargs = redis_kwargs
|
||||||
|
self.async_redis_conn_pool = get_redis_connection_pool()
|
||||||
|
|
||||||
|
def init_async_client(self):
|
||||||
|
from ._redis import get_redis_async_client
|
||||||
|
|
||||||
|
return get_redis_async_client(
|
||||||
|
connection_pool=self.async_redis_conn_pool, **self.redis_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
def set_cache(self, key, value, **kwargs):
|
def set_cache(self, key, value, **kwargs):
|
||||||
ttl = kwargs.get("ttl", None)
|
ttl = kwargs.get("ttl", None)
|
||||||
print_verbose(f"Set Redis Cache: key: {key}\nValue {value}")
|
print_verbose(f"Set Redis Cache: key: {key}\nValue {value}\nttl={ttl}")
|
||||||
try:
|
try:
|
||||||
self.redis_client.set(name=key, value=str(value), ex=ttl)
|
self.redis_client.set(name=key, value=str(value), ex=ttl)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# NON blocking - notify users Redis is throwing an exception
|
# NON blocking - notify users Redis is throwing an exception
|
||||||
logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e)
|
logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e)
|
||||||
|
|
||||||
|
async def async_set_cache(self, key, value, **kwargs):
|
||||||
|
_redis_client = self.init_async_client()
|
||||||
|
async with _redis_client as redis_client:
|
||||||
|
ttl = kwargs.get("ttl", None)
|
||||||
|
print_verbose(
|
||||||
|
f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await redis_client.set(name=key, value=json.dumps(value), ex=ttl)
|
||||||
|
except Exception as e:
|
||||||
|
# NON blocking - notify users Redis is throwing an exception
|
||||||
|
logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e)
|
||||||
|
|
||||||
|
async def async_set_cache_pipeline(self, cache_list, ttl=None):
|
||||||
|
"""
|
||||||
|
Use Redis Pipelines for bulk write operations
|
||||||
|
"""
|
||||||
|
_redis_client = self.init_async_client()
|
||||||
|
try:
|
||||||
|
async with _redis_client as redis_client:
|
||||||
|
async with redis_client.pipeline(transaction=True) as pipe:
|
||||||
|
# Iterate through each key-value pair in the cache_list and set them in the pipeline.
|
||||||
|
for cache_key, cache_value in cache_list:
|
||||||
|
print_verbose(
|
||||||
|
f"Set ASYNC Redis Cache PIPELINE: key: {cache_key}\nValue {cache_value}\nttl={ttl}"
|
||||||
|
)
|
||||||
|
# Set the value with a TTL if it's provided.
|
||||||
|
if ttl is not None:
|
||||||
|
pipe.setex(cache_key, ttl, json.dumps(cache_value))
|
||||||
|
else:
|
||||||
|
pipe.set(cache_key, json.dumps(cache_value))
|
||||||
|
# Execute the pipeline and return the results.
|
||||||
|
results = await pipe.execute()
|
||||||
|
|
||||||
|
print_verbose(f"pipeline results: {results}")
|
||||||
|
# Optionally, you could process 'results' to make sure that all set operations were successful.
|
||||||
|
return results
|
||||||
|
except Exception as e:
|
||||||
|
print_verbose(f"Error occurred in pipeline write - {str(e)}")
|
||||||
|
# NON blocking - notify users Redis is throwing an exception
|
||||||
|
logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e)
|
||||||
|
|
||||||
|
def _get_cache_logic(self, cached_response: Any):
|
||||||
|
"""
|
||||||
|
Common 'get_cache_logic' across sync + async redis client implementations
|
||||||
|
"""
|
||||||
|
if cached_response is None:
|
||||||
|
return cached_response
|
||||||
|
# cached_response is in `b{} convert it to ModelResponse
|
||||||
|
cached_response = cached_response.decode("utf-8") # Convert bytes to string
|
||||||
|
try:
|
||||||
|
cached_response = json.loads(
|
||||||
|
cached_response
|
||||||
|
) # Convert string to dictionary
|
||||||
|
except:
|
||||||
|
cached_response = ast.literal_eval(cached_response)
|
||||||
|
return cached_response
|
||||||
|
|
||||||
def get_cache(self, key, **kwargs):
|
def get_cache(self, key, **kwargs):
|
||||||
try:
|
try:
|
||||||
print_verbose(f"Get Redis Cache: key: {key}")
|
print_verbose(f"Get Redis Cache: key: {key}")
|
||||||
|
@ -101,30 +192,40 @@ class RedisCache(BaseCache):
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"Got Redis Cache: key: {key}, cached_response {cached_response}"
|
f"Got Redis Cache: key: {key}, cached_response {cached_response}"
|
||||||
)
|
)
|
||||||
if cached_response != None:
|
return self._get_cache_logic(cached_response=cached_response)
|
||||||
# cached_response is in `b{} convert it to ModelResponse
|
|
||||||
cached_response = cached_response.decode(
|
|
||||||
"utf-8"
|
|
||||||
) # Convert bytes to string
|
|
||||||
try:
|
|
||||||
cached_response = json.loads(
|
|
||||||
cached_response
|
|
||||||
) # Convert string to dictionary
|
|
||||||
except:
|
|
||||||
cached_response = ast.literal_eval(cached_response)
|
|
||||||
return cached_response
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# NON blocking - notify users Redis is throwing an exception
|
# NON blocking - notify users Redis is throwing an exception
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
|
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
|
||||||
|
|
||||||
|
async def async_get_cache(self, key, **kwargs):
|
||||||
|
_redis_client = self.init_async_client()
|
||||||
|
async with _redis_client as redis_client:
|
||||||
|
try:
|
||||||
|
print_verbose(f"Get Redis Cache: key: {key}")
|
||||||
|
cached_response = await redis_client.get(key)
|
||||||
|
print_verbose(
|
||||||
|
f"Got Async Redis Cache: key: {key}, cached_response {cached_response}"
|
||||||
|
)
|
||||||
|
response = self._get_cache_logic(cached_response=cached_response)
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
# NON blocking - notify users Redis is throwing an exception
|
||||||
|
traceback.print_exc()
|
||||||
|
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
|
||||||
|
|
||||||
def flush_cache(self):
|
def flush_cache(self):
|
||||||
self.redis_client.flushall()
|
self.redis_client.flushall()
|
||||||
|
|
||||||
|
|
||||||
|
async def disconnect(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def delete_cache(self, key):
|
def delete_cache(self, key):
|
||||||
self.redis_client.delete(key)
|
self.redis_client.delete(key)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class S3Cache(BaseCache):
|
class S3Cache(BaseCache):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -202,6 +303,9 @@ class S3Cache(BaseCache):
|
||||||
# NON blocking - notify users S3 is throwing an exception
|
# NON blocking - notify users S3 is throwing an exception
|
||||||
print_verbose(f"S3 Caching: set_cache() - Got exception from S3: {e}")
|
print_verbose(f"S3 Caching: set_cache() - Got exception from S3: {e}")
|
||||||
|
|
||||||
|
async def async_set_cache(self, key, value, **kwargs):
|
||||||
|
self.set_cache(key=key, value=value, **kwargs)
|
||||||
|
|
||||||
def get_cache(self, key, **kwargs):
|
def get_cache(self, key, **kwargs):
|
||||||
import boto3, botocore
|
import boto3, botocore
|
||||||
|
|
||||||
|
@ -244,6 +348,9 @@ class S3Cache(BaseCache):
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
print_verbose(f"S3 Caching: get_cache() - Got exception from S3: {e}")
|
print_verbose(f"S3 Caching: get_cache() - Got exception from S3: {e}")
|
||||||
|
|
||||||
|
async def async_get_cache(self, key, **kwargs):
|
||||||
|
return self.get_cache(key=key, **kwargs)
|
||||||
|
|
||||||
def flush_cache(self):
|
def flush_cache(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -361,9 +468,9 @@ class Cache:
|
||||||
"""
|
"""
|
||||||
if type == "redis":
|
if type == "redis":
|
||||||
self.cache: BaseCache = RedisCache(host, port, password, **kwargs)
|
self.cache: BaseCache = RedisCache(host, port, password, **kwargs)
|
||||||
if type == "local":
|
elif type == "local":
|
||||||
self.cache = InMemoryCache()
|
self.cache = InMemoryCache()
|
||||||
if type == "s3":
|
elif type == "s3":
|
||||||
self.cache = S3Cache(
|
self.cache = S3Cache(
|
||||||
s3_bucket_name=s3_bucket_name,
|
s3_bucket_name=s3_bucket_name,
|
||||||
s3_region_name=s3_region_name,
|
s3_region_name=s3_region_name,
|
||||||
|
@ -489,6 +596,45 @@ class Cache:
|
||||||
}
|
}
|
||||||
time.sleep(0.02)
|
time.sleep(0.02)
|
||||||
|
|
||||||
|
def _get_cache_logic(
|
||||||
|
self,
|
||||||
|
cached_result: Optional[Any],
|
||||||
|
max_age: Optional[float],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Common get cache logic across sync + async implementations
|
||||||
|
"""
|
||||||
|
# Check if a timestamp was stored with the cached response
|
||||||
|
if (
|
||||||
|
cached_result is not None
|
||||||
|
and isinstance(cached_result, dict)
|
||||||
|
and "timestamp" in cached_result
|
||||||
|
):
|
||||||
|
timestamp = cached_result["timestamp"]
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
|
# Calculate age of the cached response
|
||||||
|
response_age = current_time - timestamp
|
||||||
|
|
||||||
|
# Check if the cached response is older than the max-age
|
||||||
|
if max_age is not None and response_age > max_age:
|
||||||
|
return None # Cached response is too old
|
||||||
|
|
||||||
|
# If the response is fresh, or there's no max-age requirement, return the cached response
|
||||||
|
# cached_response is in `b{} convert it to ModelResponse
|
||||||
|
cached_response = cached_result.get("response")
|
||||||
|
try:
|
||||||
|
if isinstance(cached_response, dict):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
cached_response = json.loads(
|
||||||
|
cached_response # type: ignore
|
||||||
|
) # Convert string to dictionary
|
||||||
|
except:
|
||||||
|
cached_response = ast.literal_eval(cached_response) # type: ignore
|
||||||
|
return cached_response
|
||||||
|
return cached_result
|
||||||
|
|
||||||
def get_cache(self, *args, **kwargs):
|
def get_cache(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Retrieves the cached result for the given arguments.
|
Retrieves the cached result for the given arguments.
|
||||||
|
@ -511,54 +657,40 @@ class Cache:
|
||||||
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
|
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
|
||||||
)
|
)
|
||||||
cached_result = self.cache.get_cache(cache_key)
|
cached_result = self.cache.get_cache(cache_key)
|
||||||
# Check if a timestamp was stored with the cached response
|
return self._get_cache_logic(
|
||||||
if (
|
cached_result=cached_result, max_age=max_age
|
||||||
cached_result is not None
|
)
|
||||||
and isinstance(cached_result, dict)
|
|
||||||
and "timestamp" in cached_result
|
|
||||||
and max_age is not None
|
|
||||||
):
|
|
||||||
timestamp = cached_result["timestamp"]
|
|
||||||
current_time = time.time()
|
|
||||||
|
|
||||||
# Calculate age of the cached response
|
|
||||||
response_age = current_time - timestamp
|
|
||||||
|
|
||||||
# Check if the cached response is older than the max-age
|
|
||||||
if response_age > max_age:
|
|
||||||
print_verbose(
|
|
||||||
f"Cached response for key {cache_key} is too old. Max-age: {max_age}s, Age: {response_age}s"
|
|
||||||
)
|
|
||||||
return None # Cached response is too old
|
|
||||||
|
|
||||||
# If the response is fresh, or there's no max-age requirement, return the cached response
|
|
||||||
# cached_response is in `b{} convert it to ModelResponse
|
|
||||||
cached_response = cached_result.get("response")
|
|
||||||
try:
|
|
||||||
if isinstance(cached_response, dict):
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
cached_response = json.loads(
|
|
||||||
cached_response
|
|
||||||
) # Convert string to dictionary
|
|
||||||
except:
|
|
||||||
cached_response = ast.literal_eval(cached_response)
|
|
||||||
return cached_response
|
|
||||||
return cached_result
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_verbose(f"An exception occurred: {traceback.format_exc()}")
|
print_verbose(f"An exception occurred: {traceback.format_exc()}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def add_cache(self, result, *args, **kwargs):
|
async def async_get_cache(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Adds a result to the cache.
|
Async get cache implementation.
|
||||||
|
|
||||||
Args:
|
Used for embedding calls in async wrapper
|
||||||
*args: args to litellm.completion() or embedding()
|
"""
|
||||||
**kwargs: kwargs to litellm.completion() or embedding()
|
try: # never block execution
|
||||||
|
if "cache_key" in kwargs:
|
||||||
|
cache_key = kwargs["cache_key"]
|
||||||
|
else:
|
||||||
|
cache_key = self.get_cache_key(*args, **kwargs)
|
||||||
|
if cache_key is not None:
|
||||||
|
cache_control_args = kwargs.get("cache", {})
|
||||||
|
max_age = cache_control_args.get(
|
||||||
|
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
|
||||||
|
)
|
||||||
|
cached_result = await self.cache.async_get_cache(cache_key)
|
||||||
|
return self._get_cache_logic(
|
||||||
|
cached_result=cached_result, max_age=max_age
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print_verbose(f"An exception occurred: {traceback.format_exc()}")
|
||||||
|
return None
|
||||||
|
|
||||||
Returns:
|
def _add_cache_logic(self, result, *args, **kwargs):
|
||||||
None
|
"""
|
||||||
|
Common implementation across sync + async add_cache functions
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if "cache_key" in kwargs:
|
if "cache_key" in kwargs:
|
||||||
|
@ -577,14 +709,82 @@ class Cache:
|
||||||
if k == "ttl":
|
if k == "ttl":
|
||||||
kwargs["ttl"] = v
|
kwargs["ttl"] = v
|
||||||
cached_data = {"timestamp": time.time(), "response": result}
|
cached_data = {"timestamp": time.time(), "response": result}
|
||||||
self.cache.set_cache(cache_key, cached_data, **kwargs)
|
return cache_key, cached_data, kwargs
|
||||||
|
else:
|
||||||
|
raise Exception("cache key is None")
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def add_cache(self, result, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Adds a result to the cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args: args to litellm.completion() or embedding()
|
||||||
|
**kwargs: kwargs to litellm.completion() or embedding()
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
cache_key, cached_data, kwargs = self._add_cache_logic(
|
||||||
|
result=result, *args, **kwargs
|
||||||
|
)
|
||||||
|
self.cache.set_cache(cache_key, cached_data, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def _async_add_cache(self, result, *args, **kwargs):
|
async def async_add_cache(self, result, *args, **kwargs):
|
||||||
self.add_cache(result, *args, **kwargs)
|
"""
|
||||||
|
Async implementation of add_cache
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
cache_key, cached_data, kwargs = self._add_cache_logic(
|
||||||
|
result=result, *args, **kwargs
|
||||||
|
)
|
||||||
|
await self.cache.async_set_cache(cache_key, cached_data, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
async def async_add_cache_pipeline(self, result, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Async implementation of add_cache for Embedding calls
|
||||||
|
|
||||||
|
Does a bulk write, to prevent using too many clients
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
cache_list = []
|
||||||
|
for idx, i in enumerate(kwargs["input"]):
|
||||||
|
preset_cache_key = litellm.cache.get_cache_key(
|
||||||
|
*args, **{**kwargs, "input": i}
|
||||||
|
)
|
||||||
|
kwargs["cache_key"] = preset_cache_key
|
||||||
|
embedding_response = result.data[idx]
|
||||||
|
cache_key, cached_data, kwargs = self._add_cache_logic(
|
||||||
|
result=embedding_response,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
cache_list.append((cache_key, cached_data))
|
||||||
|
if hasattr(self.cache, "async_set_cache_pipeline"):
|
||||||
|
await self.cache.async_set_cache_pipeline(cache_list=cache_list)
|
||||||
|
else:
|
||||||
|
tasks = []
|
||||||
|
for val in cache_list:
|
||||||
|
tasks.append(
|
||||||
|
self.cache.async_set_cache(cache_key, cached_data, **kwargs)
|
||||||
|
)
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
except Exception as e:
|
||||||
|
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
async def disconnect(self):
|
||||||
|
if hasattr(self.cache, "disconnect"):
|
||||||
|
await self.cache.disconnect()
|
||||||
|
|
||||||
|
|
||||||
def enable_cache(
|
def enable_cache(
|
||||||
|
|
|
@ -440,8 +440,8 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
input=data["messages"],
|
input=data["messages"],
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
additional_args={
|
additional_args={
|
||||||
"headers": headers,
|
"headers": {"Authorization": f"Bearer {openai_client.api_key}"},
|
||||||
"api_base": api_base,
|
"api_base": openai_client._base_url._uri_reference,
|
||||||
"acompletion": False,
|
"acompletion": False,
|
||||||
"complete_input_dict": data,
|
"complete_input_dict": data,
|
||||||
},
|
},
|
||||||
|
|
|
@ -1,3 +1,7 @@
|
||||||
|
"""
|
||||||
|
Deprecated. We now do together ai calls via the openai client.
|
||||||
|
Reference: https://docs.together.ai/docs/openai-api-compatibility
|
||||||
|
"""
|
||||||
import os, types
|
import os, types
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
|
@ -10,6 +10,7 @@
|
||||||
import os, openai, sys, json, inspect, uuid, datetime, threading
|
import os, openai, sys, json, inspect, uuid, datetime, threading
|
||||||
from typing import Any, Literal, Union
|
from typing import Any, Literal, Union
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import dotenv, traceback, random, asyncio, time, contextvars
|
import dotenv, traceback, random, asyncio, time, contextvars
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -234,6 +235,9 @@ async def acompletion(
|
||||||
"model_list": model_list,
|
"model_list": model_list,
|
||||||
"acompletion": True, # assuming this is a required parameter
|
"acompletion": True, # assuming this is a required parameter
|
||||||
}
|
}
|
||||||
|
_, custom_llm_provider, _, _ = get_llm_provider(
|
||||||
|
model=model, api_base=completion_kwargs.get("base_url", None)
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
# Use a partial function to pass your keyword arguments
|
# Use a partial function to pass your keyword arguments
|
||||||
func = partial(completion, **completion_kwargs, **kwargs)
|
func = partial(completion, **completion_kwargs, **kwargs)
|
||||||
|
@ -245,7 +249,6 @@ async def acompletion(
|
||||||
_, custom_llm_provider, _, _ = get_llm_provider(
|
_, custom_llm_provider, _, _ = get_llm_provider(
|
||||||
model=model, api_base=kwargs.get("api_base", None)
|
model=model, api_base=kwargs.get("api_base", None)
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
custom_llm_provider == "openai"
|
custom_llm_provider == "openai"
|
||||||
or custom_llm_provider == "azure"
|
or custom_llm_provider == "azure"
|
||||||
|
@ -788,6 +791,7 @@ def completion(
|
||||||
or custom_llm_provider == "anyscale"
|
or custom_llm_provider == "anyscale"
|
||||||
or custom_llm_provider == "mistral"
|
or custom_llm_provider == "mistral"
|
||||||
or custom_llm_provider == "openai"
|
or custom_llm_provider == "openai"
|
||||||
|
or custom_llm_provider == "together_ai"
|
||||||
or "ft:gpt-3.5-turbo" in model # finetune gpt-3.5-turbo
|
or "ft:gpt-3.5-turbo" in model # finetune gpt-3.5-turbo
|
||||||
): # allow user to make an openai call with a custom base
|
): # allow user to make an openai call with a custom base
|
||||||
# note: if a user sets a custom base - we should ensure this works
|
# note: if a user sets a custom base - we should ensure this works
|
||||||
|
@ -1327,6 +1331,9 @@ def completion(
|
||||||
or ("togethercomputer" in model)
|
or ("togethercomputer" in model)
|
||||||
or (model in litellm.together_ai_models)
|
or (model in litellm.together_ai_models)
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Deprecated. We now do together ai calls via the openai client - https://docs.together.ai/docs/openai-api-compatibility
|
||||||
|
"""
|
||||||
custom_llm_provider = "together_ai"
|
custom_llm_provider = "together_ai"
|
||||||
together_ai_key = (
|
together_ai_key = (
|
||||||
api_key
|
api_key
|
||||||
|
|
|
@ -380,7 +380,7 @@ def run_server(
|
||||||
import gunicorn.app.base
|
import gunicorn.app.base
|
||||||
except:
|
except:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Uvicorn, gunicorn needs to be imported. Run - `pip 'litellm[proxy]'`"
|
"uvicorn, gunicorn needs to be imported. Run - `pip install 'litellm[proxy]'`"
|
||||||
)
|
)
|
||||||
|
|
||||||
if config is not None:
|
if config is not None:
|
||||||
|
@ -444,6 +444,7 @@ def run_server(
|
||||||
)
|
)
|
||||||
if port == 8000 and is_port_in_use(port):
|
if port == 8000 and is_port_in_use(port):
|
||||||
port = random.randint(1024, 49152)
|
port = random.randint(1024, 49152)
|
||||||
|
|
||||||
from litellm.proxy.proxy_server import app
|
from litellm.proxy.proxy_server import app
|
||||||
|
|
||||||
if run_gunicorn == False:
|
if run_gunicorn == False:
|
||||||
|
@ -521,5 +522,6 @@ def run_server(
|
||||||
).run() # Run gunicorn
|
).run() # Run gunicorn
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_server()
|
run_server()
|
||||||
|
|
|
@ -7,6 +7,20 @@ import secrets, subprocess
|
||||||
import hashlib, uuid
|
import hashlib, uuid
|
||||||
import warnings
|
import warnings
|
||||||
import importlib
|
import importlib
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
|
def showwarning(message, category, filename, lineno, file=None, line=None):
|
||||||
|
traceback_info = f"{filename}:{lineno}: {category.__name__}: {message}\n"
|
||||||
|
if file is not None:
|
||||||
|
file.write(traceback_info)
|
||||||
|
|
||||||
|
|
||||||
|
warnings.showwarning = showwarning
|
||||||
|
warnings.filterwarnings("default", category=UserWarning)
|
||||||
|
|
||||||
|
# Your client code here
|
||||||
|
|
||||||
|
|
||||||
messages: list = []
|
messages: list = []
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
|
@ -4053,9 +4067,12 @@ def _has_user_setup_sso():
|
||||||
async def shutdown_event():
|
async def shutdown_event():
|
||||||
global prisma_client, master_key, user_custom_auth, user_custom_key_generate
|
global prisma_client, master_key, user_custom_auth, user_custom_key_generate
|
||||||
if prisma_client:
|
if prisma_client:
|
||||||
|
|
||||||
verbose_proxy_logger.debug("Disconnecting from Prisma")
|
verbose_proxy_logger.debug("Disconnecting from Prisma")
|
||||||
await prisma_client.disconnect()
|
await prisma_client.disconnect()
|
||||||
|
|
||||||
|
if litellm.cache is not None:
|
||||||
|
await litellm.cache.disconnect()
|
||||||
## RESET CUSTOM VARIABLES ##
|
## RESET CUSTOM VARIABLES ##
|
||||||
cleanup_router_config_variables()
|
cleanup_router_config_variables()
|
||||||
|
|
||||||
|
|
|
@ -21,10 +21,18 @@ def setup_and_teardown():
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
importlib.reload(litellm)
|
importlib.reload(litellm)
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
print(litellm)
|
print(litellm)
|
||||||
# from litellm import Router, completion, aembedding, acompletion, embedding
|
# from litellm import Router, completion, aembedding, acompletion, embedding
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
# Teardown code (executes after the yield point)
|
||||||
|
loop.close() # Close the loop created earlier
|
||||||
|
asyncio.set_event_loop(None) # Remove the reference to the loop
|
||||||
|
|
||||||
|
|
||||||
def pytest_collection_modifyitems(config, items):
|
def pytest_collection_modifyitems(config, items):
|
||||||
# Separate tests in 'test_amazing_proxy_custom_logger.py' and other tests
|
# Separate tests in 'test_amazing_proxy_custom_logger.py' and other tests
|
||||||
|
|
|
@ -11,10 +11,10 @@ sys.path.insert(
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
import pytest
|
import pytest
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import embedding, completion
|
from litellm import embedding, completion, aembedding
|
||||||
from litellm.caching import Cache
|
from litellm.caching import Cache
|
||||||
import random
|
import random
|
||||||
import hashlib
|
import hashlib, asyncio
|
||||||
|
|
||||||
# litellm.set_verbose=True
|
# litellm.set_verbose=True
|
||||||
|
|
||||||
|
@ -106,10 +106,7 @@ def test_caching_with_cache_controls():
|
||||||
)
|
)
|
||||||
print(f"response1: {response1}")
|
print(f"response1: {response1}")
|
||||||
print(f"response2: {response2}")
|
print(f"response2: {response2}")
|
||||||
assert (
|
assert response2["id"] == response1["id"]
|
||||||
response2["choices"][0]["message"]["content"]
|
|
||||||
== response1["choices"][0]["message"]["content"]
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"error occurred: {traceback.format_exc()}")
|
print(f"error occurred: {traceback.format_exc()}")
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
@ -259,6 +256,84 @@ def test_embedding_caching_azure():
|
||||||
# test_embedding_caching_azure()
|
# test_embedding_caching_azure()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_embedding_caching_azure_individual_items():
|
||||||
|
"""
|
||||||
|
Tests caching for individual items in an embedding list
|
||||||
|
|
||||||
|
- Cache an item
|
||||||
|
- call aembedding(..) with the item + 1 unique item
|
||||||
|
- compare to a 2nd aembedding (...) with 2 unique items
|
||||||
|
```
|
||||||
|
embedding_1 = ["hey how's it going", "I'm doing well"]
|
||||||
|
embedding_val_1 = embedding(...)
|
||||||
|
|
||||||
|
embedding_2 = ["hey how's it going", "I'm fine"]
|
||||||
|
embedding_val_2 = embedding(...)
|
||||||
|
|
||||||
|
assert embedding_val_1[0]["id"] == embedding_val_2[0]["id"]
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
litellm.cache = Cache()
|
||||||
|
common_msg = f"hey how's it going {uuid.uuid4()}"
|
||||||
|
common_msg_2 = f"hey how's it going {uuid.uuid4()}"
|
||||||
|
embedding_1 = [common_msg]
|
||||||
|
embedding_2 = [
|
||||||
|
common_msg,
|
||||||
|
f"I'm fine {uuid.uuid4()}",
|
||||||
|
]
|
||||||
|
|
||||||
|
embedding_val_1 = await aembedding(
|
||||||
|
model="azure/azure-embedding-model", input=embedding_1, caching=True
|
||||||
|
)
|
||||||
|
embedding_val_2 = await aembedding(
|
||||||
|
model="azure/azure-embedding-model", input=embedding_2, caching=True
|
||||||
|
)
|
||||||
|
print(f"embedding_val_2._hidden_params: {embedding_val_2._hidden_params}")
|
||||||
|
assert embedding_val_2._hidden_params["cache_hit"] == True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_redis_cache_basic():
|
||||||
|
"""
|
||||||
|
Init redis client
|
||||||
|
- write to client
|
||||||
|
- read from client
|
||||||
|
"""
|
||||||
|
litellm.set_verbose = False
|
||||||
|
|
||||||
|
random_number = random.randint(
|
||||||
|
1, 100000
|
||||||
|
) # add a random number to ensure it's always adding / reading from cache
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": f"write a one sentence poem about: {random_number}"}
|
||||||
|
]
|
||||||
|
litellm.cache = Cache(
|
||||||
|
type="redis",
|
||||||
|
host=os.environ["REDIS_HOST"],
|
||||||
|
port=os.environ["REDIS_PORT"],
|
||||||
|
password=os.environ["REDIS_PASSWORD"],
|
||||||
|
)
|
||||||
|
response1 = completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
cache_key = litellm.cache.get_cache_key(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
print(f"cache_key: {cache_key}")
|
||||||
|
litellm.cache.add_cache(result=response1, cache_key=cache_key)
|
||||||
|
print(f"cache key pre async get: {cache_key}")
|
||||||
|
stored_val = await litellm.cache.async_get_cache(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
print(f"stored_val: {stored_val}")
|
||||||
|
assert stored_val["id"] == response1.id
|
||||||
|
|
||||||
|
|
||||||
def test_redis_cache_completion():
|
def test_redis_cache_completion():
|
||||||
litellm.set_verbose = False
|
litellm.set_verbose = False
|
||||||
|
|
||||||
|
@ -406,7 +481,7 @@ def test_redis_cache_acompletion_stream():
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = False
|
||||||
random_word = generate_random_word()
|
random_word = generate_random_word()
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
|
@ -434,7 +509,6 @@ def test_redis_cache_acompletion_stream():
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
async for chunk in response1:
|
async for chunk in response1:
|
||||||
print(chunk)
|
|
||||||
response_1_content += chunk.choices[0].delta.content or ""
|
response_1_content += chunk.choices[0].delta.content or ""
|
||||||
print(response_1_content)
|
print(response_1_content)
|
||||||
|
|
||||||
|
@ -452,7 +526,6 @@ def test_redis_cache_acompletion_stream():
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
async for chunk in response2:
|
async for chunk in response2:
|
||||||
print(chunk)
|
|
||||||
response_2_content += chunk.choices[0].delta.content or ""
|
response_2_content += chunk.choices[0].delta.content or ""
|
||||||
print(response_2_content)
|
print(response_2_content)
|
||||||
|
|
||||||
|
@ -914,101 +987,3 @@ def test_cache_context_managers():
|
||||||
|
|
||||||
|
|
||||||
# test_cache_context_managers()
|
# test_cache_context_managers()
|
||||||
|
|
||||||
# test_custom_redis_cache_params()
|
|
||||||
|
|
||||||
# def test_redis_cache_with_ttl():
|
|
||||||
# cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
|
|
||||||
# sample_model_response_object_str = """{
|
|
||||||
# "choices": [
|
|
||||||
# {
|
|
||||||
# "finish_reason": "stop",
|
|
||||||
# "index": 0,
|
|
||||||
# "message": {
|
|
||||||
# "role": "assistant",
|
|
||||||
# "content": "I'm doing well, thank you for asking. I am Claude, an AI assistant created by Anthropic."
|
|
||||||
# }
|
|
||||||
# }
|
|
||||||
# ],
|
|
||||||
# "created": 1691429984.3852863,
|
|
||||||
# "model": "claude-instant-1",
|
|
||||||
# "usage": {
|
|
||||||
# "prompt_tokens": 18,
|
|
||||||
# "completion_tokens": 23,
|
|
||||||
# "total_tokens": 41
|
|
||||||
# }
|
|
||||||
# }"""
|
|
||||||
# sample_model_response_object = {
|
|
||||||
# "choices": [
|
|
||||||
# {
|
|
||||||
# "finish_reason": "stop",
|
|
||||||
# "index": 0,
|
|
||||||
# "message": {
|
|
||||||
# "role": "assistant",
|
|
||||||
# "content": "I'm doing well, thank you for asking. I am Claude, an AI assistant created by Anthropic."
|
|
||||||
# }
|
|
||||||
# }
|
|
||||||
# ],
|
|
||||||
# "created": 1691429984.3852863,
|
|
||||||
# "model": "claude-instant-1",
|
|
||||||
# "usage": {
|
|
||||||
# "prompt_tokens": 18,
|
|
||||||
# "completion_tokens": 23,
|
|
||||||
# "total_tokens": 41
|
|
||||||
# }
|
|
||||||
# }
|
|
||||||
# cache.add_cache(cache_key="test_key", result=sample_model_response_object_str, ttl=1)
|
|
||||||
# cached_value = cache.get_cache(cache_key="test_key")
|
|
||||||
# print(f"cached-value: {cached_value}")
|
|
||||||
# assert cached_value['choices'][0]['message']['content'] == sample_model_response_object['choices'][0]['message']['content']
|
|
||||||
# time.sleep(2)
|
|
||||||
# assert cache.get_cache(cache_key="test_key") is None
|
|
||||||
|
|
||||||
# # test_redis_cache_with_ttl()
|
|
||||||
|
|
||||||
# def test_in_memory_cache_with_ttl():
|
|
||||||
# cache = Cache(type="local")
|
|
||||||
# sample_model_response_object_str = """{
|
|
||||||
# "choices": [
|
|
||||||
# {
|
|
||||||
# "finish_reason": "stop",
|
|
||||||
# "index": 0,
|
|
||||||
# "message": {
|
|
||||||
# "role": "assistant",
|
|
||||||
# "content": "I'm doing well, thank you for asking. I am Claude, an AI assistant created by Anthropic."
|
|
||||||
# }
|
|
||||||
# }
|
|
||||||
# ],
|
|
||||||
# "created": 1691429984.3852863,
|
|
||||||
# "model": "claude-instant-1",
|
|
||||||
# "usage": {
|
|
||||||
# "prompt_tokens": 18,
|
|
||||||
# "completion_tokens": 23,
|
|
||||||
# "total_tokens": 41
|
|
||||||
# }
|
|
||||||
# }"""
|
|
||||||
# sample_model_response_object = {
|
|
||||||
# "choices": [
|
|
||||||
# {
|
|
||||||
# "finish_reason": "stop",
|
|
||||||
# "index": 0,
|
|
||||||
# "message": {
|
|
||||||
# "role": "assistant",
|
|
||||||
# "content": "I'm doing well, thank you for asking. I am Claude, an AI assistant created by Anthropic."
|
|
||||||
# }
|
|
||||||
# }
|
|
||||||
# ],
|
|
||||||
# "created": 1691429984.3852863,
|
|
||||||
# "model": "claude-instant-1",
|
|
||||||
# "usage": {
|
|
||||||
# "prompt_tokens": 18,
|
|
||||||
# "completion_tokens": 23,
|
|
||||||
# "total_tokens": 41
|
|
||||||
# }
|
|
||||||
# }
|
|
||||||
# cache.add_cache(cache_key="test_key", result=sample_model_response_object_str, ttl=1)
|
|
||||||
# cached_value = cache.get_cache(cache_key="test_key")
|
|
||||||
# assert cached_value['choices'][0]['message']['content'] == sample_model_response_object['choices'][0]['message']['content']
|
|
||||||
# time.sleep(2)
|
|
||||||
# assert cache.get_cache(cache_key="test_key") is None
|
|
||||||
# # test_in_memory_cache_with_ttl()
|
|
||||||
|
|
|
@ -1994,6 +1994,7 @@ def test_completion_palm_stream():
|
||||||
|
|
||||||
|
|
||||||
def test_completion_together_ai_stream():
|
def test_completion_together_ai_stream():
|
||||||
|
litellm.set_verbose = True
|
||||||
user_message = "Write 1pg about YC & litellm"
|
user_message = "Write 1pg about YC & litellm"
|
||||||
messages = [{"content": user_message, "role": "user"}]
|
messages = [{"content": user_message, "role": "user"}]
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -556,7 +556,6 @@ async def test_async_chat_bedrock_stream():
|
||||||
|
|
||||||
# asyncio.run(test_async_chat_bedrock_stream())
|
# asyncio.run(test_async_chat_bedrock_stream())
|
||||||
|
|
||||||
|
|
||||||
## Test Sagemaker + Async
|
## Test Sagemaker + Async
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_async_chat_sagemaker_stream():
|
async def test_async_chat_sagemaker_stream():
|
||||||
|
@ -725,7 +724,7 @@ async def test_async_embedding_bedrock():
|
||||||
response = await litellm.aembedding(
|
response = await litellm.aembedding(
|
||||||
model="bedrock/cohere.embed-multilingual-v3",
|
model="bedrock/cohere.embed-multilingual-v3",
|
||||||
input=["good morning from litellm"],
|
input=["good morning from litellm"],
|
||||||
aws_region_name="os.environ/AWS_REGION_NAME_2",
|
aws_region_name="us-east-1",
|
||||||
)
|
)
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
print(f"customHandler_success.errors: {customHandler_success.errors}")
|
print(f"customHandler_success.errors: {customHandler_success.errors}")
|
||||||
|
@ -758,6 +757,7 @@ async def test_async_embedding_bedrock():
|
||||||
## Test Azure - completion, embedding
|
## Test Azure - completion, embedding
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_async_completion_azure_caching():
|
async def test_async_completion_azure_caching():
|
||||||
|
litellm.set_verbose = True
|
||||||
customHandler_caching = CompletionCustomHandler()
|
customHandler_caching = CompletionCustomHandler()
|
||||||
litellm.cache = Cache(
|
litellm.cache = Cache(
|
||||||
type="redis",
|
type="redis",
|
||||||
|
@ -812,6 +812,7 @@ async def test_async_embedding_azure_caching():
|
||||||
)
|
)
|
||||||
await asyncio.sleep(1) # success callbacks are done in parallel
|
await asyncio.sleep(1) # success callbacks are done in parallel
|
||||||
print(customHandler_caching.states)
|
print(customHandler_caching.states)
|
||||||
|
print(customHandler_caching.errors)
|
||||||
assert len(customHandler_caching.errors) == 0
|
assert len(customHandler_caching.errors) == 0
|
||||||
assert len(customHandler_caching.states) == 4 # pre, post, success, success
|
assert len(customHandler_caching.states) == 4 # pre, post, success, success
|
||||||
|
|
||||||
|
|
227
litellm/utils.py
227
litellm/utils.py
|
@ -55,6 +55,7 @@ from .integrations.litedebugger import LiteDebugger
|
||||||
from .proxy._types import KeyManagementSystem
|
from .proxy._types import KeyManagementSystem
|
||||||
from openai import OpenAIError as OriginalError
|
from openai import OpenAIError as OriginalError
|
||||||
from openai._models import BaseModel as OpenAIObject
|
from openai._models import BaseModel as OpenAIObject
|
||||||
|
from .caching import S3Cache
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
AuthenticationError,
|
AuthenticationError,
|
||||||
BadRequestError,
|
BadRequestError,
|
||||||
|
@ -862,6 +863,7 @@ class Logging:
|
||||||
curl_command += additional_args.get("request_str", None)
|
curl_command += additional_args.get("request_str", None)
|
||||||
elif api_base == "":
|
elif api_base == "":
|
||||||
curl_command = self.model_call_details
|
curl_command = self.model_call_details
|
||||||
|
print_verbose(f"\033[92m{curl_command}\033[0m\n")
|
||||||
verbose_logger.info(f"\033[92m{curl_command}\033[0m\n")
|
verbose_logger.info(f"\033[92m{curl_command}\033[0m\n")
|
||||||
if self.logger_fn and callable(self.logger_fn):
|
if self.logger_fn and callable(self.logger_fn):
|
||||||
try:
|
try:
|
||||||
|
@ -2196,12 +2198,21 @@ def client(original_function):
|
||||||
)
|
)
|
||||||
# if caching is false or cache["no-cache"]==True, don't run this
|
# if caching is false or cache["no-cache"]==True, don't run this
|
||||||
if (
|
if (
|
||||||
(kwargs.get("caching", None) is None and litellm.cache is not None)
|
(
|
||||||
or kwargs.get("caching", False) == True
|
(
|
||||||
or (
|
kwargs.get("caching", None) is None
|
||||||
kwargs.get("cache", None) is not None
|
and kwargs.get("cache", None) is None
|
||||||
and kwargs.get("cache", {}).get("no-cache", False) != True
|
and litellm.cache is not None
|
||||||
|
)
|
||||||
|
or kwargs.get("caching", False) == True
|
||||||
|
or (
|
||||||
|
kwargs.get("cache", None) is not None
|
||||||
|
and kwargs.get("cache", {}).get("no-cache", False) != True
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
and kwargs.get("aembedding", False) != True
|
||||||
|
and kwargs.get("acompletion", False) != True
|
||||||
|
and kwargs.get("aimg_generation", False) != True
|
||||||
): # allow users to control returning cached responses from the completion function
|
): # allow users to control returning cached responses from the completion function
|
||||||
# checking cache
|
# checking cache
|
||||||
print_verbose(f"INSIDE CHECKING CACHE")
|
print_verbose(f"INSIDE CHECKING CACHE")
|
||||||
|
@ -2435,6 +2446,7 @@ def client(original_function):
|
||||||
result = None
|
result = None
|
||||||
logging_obj = kwargs.get("litellm_logging_obj", None)
|
logging_obj = kwargs.get("litellm_logging_obj", None)
|
||||||
# only set litellm_call_id if its not in kwargs
|
# only set litellm_call_id if its not in kwargs
|
||||||
|
call_type = original_function.__name__
|
||||||
if "litellm_call_id" not in kwargs:
|
if "litellm_call_id" not in kwargs:
|
||||||
kwargs["litellm_call_id"] = str(uuid.uuid4())
|
kwargs["litellm_call_id"] = str(uuid.uuid4())
|
||||||
try:
|
try:
|
||||||
|
@ -2465,8 +2477,14 @@ def client(original_function):
|
||||||
f"kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}"
|
f"kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}"
|
||||||
)
|
)
|
||||||
# if caching is false, don't run this
|
# if caching is false, don't run this
|
||||||
|
final_embedding_cached_response = None
|
||||||
|
|
||||||
if (
|
if (
|
||||||
(kwargs.get("caching", None) is None and litellm.cache is not None)
|
(
|
||||||
|
kwargs.get("caching", None) is None
|
||||||
|
and kwargs.get("cache", None) is None
|
||||||
|
and litellm.cache is not None
|
||||||
|
)
|
||||||
or kwargs.get("caching", False) == True
|
or kwargs.get("caching", False) == True
|
||||||
or (
|
or (
|
||||||
kwargs.get("cache", None) is not None
|
kwargs.get("cache", None) is not None
|
||||||
|
@ -2481,8 +2499,36 @@ def client(original_function):
|
||||||
in litellm.cache.supported_call_types
|
in litellm.cache.supported_call_types
|
||||||
):
|
):
|
||||||
print_verbose(f"Checking Cache")
|
print_verbose(f"Checking Cache")
|
||||||
cached_result = litellm.cache.get_cache(*args, **kwargs)
|
if call_type == CallTypes.aembedding.value and isinstance(
|
||||||
if cached_result != None:
|
kwargs["input"], list
|
||||||
|
):
|
||||||
|
tasks = []
|
||||||
|
for idx, i in enumerate(kwargs["input"]):
|
||||||
|
preset_cache_key = litellm.cache.get_cache_key(
|
||||||
|
*args, **{**kwargs, "input": i}
|
||||||
|
)
|
||||||
|
tasks.append(
|
||||||
|
litellm.cache.async_get_cache(
|
||||||
|
cache_key=preset_cache_key
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cached_result = await asyncio.gather(*tasks)
|
||||||
|
## check if cached result is None ##
|
||||||
|
if cached_result is not None and isinstance(
|
||||||
|
cached_result, list
|
||||||
|
):
|
||||||
|
if len(cached_result) == 1 and cached_result[0] is None:
|
||||||
|
cached_result = None
|
||||||
|
else:
|
||||||
|
preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs)
|
||||||
|
kwargs[
|
||||||
|
"preset_cache_key"
|
||||||
|
] = preset_cache_key # for streaming calls, we need to pass the preset_cache_key
|
||||||
|
cached_result = litellm.cache.get_cache(*args, **kwargs)
|
||||||
|
|
||||||
|
if cached_result is not None and not isinstance(
|
||||||
|
cached_result, list
|
||||||
|
):
|
||||||
print_verbose(f"Cache Hit!")
|
print_verbose(f"Cache Hit!")
|
||||||
call_type = original_function.__name__
|
call_type = original_function.__name__
|
||||||
if call_type == CallTypes.acompletion.value and isinstance(
|
if call_type == CallTypes.acompletion.value and isinstance(
|
||||||
|
@ -2555,6 +2601,103 @@ def client(original_function):
|
||||||
args=(cached_result, start_time, end_time, cache_hit),
|
args=(cached_result, start_time, end_time, cache_hit),
|
||||||
).start()
|
).start()
|
||||||
return cached_result
|
return cached_result
|
||||||
|
elif (
|
||||||
|
call_type == CallTypes.aembedding.value
|
||||||
|
and cached_result is not None
|
||||||
|
and isinstance(cached_result, list)
|
||||||
|
and litellm.cache is not None
|
||||||
|
and not isinstance(
|
||||||
|
litellm.cache.cache, S3Cache
|
||||||
|
) # s3 doesn't support bulk writing. Exclude.
|
||||||
|
):
|
||||||
|
remaining_list = []
|
||||||
|
non_null_list = []
|
||||||
|
for idx, cr in enumerate(cached_result):
|
||||||
|
if cr is None:
|
||||||
|
remaining_list.append(kwargs["input"][idx])
|
||||||
|
else:
|
||||||
|
non_null_list.append((idx, cr))
|
||||||
|
original_kwargs_input = kwargs["input"]
|
||||||
|
kwargs["input"] = remaining_list
|
||||||
|
if len(non_null_list) > 0:
|
||||||
|
print_verbose(
|
||||||
|
f"EMBEDDING CACHE HIT! - {len(non_null_list)}"
|
||||||
|
)
|
||||||
|
final_embedding_cached_response = EmbeddingResponse(
|
||||||
|
model=kwargs.get("model"),
|
||||||
|
data=[None] * len(original_kwargs_input),
|
||||||
|
)
|
||||||
|
final_embedding_cached_response._hidden_params[
|
||||||
|
"cache_hit"
|
||||||
|
] = True
|
||||||
|
|
||||||
|
for val in non_null_list:
|
||||||
|
idx, cr = val # (idx, cr) tuple
|
||||||
|
if cr is not None:
|
||||||
|
final_embedding_cached_response.data[idx] = cr
|
||||||
|
if len(remaining_list) == 0:
|
||||||
|
# LOG SUCCESS
|
||||||
|
cache_hit = True
|
||||||
|
end_time = datetime.datetime.now()
|
||||||
|
(
|
||||||
|
model,
|
||||||
|
custom_llm_provider,
|
||||||
|
dynamic_api_key,
|
||||||
|
api_base,
|
||||||
|
) = litellm.get_llm_provider(
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider=kwargs.get(
|
||||||
|
"custom_llm_provider", None
|
||||||
|
),
|
||||||
|
api_base=kwargs.get("api_base", None),
|
||||||
|
api_key=kwargs.get("api_key", None),
|
||||||
|
)
|
||||||
|
print_verbose(
|
||||||
|
f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}"
|
||||||
|
)
|
||||||
|
logging_obj.update_environment_variables(
|
||||||
|
model=model,
|
||||||
|
user=kwargs.get("user", None),
|
||||||
|
optional_params={},
|
||||||
|
litellm_params={
|
||||||
|
"logger_fn": kwargs.get("logger_fn", None),
|
||||||
|
"acompletion": True,
|
||||||
|
"metadata": kwargs.get("metadata", {}),
|
||||||
|
"model_info": kwargs.get("model_info", {}),
|
||||||
|
"proxy_server_request": kwargs.get(
|
||||||
|
"proxy_server_request", None
|
||||||
|
),
|
||||||
|
"preset_cache_key": kwargs.get(
|
||||||
|
"preset_cache_key", None
|
||||||
|
),
|
||||||
|
"stream_response": kwargs.get(
|
||||||
|
"stream_response", {}
|
||||||
|
),
|
||||||
|
},
|
||||||
|
input=kwargs.get("messages", ""),
|
||||||
|
api_key=kwargs.get("api_key", None),
|
||||||
|
original_response=str(final_embedding_cached_response),
|
||||||
|
additional_args=None,
|
||||||
|
stream=kwargs.get("stream", False),
|
||||||
|
)
|
||||||
|
asyncio.create_task(
|
||||||
|
logging_obj.async_success_handler(
|
||||||
|
final_embedding_cached_response,
|
||||||
|
start_time,
|
||||||
|
end_time,
|
||||||
|
cache_hit,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
threading.Thread(
|
||||||
|
target=logging_obj.success_handler,
|
||||||
|
args=(
|
||||||
|
final_embedding_cached_response,
|
||||||
|
start_time,
|
||||||
|
end_time,
|
||||||
|
cache_hit,
|
||||||
|
),
|
||||||
|
).start()
|
||||||
|
return final_embedding_cached_response
|
||||||
# MODEL CALL
|
# MODEL CALL
|
||||||
result = await original_function(*args, **kwargs)
|
result = await original_function(*args, **kwargs)
|
||||||
end_time = datetime.datetime.now()
|
end_time = datetime.datetime.now()
|
||||||
|
@ -2587,12 +2730,28 @@ def client(original_function):
|
||||||
if isinstance(result, litellm.ModelResponse) or isinstance(
|
if isinstance(result, litellm.ModelResponse) or isinstance(
|
||||||
result, litellm.EmbeddingResponse
|
result, litellm.EmbeddingResponse
|
||||||
):
|
):
|
||||||
asyncio.create_task(
|
if (
|
||||||
litellm.cache._async_add_cache(result.json(), *args, **kwargs)
|
isinstance(result, EmbeddingResponse)
|
||||||
)
|
and isinstance(kwargs["input"], list)
|
||||||
|
and litellm.cache is not None
|
||||||
|
and not isinstance(
|
||||||
|
litellm.cache.cache, S3Cache
|
||||||
|
) # s3 doesn't support bulk writing. Exclude.
|
||||||
|
):
|
||||||
|
asyncio.create_task(
|
||||||
|
litellm.cache.async_add_cache_pipeline(
|
||||||
|
result, *args, **kwargs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
asyncio.create_task(
|
||||||
|
litellm.cache.async_add_cache(
|
||||||
|
result.json(), *args, **kwargs
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
litellm.cache._async_add_cache(result, *args, **kwargs)
|
litellm.cache.async_add_cache(result, *args, **kwargs)
|
||||||
)
|
)
|
||||||
# LOG SUCCESS - handle streaming success logging in the _next_ object
|
# LOG SUCCESS - handle streaming success logging in the _next_ object
|
||||||
print_verbose(
|
print_verbose(
|
||||||
|
@ -2616,6 +2775,27 @@ def client(original_function):
|
||||||
result._response_ms = (
|
result._response_ms = (
|
||||||
end_time - start_time
|
end_time - start_time
|
||||||
).total_seconds() * 1000 # return response latency in ms like openai
|
).total_seconds() * 1000 # return response latency in ms like openai
|
||||||
|
|
||||||
|
if (
|
||||||
|
isinstance(result, EmbeddingResponse)
|
||||||
|
and final_embedding_cached_response is not None
|
||||||
|
):
|
||||||
|
idx = 0
|
||||||
|
final_data_list = []
|
||||||
|
for item in final_embedding_cached_response.data:
|
||||||
|
if item is None:
|
||||||
|
final_data_list.append(result.data[idx])
|
||||||
|
idx += 1
|
||||||
|
else:
|
||||||
|
final_data_list.append(item)
|
||||||
|
|
||||||
|
final_embedding_cached_response.data = final_data_list
|
||||||
|
final_embedding_cached_response._hidden_params["cache_hit"] = True
|
||||||
|
final_embedding_cached_response._response_ms = (
|
||||||
|
end_time - start_time
|
||||||
|
).total_seconds() * 1000
|
||||||
|
return final_embedding_cached_response
|
||||||
|
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback_exception = traceback.format_exc()
|
traceback_exception = traceback.format_exc()
|
||||||
|
@ -3275,7 +3455,11 @@ def completion_cost(
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Model={model} not found in completion cost model map")
|
raise Exception(f"Model={model} not found in completion cost model map")
|
||||||
# Calculate cost based on prompt_tokens, completion_tokens
|
# Calculate cost based on prompt_tokens, completion_tokens
|
||||||
if "togethercomputer" in model or "together_ai" in model:
|
if (
|
||||||
|
"togethercomputer" in model
|
||||||
|
or "together_ai" in model
|
||||||
|
or custom_llm_provider == "together_ai"
|
||||||
|
):
|
||||||
# together ai prices based on size of llm
|
# together ai prices based on size of llm
|
||||||
# get_model_params_and_category takes a model name and returns the category of LLM size it is in model_prices_and_context_window.json
|
# get_model_params_and_category takes a model name and returns the category of LLM size it is in model_prices_and_context_window.json
|
||||||
model = get_model_params_and_category(model)
|
model = get_model_params_and_category(model)
|
||||||
|
@ -3864,7 +4048,7 @@ def get_optional_params(
|
||||||
_check_valid_arg(supported_params=supported_params)
|
_check_valid_arg(supported_params=supported_params)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
optional_params["stream_tokens"] = stream
|
optional_params["stream"] = stream
|
||||||
if temperature is not None:
|
if temperature is not None:
|
||||||
optional_params["temperature"] = temperature
|
optional_params["temperature"] = temperature
|
||||||
if top_p is not None:
|
if top_p is not None:
|
||||||
|
@ -4498,6 +4682,14 @@ def get_llm_provider(
|
||||||
# voyage is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.voyageai.com/v1
|
# voyage is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.voyageai.com/v1
|
||||||
api_base = "https://api.voyageai.com/v1"
|
api_base = "https://api.voyageai.com/v1"
|
||||||
dynamic_api_key = get_secret("VOYAGE_API_KEY")
|
dynamic_api_key = get_secret("VOYAGE_API_KEY")
|
||||||
|
elif custom_llm_provider == "together_ai":
|
||||||
|
api_base = "https://api.together.xyz/v1"
|
||||||
|
dynamic_api_key = (
|
||||||
|
get_secret("TOGETHER_API_KEY")
|
||||||
|
or get_secret("TOGETHER_AI_API_KEY")
|
||||||
|
or get_secret("TOGETHERAI_API_KEY")
|
||||||
|
or get_secret("TOGETHER_AI_TOKEN")
|
||||||
|
)
|
||||||
return model, custom_llm_provider, dynamic_api_key, api_base
|
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||||
elif model.split("/", 1)[0] in litellm.provider_list:
|
elif model.split("/", 1)[0] in litellm.provider_list:
|
||||||
custom_llm_provider = model.split("/", 1)[0]
|
custom_llm_provider = model.split("/", 1)[0]
|
||||||
|
@ -6383,7 +6575,12 @@ def exception_type(
|
||||||
message=f"BedrockException - {original_exception.message}",
|
message=f"BedrockException - {original_exception.message}",
|
||||||
llm_provider="bedrock",
|
llm_provider="bedrock",
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=httpx.Response(
|
||||||
|
status_code=500,
|
||||||
|
request=httpx.Request(
|
||||||
|
method="POST", url="https://api.openai.com/v1/"
|
||||||
|
),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 401:
|
elif original_exception.status_code == 401:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue