forked from phoenix/litellm-mirror
fix: support async redis caching
This commit is contained in:
parent
817a3d29b7
commit
007870390d
6 changed files with 357 additions and 122 deletions
|
@ -11,6 +11,7 @@
|
||||||
import os
|
import os
|
||||||
import inspect
|
import inspect
|
||||||
import redis, litellm
|
import redis, litellm
|
||||||
|
import redis.asyncio as async_redis
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@ -67,7 +68,10 @@ def get_redis_url_from_environment():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_redis_client(**env_overrides):
|
def _get_redis_client_logic(**env_overrides):
|
||||||
|
"""
|
||||||
|
Common functionality across sync + async redis client implementations
|
||||||
|
"""
|
||||||
### check if "os.environ/<key-name>" passed in
|
### check if "os.environ/<key-name>" passed in
|
||||||
for k, v in env_overrides.items():
|
for k, v in env_overrides.items():
|
||||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||||
|
@ -85,9 +89,21 @@ def get_redis_client(**env_overrides):
|
||||||
redis_kwargs.pop("port", None)
|
redis_kwargs.pop("port", None)
|
||||||
redis_kwargs.pop("db", None)
|
redis_kwargs.pop("db", None)
|
||||||
redis_kwargs.pop("password", None)
|
redis_kwargs.pop("password", None)
|
||||||
|
|
||||||
return redis.Redis.from_url(**redis_kwargs)
|
|
||||||
elif "host" not in redis_kwargs or redis_kwargs["host"] is None:
|
elif "host" not in redis_kwargs or redis_kwargs["host"] is None:
|
||||||
raise ValueError("Either 'host' or 'url' must be specified for redis.")
|
raise ValueError("Either 'host' or 'url' must be specified for redis.")
|
||||||
litellm.print_verbose(f"redis_kwargs: {redis_kwargs}")
|
litellm.print_verbose(f"redis_kwargs: {redis_kwargs}")
|
||||||
|
return redis_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def get_redis_client(**env_overrides):
|
||||||
|
redis_kwargs = _get_redis_client_logic(**env_overrides)
|
||||||
|
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
|
||||||
|
return redis.Redis.from_url(**redis_kwargs)
|
||||||
return redis.Redis(**redis_kwargs)
|
return redis.Redis(**redis_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_redis_async_client(**env_overrides):
|
||||||
|
redis_kwargs = _get_redis_client_logic(**env_overrides)
|
||||||
|
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
|
||||||
|
return async_redis.Redis.from_url(**redis_kwargs)
|
||||||
|
return async_redis.Redis(socket_timeout=5, **redis_kwargs)
|
||||||
|
|
|
@ -26,9 +26,18 @@ class BaseCache:
|
||||||
def set_cache(self, key, value, **kwargs):
|
def set_cache(self, key, value, **kwargs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def async_set_cache(self, key, value, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_cache(self, key, **kwargs):
|
def get_cache(self, key, **kwargs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def async_get_cache(self, key, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def disconnect(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class InMemoryCache(BaseCache):
|
class InMemoryCache(BaseCache):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -41,6 +50,9 @@ class InMemoryCache(BaseCache):
|
||||||
if "ttl" in kwargs:
|
if "ttl" in kwargs:
|
||||||
self.ttl_dict[key] = time.time() + kwargs["ttl"]
|
self.ttl_dict[key] = time.time() + kwargs["ttl"]
|
||||||
|
|
||||||
|
async def async_set_cache(self, key, value, **kwargs):
|
||||||
|
self.set_cache(key=key, value=value, **kwargs)
|
||||||
|
|
||||||
def get_cache(self, key, **kwargs):
|
def get_cache(self, key, **kwargs):
|
||||||
if key in self.cache_dict:
|
if key in self.cache_dict:
|
||||||
if key in self.ttl_dict:
|
if key in self.ttl_dict:
|
||||||
|
@ -55,16 +67,21 @@ class InMemoryCache(BaseCache):
|
||||||
return cached_response
|
return cached_response
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def async_get_cache(self, key, **kwargs):
|
||||||
|
return self.get_cache(key=key, **kwargs)
|
||||||
|
|
||||||
def flush_cache(self):
|
def flush_cache(self):
|
||||||
self.cache_dict.clear()
|
self.cache_dict.clear()
|
||||||
self.ttl_dict.clear()
|
self.ttl_dict.clear()
|
||||||
|
|
||||||
|
async def disconnect(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class RedisCache(BaseCache):
|
class RedisCache(BaseCache):
|
||||||
def __init__(self, host=None, port=None, password=None, **kwargs):
|
# if users don't provider one, use the default litellm cache
|
||||||
import redis
|
|
||||||
|
|
||||||
# if users don't provider one, use the default litellm cache
|
def __init__(self, host=None, port=None, password=None, **kwargs):
|
||||||
from ._redis import get_redis_client
|
from ._redis import get_redis_client
|
||||||
|
|
||||||
redis_kwargs = {}
|
redis_kwargs = {}
|
||||||
|
@ -76,8 +93,13 @@ class RedisCache(BaseCache):
|
||||||
redis_kwargs["password"] = password
|
redis_kwargs["password"] = password
|
||||||
|
|
||||||
redis_kwargs.update(kwargs)
|
redis_kwargs.update(kwargs)
|
||||||
|
|
||||||
self.redis_client = get_redis_client(**redis_kwargs)
|
self.redis_client = get_redis_client(**redis_kwargs)
|
||||||
|
self.redis_kwargs = redis_kwargs
|
||||||
|
|
||||||
|
def init_async_client(self):
|
||||||
|
from ._redis import get_redis_async_client
|
||||||
|
|
||||||
|
return get_redis_async_client(**self.redis_kwargs)
|
||||||
|
|
||||||
def set_cache(self, key, value, **kwargs):
|
def set_cache(self, key, value, **kwargs):
|
||||||
ttl = kwargs.get("ttl", None)
|
ttl = kwargs.get("ttl", None)
|
||||||
|
@ -88,6 +110,34 @@ class RedisCache(BaseCache):
|
||||||
# NON blocking - notify users Redis is throwing an exception
|
# NON blocking - notify users Redis is throwing an exception
|
||||||
logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e)
|
logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e)
|
||||||
|
|
||||||
|
async def async_set_cache(self, key, value, **kwargs):
|
||||||
|
async with self.init_async_client() as redis_client:
|
||||||
|
ttl = kwargs.get("ttl", None)
|
||||||
|
print_verbose(
|
||||||
|
f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await redis_client.set(name=key, value=str(value), ex=ttl)
|
||||||
|
except Exception as e:
|
||||||
|
# NON blocking - notify users Redis is throwing an exception
|
||||||
|
logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e)
|
||||||
|
|
||||||
|
def _get_cache_logic(self, cached_response: Any):
|
||||||
|
"""
|
||||||
|
Common 'get_cache_logic' across sync + async redis client implementations
|
||||||
|
"""
|
||||||
|
if cached_response is None:
|
||||||
|
return cached_response
|
||||||
|
# cached_response is in `b{} convert it to ModelResponse
|
||||||
|
cached_response = cached_response.decode("utf-8") # Convert bytes to string
|
||||||
|
try:
|
||||||
|
cached_response = json.loads(
|
||||||
|
cached_response
|
||||||
|
) # Convert string to dictionary
|
||||||
|
except:
|
||||||
|
cached_response = ast.literal_eval(cached_response)
|
||||||
|
return cached_response
|
||||||
|
|
||||||
def get_cache(self, key, **kwargs):
|
def get_cache(self, key, **kwargs):
|
||||||
try:
|
try:
|
||||||
print_verbose(f"Get Redis Cache: key: {key}")
|
print_verbose(f"Get Redis Cache: key: {key}")
|
||||||
|
@ -95,26 +145,33 @@ class RedisCache(BaseCache):
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"Got Redis Cache: key: {key}, cached_response {cached_response}"
|
f"Got Redis Cache: key: {key}, cached_response {cached_response}"
|
||||||
)
|
)
|
||||||
if cached_response != None:
|
return self._get_cache_logic(cached_response=cached_response)
|
||||||
# cached_response is in `b{} convert it to ModelResponse
|
|
||||||
cached_response = cached_response.decode(
|
|
||||||
"utf-8"
|
|
||||||
) # Convert bytes to string
|
|
||||||
try:
|
|
||||||
cached_response = json.loads(
|
|
||||||
cached_response
|
|
||||||
) # Convert string to dictionary
|
|
||||||
except:
|
|
||||||
cached_response = ast.literal_eval(cached_response)
|
|
||||||
return cached_response
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# NON blocking - notify users Redis is throwing an exception
|
# NON blocking - notify users Redis is throwing an exception
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
|
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
|
||||||
|
|
||||||
|
async def async_get_cache(self, key, **kwargs):
|
||||||
|
async with self.init_async_client() as redis_client:
|
||||||
|
try:
|
||||||
|
print_verbose(f"Get Redis Cache: key: {key}")
|
||||||
|
cached_response = await redis_client.get(key)
|
||||||
|
print_verbose(
|
||||||
|
f"Got Async Redis Cache: key: {key}, cached_response {cached_response}"
|
||||||
|
)
|
||||||
|
response = self._get_cache_logic(cached_response=cached_response)
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
# NON blocking - notify users Redis is throwing an exception
|
||||||
|
traceback.print_exc()
|
||||||
|
logging.debug("LiteLLM Caching: get() - Got exception from REDIS: ", e)
|
||||||
|
|
||||||
def flush_cache(self):
|
def flush_cache(self):
|
||||||
self.redis_client.flushall()
|
self.redis_client.flushall()
|
||||||
|
|
||||||
|
async def disconnect(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class S3Cache(BaseCache):
|
class S3Cache(BaseCache):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -189,6 +246,9 @@ class S3Cache(BaseCache):
|
||||||
# NON blocking - notify users S3 is throwing an exception
|
# NON blocking - notify users S3 is throwing an exception
|
||||||
print_verbose(f"S3 Caching: set_cache() - Got exception from S3: {e}")
|
print_verbose(f"S3 Caching: set_cache() - Got exception from S3: {e}")
|
||||||
|
|
||||||
|
async def async_set_cache(self, key, value, **kwargs):
|
||||||
|
self.set_cache(key=key, value=value, **kwargs)
|
||||||
|
|
||||||
def get_cache(self, key, **kwargs):
|
def get_cache(self, key, **kwargs):
|
||||||
import boto3, botocore
|
import boto3, botocore
|
||||||
|
|
||||||
|
@ -229,6 +289,9 @@ class S3Cache(BaseCache):
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
print_verbose(f"S3 Caching: get_cache() - Got exception from S3: {e}")
|
print_verbose(f"S3 Caching: get_cache() - Got exception from S3: {e}")
|
||||||
|
|
||||||
|
async def async_get_cache(self, key, **kwargs):
|
||||||
|
return self.get_cache(key=key, **kwargs)
|
||||||
|
|
||||||
def flush_cache(self):
|
def flush_cache(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -468,6 +531,45 @@ class Cache:
|
||||||
}
|
}
|
||||||
time.sleep(0.02)
|
time.sleep(0.02)
|
||||||
|
|
||||||
|
def _get_cache_logic(
|
||||||
|
self,
|
||||||
|
cached_result: Optional[Any],
|
||||||
|
max_age: Optional[float],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Common get cache logic across sync + async implementations
|
||||||
|
"""
|
||||||
|
# Check if a timestamp was stored with the cached response
|
||||||
|
if (
|
||||||
|
cached_result is not None
|
||||||
|
and isinstance(cached_result, dict)
|
||||||
|
and "timestamp" in cached_result
|
||||||
|
):
|
||||||
|
timestamp = cached_result["timestamp"]
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
|
# Calculate age of the cached response
|
||||||
|
response_age = current_time - timestamp
|
||||||
|
|
||||||
|
# Check if the cached response is older than the max-age
|
||||||
|
if max_age is not None and response_age > max_age:
|
||||||
|
return None # Cached response is too old
|
||||||
|
|
||||||
|
# If the response is fresh, or there's no max-age requirement, return the cached response
|
||||||
|
# cached_response is in `b{} convert it to ModelResponse
|
||||||
|
cached_response = cached_result.get("response")
|
||||||
|
try:
|
||||||
|
if isinstance(cached_response, dict):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
cached_response = json.loads(
|
||||||
|
cached_response # type: ignore
|
||||||
|
) # Convert string to dictionary
|
||||||
|
except:
|
||||||
|
cached_response = ast.literal_eval(cached_response) # type: ignore
|
||||||
|
return cached_response
|
||||||
|
return cached_result
|
||||||
|
|
||||||
def get_cache(self, *args, **kwargs):
|
def get_cache(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Retrieves the cached result for the given arguments.
|
Retrieves the cached result for the given arguments.
|
||||||
|
@ -490,53 +592,40 @@ class Cache:
|
||||||
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
|
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
|
||||||
)
|
)
|
||||||
cached_result = self.cache.get_cache(cache_key)
|
cached_result = self.cache.get_cache(cache_key)
|
||||||
# Check if a timestamp was stored with the cached response
|
return self._get_cache_logic(
|
||||||
if (
|
cached_result=cached_result, max_age=max_age
|
||||||
cached_result is not None
|
)
|
||||||
and isinstance(cached_result, dict)
|
|
||||||
and "timestamp" in cached_result
|
|
||||||
):
|
|
||||||
timestamp = cached_result["timestamp"]
|
|
||||||
current_time = time.time()
|
|
||||||
|
|
||||||
# Calculate age of the cached response
|
|
||||||
response_age = current_time - timestamp
|
|
||||||
|
|
||||||
# Check if the cached response is older than the max-age
|
|
||||||
if max_age is not None and response_age > max_age:
|
|
||||||
print_verbose(
|
|
||||||
f"Cached response for key {cache_key} is too old. Max-age: {max_age}s, Age: {response_age}s"
|
|
||||||
)
|
|
||||||
return None # Cached response is too old
|
|
||||||
|
|
||||||
# If the response is fresh, or there's no max-age requirement, return the cached response
|
|
||||||
# cached_response is in `b{} convert it to ModelResponse
|
|
||||||
cached_response = cached_result.get("response")
|
|
||||||
try:
|
|
||||||
if isinstance(cached_response, dict):
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
cached_response = json.loads(
|
|
||||||
cached_response
|
|
||||||
) # Convert string to dictionary
|
|
||||||
except:
|
|
||||||
cached_response = ast.literal_eval(cached_response)
|
|
||||||
return cached_response
|
|
||||||
return cached_result
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_verbose(f"An exception occurred: {traceback.format_exc()}")
|
print_verbose(f"An exception occurred: {traceback.format_exc()}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def add_cache(self, result, *args, **kwargs):
|
async def async_get_cache(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Adds a result to the cache.
|
Async get cache implementation.
|
||||||
|
|
||||||
Args:
|
Used for embedding calls in async wrapper
|
||||||
*args: args to litellm.completion() or embedding()
|
"""
|
||||||
**kwargs: kwargs to litellm.completion() or embedding()
|
try: # never block execution
|
||||||
|
if "cache_key" in kwargs:
|
||||||
|
cache_key = kwargs["cache_key"]
|
||||||
|
else:
|
||||||
|
cache_key = self.get_cache_key(*args, **kwargs)
|
||||||
|
if cache_key is not None:
|
||||||
|
cache_control_args = kwargs.get("cache", {})
|
||||||
|
max_age = cache_control_args.get(
|
||||||
|
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
|
||||||
|
)
|
||||||
|
cached_result = await self.cache.async_get_cache(cache_key)
|
||||||
|
return self._get_cache_logic(
|
||||||
|
cached_result=cached_result, max_age=max_age
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print_verbose(f"An exception occurred: {traceback.format_exc()}")
|
||||||
|
return None
|
||||||
|
|
||||||
Returns:
|
def _add_cache_logic(self, result, *args, **kwargs):
|
||||||
None
|
"""
|
||||||
|
Common implementation across sync + async add_cache functions
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if "cache_key" in kwargs:
|
if "cache_key" in kwargs:
|
||||||
|
@ -555,17 +644,49 @@ class Cache:
|
||||||
if k == "ttl":
|
if k == "ttl":
|
||||||
kwargs["ttl"] = v
|
kwargs["ttl"] = v
|
||||||
cached_data = {"timestamp": time.time(), "response": result}
|
cached_data = {"timestamp": time.time(), "response": result}
|
||||||
self.cache.set_cache(cache_key, cached_data, **kwargs)
|
return cache_key, cached_data
|
||||||
|
else:
|
||||||
|
raise Exception("cache key is None")
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def add_cache(self, result, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Adds a result to the cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args: args to litellm.completion() or embedding()
|
||||||
|
**kwargs: kwargs to litellm.completion() or embedding()
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
cache_key, cached_data = self._add_cache_logic(
|
||||||
|
result=result, *args, **kwargs
|
||||||
|
)
|
||||||
|
self.cache.set_cache(cache_key, cached_data, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def _async_add_cache(self, result, *args, **kwargs):
|
async def async_add_cache(self, result, *args, **kwargs):
|
||||||
self.add_cache(result, *args, **kwargs)
|
"""
|
||||||
|
Async implementation of add_cache
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
cache_key, cached_data = self._add_cache_logic(
|
||||||
|
result=result, *args, **kwargs
|
||||||
|
)
|
||||||
|
await self.cache.async_set_cache(cache_key, cached_data, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
async def _async_get_cache(self, *args, **kwargs):
|
async def disconnect(self):
|
||||||
return self.get_cache(*args, **kwargs)
|
if hasattr(self.cache, "disconnect"):
|
||||||
|
await self.cache.disconnect()
|
||||||
|
|
||||||
|
|
||||||
def enable_cache(
|
def enable_cache(
|
||||||
|
|
|
@ -346,7 +346,7 @@ def run_server(
|
||||||
import gunicorn.app.base
|
import gunicorn.app.base
|
||||||
except:
|
except:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Uvicorn, gunicorn needs to be imported. Run - `pip 'litellm[proxy]'`"
|
"uvicorn, gunicorn needs to be imported. Run - `pip install 'litellm[proxy]'`"
|
||||||
)
|
)
|
||||||
|
|
||||||
if config is not None:
|
if config is not None:
|
||||||
|
@ -427,36 +427,40 @@ def run_server(
|
||||||
f"\033[1;34mSee all Router/Swagger docs on http://0.0.0.0:{port} \033[0m\n"
|
f"\033[1;34mSee all Router/Swagger docs on http://0.0.0.0:{port} \033[0m\n"
|
||||||
) # noqa
|
) # noqa
|
||||||
|
|
||||||
# Gunicorn Application Class
|
uvicorn.run(
|
||||||
class StandaloneApplication(gunicorn.app.base.BaseApplication):
|
"litellm.proxy.proxy_server:app", host=host, port=port, workers=num_workers
|
||||||
def __init__(self, app, options=None):
|
)
|
||||||
self.options = options or {} # gunicorn options
|
|
||||||
self.application = app # FastAPI app
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def load_config(self):
|
# # Gunicorn Application Class
|
||||||
# note: This Loads the gunicorn config - has nothing to do with LiteLLM Proxy config
|
# class StandaloneApplication(gunicorn.app.base.BaseApplication):
|
||||||
config = {
|
# def __init__(self, app, options=None):
|
||||||
key: value
|
# self.options = options or {} # gunicorn options
|
||||||
for key, value in self.options.items()
|
# self.application = app # FastAPI app
|
||||||
if key in self.cfg.settings and value is not None
|
# super().__init__()
|
||||||
}
|
|
||||||
for key, value in config.items():
|
|
||||||
self.cfg.set(key.lower(), value)
|
|
||||||
|
|
||||||
def load(self):
|
# def load_config(self):
|
||||||
# gunicorn app function
|
# # note: This Loads the gunicorn config - has nothing to do with LiteLLM Proxy config
|
||||||
return self.application
|
# config = {
|
||||||
|
# key: value
|
||||||
|
# for key, value in self.options.items()
|
||||||
|
# if key in self.cfg.settings and value is not None
|
||||||
|
# }
|
||||||
|
# for key, value in config.items():
|
||||||
|
# self.cfg.set(key.lower(), value)
|
||||||
|
|
||||||
gunicorn_options = {
|
# def load(self):
|
||||||
"bind": f"{host}:{port}",
|
# # gunicorn app function
|
||||||
"workers": num_workers, # default is 1
|
# return self.application
|
||||||
"worker_class": "uvicorn.workers.UvicornWorker",
|
|
||||||
"preload": True, # Add the preload flag
|
|
||||||
}
|
|
||||||
from litellm.proxy.proxy_server import app
|
|
||||||
|
|
||||||
StandaloneApplication(app=app, options=gunicorn_options).run() # Run gunicorn
|
# gunicorn_options = {
|
||||||
|
# "bind": f"{host}:{port}",
|
||||||
|
# "workers": num_workers, # default is 1
|
||||||
|
# "worker_class": "uvicorn.workers.UvicornWorker",
|
||||||
|
# "preload": True, # Add the preload flag
|
||||||
|
# }
|
||||||
|
# from litellm.proxy.proxy_server import app
|
||||||
|
|
||||||
|
# StandaloneApplication(app=app, options=gunicorn_options).run() # Run gunicorn
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -7,6 +7,20 @@ import secrets, subprocess
|
||||||
import hashlib, uuid
|
import hashlib, uuid
|
||||||
import warnings
|
import warnings
|
||||||
import importlib
|
import importlib
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
|
def showwarning(message, category, filename, lineno, file=None, line=None):
|
||||||
|
traceback_info = f"{filename}:{lineno}: {category.__name__}: {message}\n"
|
||||||
|
if file is not None:
|
||||||
|
file.write(traceback_info)
|
||||||
|
|
||||||
|
|
||||||
|
warnings.showwarning = showwarning
|
||||||
|
warnings.filterwarnings("default", category=UserWarning)
|
||||||
|
|
||||||
|
# Your client code here
|
||||||
|
|
||||||
|
|
||||||
messages: list = []
|
messages: list = []
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
|
@ -2510,10 +2524,12 @@ async def get_routes():
|
||||||
@router.on_event("shutdown")
|
@router.on_event("shutdown")
|
||||||
async def shutdown_event():
|
async def shutdown_event():
|
||||||
global prisma_client, master_key, user_custom_auth
|
global prisma_client, master_key, user_custom_auth
|
||||||
if prisma_client:
|
if prisma_client is not None:
|
||||||
verbose_proxy_logger.debug("Disconnecting from Prisma")
|
verbose_proxy_logger.debug("Disconnecting from Prisma")
|
||||||
await prisma_client.disconnect()
|
await prisma_client.disconnect()
|
||||||
|
|
||||||
|
if litellm.cache is not None:
|
||||||
|
await litellm.cache.disconnect()
|
||||||
## RESET CUSTOM VARIABLES ##
|
## RESET CUSTOM VARIABLES ##
|
||||||
cleanup_router_config_variables()
|
cleanup_router_config_variables()
|
||||||
|
|
||||||
|
|
|
@ -266,8 +266,9 @@ async def test_embedding_caching_azure_individual_items():
|
||||||
"""
|
"""
|
||||||
Tests caching for individual items in an embedding list
|
Tests caching for individual items in an embedding list
|
||||||
|
|
||||||
Assert if the same embeddingresponse object is returned for the duplicate item in 2 embedding list calls
|
- Cache an item
|
||||||
|
- call aembedding(..) with the item + 1 unique item
|
||||||
|
- compare to a 2nd aembedding (...) with 2 unique items
|
||||||
```
|
```
|
||||||
embedding_1 = ["hey how's it going", "I'm doing well"]
|
embedding_1 = ["hey how's it going", "I'm doing well"]
|
||||||
embedding_val_1 = embedding(...)
|
embedding_val_1 = embedding(...)
|
||||||
|
@ -280,31 +281,98 @@ async def test_embedding_caching_azure_individual_items():
|
||||||
"""
|
"""
|
||||||
litellm.cache = Cache()
|
litellm.cache = Cache()
|
||||||
common_msg = f"hey how's it going {uuid.uuid4()}"
|
common_msg = f"hey how's it going {uuid.uuid4()}"
|
||||||
embedding_1 = [common_msg, "I'm doing well"]
|
common_msg_2 = f"hey how's it going {uuid.uuid4()}"
|
||||||
embedding_2 = [common_msg, "I'm fine"]
|
embedding_2 = [
|
||||||
|
common_msg,
|
||||||
|
f"I'm fine {uuid.uuid4()}",
|
||||||
|
common_msg,
|
||||||
|
common_msg,
|
||||||
|
common_msg,
|
||||||
|
] * 20
|
||||||
|
embedding_2 = [
|
||||||
|
common_msg,
|
||||||
|
f"I'm fine {uuid.uuid4()}",
|
||||||
|
common_msg,
|
||||||
|
common_msg,
|
||||||
|
common_msg,
|
||||||
|
] * 20
|
||||||
|
embedding_3 = [
|
||||||
|
common_msg_2,
|
||||||
|
common_msg_2,
|
||||||
|
common_msg_2,
|
||||||
|
common_msg_2,
|
||||||
|
f"I'm fine {uuid.uuid4()}",
|
||||||
|
] * 20 # make sure azure doesn't return cached 'i'm fine' responses
|
||||||
|
|
||||||
embedding_val_1 = await aembedding(
|
embedding_val_1 = await aembedding(
|
||||||
model="azure/azure-embedding-model", input=embedding_1, caching=True
|
model="azure/azure-embedding-model", input=embedding_1, caching=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
second_response_start_time = time.time()
|
||||||
embedding_val_2 = await aembedding(
|
embedding_val_2 = await aembedding(
|
||||||
model="azure/azure-embedding-model", input=embedding_2, caching=True
|
model="azure/azure-embedding-model", input=embedding_2, caching=True
|
||||||
)
|
)
|
||||||
print(f"embedding_val_2: {embedding_val_2}")
|
if embedding_val_2 is not None:
|
||||||
if (
|
second_response_end_time = time.time()
|
||||||
embedding_val_2["data"][0]["embedding"]
|
second_response_time = second_response_end_time - second_response_start_time
|
||||||
!= embedding_val_1["data"][0]["embedding"]
|
|
||||||
):
|
third_response_start_time = time.time()
|
||||||
print(f"embedding1: {embedding_val_1}")
|
embedding_val_3 = await aembedding(
|
||||||
print(f"embedding2: {embedding_val_2}")
|
model="azure/azure-embedding-model", input=embedding_3, cache={"no-cache": True}
|
||||||
pytest.fail("Error occurred: Embedding caching failed")
|
)
|
||||||
if (
|
if embedding_val_3 is not None:
|
||||||
embedding_val_2["data"][1]["embedding"]
|
third_response_end_time = time.time()
|
||||||
== embedding_val_1["data"][1]["embedding"]
|
third_response_time = third_response_end_time - third_response_start_time
|
||||||
):
|
|
||||||
print(f"embedding1: {embedding_val_1}")
|
print(f"second_response_time: {second_response_time}")
|
||||||
print(f"embedding2: {embedding_val_2}")
|
print(f"third_response_time: {third_response_time}")
|
||||||
pytest.fail("Error occurred: Embedding caching failed")
|
|
||||||
|
assert (
|
||||||
|
second_response_time < third_response_time - 0.5
|
||||||
|
) # make sure it's actually faster
|
||||||
|
raise Exception(f"it works {second_response_time} < {third_response_time}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_redis_cache_basic():
|
||||||
|
"""
|
||||||
|
Init redis client
|
||||||
|
- write to client
|
||||||
|
- read from client
|
||||||
|
"""
|
||||||
|
litellm.set_verbose = False
|
||||||
|
|
||||||
|
random_number = random.randint(
|
||||||
|
1, 100000
|
||||||
|
) # add a random number to ensure it's always adding / reading from cache
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": f"write a one sentence poem about: {random_number}"}
|
||||||
|
]
|
||||||
|
litellm.cache = Cache(
|
||||||
|
type="redis",
|
||||||
|
host=os.environ["REDIS_HOST"],
|
||||||
|
port=os.environ["REDIS_PORT"],
|
||||||
|
password=os.environ["REDIS_PASSWORD"],
|
||||||
|
)
|
||||||
|
response1 = completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
cache_key = litellm.cache.get_cache_key(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
print(f"cache_key: {cache_key}")
|
||||||
|
litellm.cache.add_cache(result=response1, cache_key=cache_key)
|
||||||
|
print(f"cache key pre async get: {cache_key}")
|
||||||
|
stored_val = await litellm.cache.async_get_cache(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
print(f"stored_val: {stored_val}")
|
||||||
|
assert stored_val["id"] == response1.id
|
||||||
|
raise Exception("it worked!")
|
||||||
|
|
||||||
|
|
||||||
def test_redis_cache_completion():
|
def test_redis_cache_completion():
|
||||||
|
|
|
@ -2214,8 +2214,13 @@ def client(original_function):
|
||||||
)
|
)
|
||||||
# if caching is false, don't run this
|
# if caching is false, don't run this
|
||||||
final_embedding_cached_response = None
|
final_embedding_cached_response = None
|
||||||
|
|
||||||
if (
|
if (
|
||||||
(kwargs.get("caching", None) is None and litellm.cache is not None)
|
(
|
||||||
|
kwargs.get("caching", None) is None
|
||||||
|
and kwargs.get("cache", None) is None
|
||||||
|
and litellm.cache is not None
|
||||||
|
)
|
||||||
or kwargs.get("caching", False) == True
|
or kwargs.get("caching", False) == True
|
||||||
or (
|
or (
|
||||||
kwargs.get("cache", None) is not None
|
kwargs.get("cache", None) is not None
|
||||||
|
@ -2234,12 +2239,13 @@ def client(original_function):
|
||||||
kwargs["input"], list
|
kwargs["input"], list
|
||||||
):
|
):
|
||||||
tasks = []
|
tasks = []
|
||||||
embedding_kwargs = copy.deepcopy(kwargs)
|
|
||||||
for idx, i in enumerate(kwargs["input"]):
|
for idx, i in enumerate(kwargs["input"]):
|
||||||
embedding_kwargs["input"] = i
|
preset_cache_key = litellm.cache.get_cache_key(
|
||||||
|
*args, **{**kwargs, "input": i}
|
||||||
|
)
|
||||||
tasks.append(
|
tasks.append(
|
||||||
litellm.cache._async_get_cache(
|
litellm.cache.async_get_cache(
|
||||||
*args, **embedding_kwargs
|
cache_key=preset_cache_key
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
cached_result = await asyncio.gather(*tasks)
|
cached_result = await asyncio.gather(*tasks)
|
||||||
|
@ -2445,24 +2451,28 @@ def client(original_function):
|
||||||
if isinstance(result, EmbeddingResponse) and isinstance(
|
if isinstance(result, EmbeddingResponse) and isinstance(
|
||||||
kwargs["input"], list
|
kwargs["input"], list
|
||||||
):
|
):
|
||||||
embedding_kwargs = copy.deepcopy(kwargs)
|
|
||||||
for idx, i in enumerate(kwargs["input"]):
|
for idx, i in enumerate(kwargs["input"]):
|
||||||
|
preset_cache_key = litellm.cache.get_cache_key(
|
||||||
|
*args, **{**kwargs, "input": i}
|
||||||
|
)
|
||||||
embedding_response = result.data[idx]
|
embedding_response = result.data[idx]
|
||||||
embedding_kwargs["input"] = i
|
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
litellm.cache._async_add_cache(
|
litellm.cache.async_add_cache(
|
||||||
embedding_response, *args, **embedding_kwargs
|
embedding_response,
|
||||||
|
*args,
|
||||||
|
cache_key=preset_cache_key,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
# pass
|
||||||
else:
|
else:
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
litellm.cache._async_add_cache(
|
litellm.cache.async_add_cache(
|
||||||
result.json(), *args, **kwargs
|
result.json(), *args, **kwargs
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
litellm.cache._async_add_cache(result, *args, **kwargs)
|
litellm.cache.async_add_cache(result, *args, **kwargs)
|
||||||
)
|
)
|
||||||
# LOG SUCCESS - handle streaming success logging in the _next_ object
|
# LOG SUCCESS - handle streaming success logging in the _next_ object
|
||||||
print_verbose(
|
print_verbose(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue