From 88c95ca259f8677b1eb1e2a356e81bf50e87add1 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 5 Dec 2023 12:16:33 -0800 Subject: [PATCH] fix(_redis.py): support additional params for redis --- litellm/_redis.py | 85 +++++++++++++++++++++++++++++++++++ litellm/caching.py | 16 ++++++- litellm/proxy/proxy_server.py | 7 +-- litellm/router.py | 22 +++++---- litellm/utils.py | 34 +++++++------- 5 files changed, 135 insertions(+), 29 deletions(-) create mode 100644 litellm/_redis.py diff --git a/litellm/_redis.py b/litellm/_redis.py new file mode 100644 index 0000000000..82e0ab0ec6 --- /dev/null +++ b/litellm/_redis.py @@ -0,0 +1,85 @@ +# +-----------------------------------------------+ +# | | +# | Give Feedback / Get Help | +# | https://github.com/BerriAI/litellm/issues/new | +# | | +# +-----------------------------------------------+ +# +# Thank you users! We ❤️ you! - Krrish & Ishaan + +# s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation +import os +import inspect +import redis, litellm + +def _get_redis_kwargs(): + arg_spec = inspect.getfullargspec(redis.Redis) + + # Only allow primitive arguments + exclude_args = { + "self", + "connection_pool", + "retry", + } + + + include_args = [ + "url" + ] + + available_args = [ + x for x in arg_spec.args if x not in exclude_args + ] + include_args + + return available_args + +def _get_redis_env_kwarg_mapping(): + PREFIX = "REDIS_" + + return { + f"{PREFIX}{x.upper()}": x for x in _get_redis_kwargs() + } + + +def _redis_kwargs_from_environment(): + mapping = _get_redis_env_kwarg_mapping() + + return_dict = {} + for k, v in mapping.items(): + value = litellm.get_secret(k, default_value=None) # check os.environ/key vault + if value is not None: + return_dict[v] = value + return return_dict + + +def get_redis_url_from_environment(): + if "REDIS_URL" in os.environ: + return os.environ["REDIS_URL"] + + if "REDIS_HOST" not in os.environ or "REDIS_PORT" not in os.environ: + raise ValueError("Either 'REDIS_URL' or both 'REDIS_HOST' and 'REDIS_PORT' must be specified for Redis.") + + if "REDIS_PASSWORD" in os.environ: + redis_password = f":{os.environ['REDIS_PASSWORD']}@" + else: + redis_password = "" + + return f"redis://{redis_password}{os.environ['REDIS_HOST']}:{os.environ['REDIS_PORT']}" + +def get_redis_client(**env_overrides): + redis_kwargs = { + **_redis_kwargs_from_environment(), + **env_overrides, + } + + if "url" in redis_kwargs and redis_kwargs['url'] is not None: + redis_kwargs.pop("host", None) + redis_kwargs.pop("port", None) + redis_kwargs.pop("db", None) + redis_kwargs.pop("password", None) + + return redis.Redis.from_url(**redis_kwargs) + elif "host" not in redis_kwargs or redis_kwargs['host'] is None: + raise ValueError("Either 'host' or 'url' must be specified for redis.") + + return redis.Redis(**redis_kwargs) \ No newline at end of file diff --git a/litellm/caching.py b/litellm/caching.py index d9b94b9586..1b6963cc67 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -69,10 +69,22 @@ class InMemoryCache(BaseCache): class RedisCache(BaseCache): - def __init__(self, host, port, password, **kwargs): + def __init__(self, host=None, port=None, password=None, **kwargs): import redis # if users don't provider one, use the default litellm cache - self.redis_client = redis.Redis(host=host, port=port, password=password, **kwargs) + from ._redis import get_redis_client + + redis_kwargs = {} + if host is not None: + redis_kwargs["host"] = host + if port is not None: + redis_kwargs["port"] = port + if password is not None: + redis_kwargs["password"] = password + + redis_kwargs.update(kwargs) + + self.redis_client = get_redis_client(**redis_kwargs) def set_cache(self, key, value, **kwargs): ttl = kwargs.get("ttl", None) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index ecd4ab8d6d..3f94f90b94 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -477,9 +477,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): print(f"{blue_color_code}\nSetting Cache on Proxy") from litellm.caching import Cache cache_type = value["type"] - cache_host = litellm.get_secret("REDIS_HOST") - cache_port = litellm.get_secret("REDIS_PORT") - cache_password = litellm.get_secret("REDIS_PASSWORD") + cache_host = litellm.get_secret("REDIS_HOST", None) + cache_port = litellm.get_secret("REDIS_PORT", None) + cache_password = litellm.get_secret("REDIS_PASSWORD", None) # Assuming cache_type, cache_host, cache_port, and cache_password are strings print(f"{blue_color_code}Cache Type:{reset_color_code} {cache_type}") @@ -488,6 +488,7 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): print(f"{blue_color_code}Cache Password:{reset_color_code} {cache_password}") print() + ## to pass a complete url, just set it as `os.environ[REDIS_URL] = `, _redis.py checks for REDIS specific environment variables litellm.cache = Cache( type=cache_type, host=cache_host, diff --git a/litellm/router.py b/litellm/router.py index 75fd5afd9e..478b5dd233 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -60,10 +60,14 @@ class Router: def __init__(self, model_list: Optional[list] = None, + ## CACHING ## + redis_url: Optional[str] = None, redis_host: Optional[str] = None, redis_port: Optional[int] = None, redis_password: Optional[str] = None, cache_responses: bool = False, + cache_kwargs: dict = {}, # additional kwargs to pass to RedisCache (see caching.py) + ## RELIABILITY ## num_retries: int = 0, timeout: Optional[float] = None, default_litellm_params = {}, # default params for Router.chat.completion.create @@ -107,21 +111,21 @@ class Router: if self.routing_strategy == "least-busy": self._start_health_check_thread() ### CACHING ### + cache_type = "local" # default to an in-memory cache redis_cache = None - if redis_host is not None and redis_port is not None and redis_password is not None: + cache_config = {} + if redis_url is not None or (redis_host is not None and redis_port is not None and redis_password is not None): + cache_type = "redis" cache_config = { - 'type': 'redis', + 'url': redis_url, 'host': redis_host, 'port': redis_port, - 'password': redis_password - } - redis_cache = RedisCache(host=redis_host, port=redis_port, password=redis_password) - else: # use an in-memory cache - cache_config = { - "type": "local" + 'password': redis_password, + **cache_kwargs } + redis_cache = RedisCache(**cache_config) if cache_responses: - litellm.cache = litellm.Cache(**cache_config) # use Redis for caching completion requests + litellm.cache = litellm.Cache(type=cache_type, **cache_config) self.cache_responses = cache_responses self.cache = DualCache(redis_cache=redis_cache, in_memory_cache=InMemoryCache()) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc. ## USAGE TRACKING ## diff --git a/litellm/utils.py b/litellm/utils.py index 84cd207297..c89e690d72 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4763,23 +4763,27 @@ def litellm_telemetry(data): ######### Secret Manager ############################ # checks if user has passed in a secret manager client # if passed in then checks the secret there -def get_secret(secret_name: str): +def get_secret(secret_name: str, default_value: Optional[str]=None): if secret_name.startswith("os.environ/"): secret_name = secret_name.replace("os.environ/", "") - if litellm.secret_manager_client is not None: - # TODO: check which secret manager is being used - # currently only supports Infisical - try: - client = litellm.secret_manager_client - if type(client).__module__ + '.' + type(client).__name__ == 'azure.keyvault.secrets._client.SecretClient': # support Azure Secret Client - from azure.keyvault.secrets import SecretClient - secret = retrieved_secret = client.get_secret(secret_name).value - else: # assume the default is infisicial client - secret = client.get_secret(secret_name).secret_value - except: # check if it's in os.environ - secret = os.environ.get(secret_name) - return secret - else: - return os.environ.get(secret_name) + try: + if litellm.secret_manager_client is not None: + try: + client = litellm.secret_manager_client + if type(client).__module__ + '.' + type(client).__name__ == 'azure.keyvault.secrets._client.SecretClient': # support Azure Secret Client - from azure.keyvault.secrets import SecretClient + secret = retrieved_secret = client.get_secret(secret_name).value + else: # assume the default is infisicial client + secret = client.get_secret(secret_name).secret_value + except: # check if it's in os.environ + secret = os.environ.get(secret_name) + return secret + else: + return os.environ.get(secret_name) + except Exception as e: + if default_value is not None: + return default_value + else: + raise e ######## Streaming Class ############################