mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
fix(_redis.py): support additional params for redis
This commit is contained in:
parent
e615f2670a
commit
88c95ca259
5 changed files with 135 additions and 29 deletions
85
litellm/_redis.py
Normal file
85
litellm/_redis.py
Normal file
|
@ -0,0 +1,85 @@
|
||||||
|
# +-----------------------------------------------+
|
||||||
|
# | |
|
||||||
|
# | Give Feedback / Get Help |
|
||||||
|
# | https://github.com/BerriAI/litellm/issues/new |
|
||||||
|
# | |
|
||||||
|
# +-----------------------------------------------+
|
||||||
|
#
|
||||||
|
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||||
|
|
||||||
|
# s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation
|
||||||
|
import os
|
||||||
|
import inspect
|
||||||
|
import redis, litellm
|
||||||
|
|
||||||
|
def _get_redis_kwargs():
|
||||||
|
arg_spec = inspect.getfullargspec(redis.Redis)
|
||||||
|
|
||||||
|
# Only allow primitive arguments
|
||||||
|
exclude_args = {
|
||||||
|
"self",
|
||||||
|
"connection_pool",
|
||||||
|
"retry",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
include_args = [
|
||||||
|
"url"
|
||||||
|
]
|
||||||
|
|
||||||
|
available_args = [
|
||||||
|
x for x in arg_spec.args if x not in exclude_args
|
||||||
|
] + include_args
|
||||||
|
|
||||||
|
return available_args
|
||||||
|
|
||||||
|
def _get_redis_env_kwarg_mapping():
|
||||||
|
PREFIX = "REDIS_"
|
||||||
|
|
||||||
|
return {
|
||||||
|
f"{PREFIX}{x.upper()}": x for x in _get_redis_kwargs()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _redis_kwargs_from_environment():
|
||||||
|
mapping = _get_redis_env_kwarg_mapping()
|
||||||
|
|
||||||
|
return_dict = {}
|
||||||
|
for k, v in mapping.items():
|
||||||
|
value = litellm.get_secret(k, default_value=None) # check os.environ/key vault
|
||||||
|
if value is not None:
|
||||||
|
return_dict[v] = value
|
||||||
|
return return_dict
|
||||||
|
|
||||||
|
|
||||||
|
def get_redis_url_from_environment():
|
||||||
|
if "REDIS_URL" in os.environ:
|
||||||
|
return os.environ["REDIS_URL"]
|
||||||
|
|
||||||
|
if "REDIS_HOST" not in os.environ or "REDIS_PORT" not in os.environ:
|
||||||
|
raise ValueError("Either 'REDIS_URL' or both 'REDIS_HOST' and 'REDIS_PORT' must be specified for Redis.")
|
||||||
|
|
||||||
|
if "REDIS_PASSWORD" in os.environ:
|
||||||
|
redis_password = f":{os.environ['REDIS_PASSWORD']}@"
|
||||||
|
else:
|
||||||
|
redis_password = ""
|
||||||
|
|
||||||
|
return f"redis://{redis_password}{os.environ['REDIS_HOST']}:{os.environ['REDIS_PORT']}"
|
||||||
|
|
||||||
|
def get_redis_client(**env_overrides):
|
||||||
|
redis_kwargs = {
|
||||||
|
**_redis_kwargs_from_environment(),
|
||||||
|
**env_overrides,
|
||||||
|
}
|
||||||
|
|
||||||
|
if "url" in redis_kwargs and redis_kwargs['url'] is not None:
|
||||||
|
redis_kwargs.pop("host", None)
|
||||||
|
redis_kwargs.pop("port", None)
|
||||||
|
redis_kwargs.pop("db", None)
|
||||||
|
redis_kwargs.pop("password", None)
|
||||||
|
|
||||||
|
return redis.Redis.from_url(**redis_kwargs)
|
||||||
|
elif "host" not in redis_kwargs or redis_kwargs['host'] is None:
|
||||||
|
raise ValueError("Either 'host' or 'url' must be specified for redis.")
|
||||||
|
|
||||||
|
return redis.Redis(**redis_kwargs)
|
|
@ -69,10 +69,22 @@ class InMemoryCache(BaseCache):
|
||||||
|
|
||||||
|
|
||||||
class RedisCache(BaseCache):
|
class RedisCache(BaseCache):
|
||||||
def __init__(self, host, port, password, **kwargs):
|
def __init__(self, host=None, port=None, password=None, **kwargs):
|
||||||
import redis
|
import redis
|
||||||
# if users don't provider one, use the default litellm cache
|
# if users don't provider one, use the default litellm cache
|
||||||
self.redis_client = redis.Redis(host=host, port=port, password=password, **kwargs)
|
from ._redis import get_redis_client
|
||||||
|
|
||||||
|
redis_kwargs = {}
|
||||||
|
if host is not None:
|
||||||
|
redis_kwargs["host"] = host
|
||||||
|
if port is not None:
|
||||||
|
redis_kwargs["port"] = port
|
||||||
|
if password is not None:
|
||||||
|
redis_kwargs["password"] = password
|
||||||
|
|
||||||
|
redis_kwargs.update(kwargs)
|
||||||
|
|
||||||
|
self.redis_client = get_redis_client(**redis_kwargs)
|
||||||
|
|
||||||
def set_cache(self, key, value, **kwargs):
|
def set_cache(self, key, value, **kwargs):
|
||||||
ttl = kwargs.get("ttl", None)
|
ttl = kwargs.get("ttl", None)
|
||||||
|
|
|
@ -477,9 +477,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
print(f"{blue_color_code}\nSetting Cache on Proxy")
|
print(f"{blue_color_code}\nSetting Cache on Proxy")
|
||||||
from litellm.caching import Cache
|
from litellm.caching import Cache
|
||||||
cache_type = value["type"]
|
cache_type = value["type"]
|
||||||
cache_host = litellm.get_secret("REDIS_HOST")
|
cache_host = litellm.get_secret("REDIS_HOST", None)
|
||||||
cache_port = litellm.get_secret("REDIS_PORT")
|
cache_port = litellm.get_secret("REDIS_PORT", None)
|
||||||
cache_password = litellm.get_secret("REDIS_PASSWORD")
|
cache_password = litellm.get_secret("REDIS_PASSWORD", None)
|
||||||
|
|
||||||
# Assuming cache_type, cache_host, cache_port, and cache_password are strings
|
# Assuming cache_type, cache_host, cache_port, and cache_password are strings
|
||||||
print(f"{blue_color_code}Cache Type:{reset_color_code} {cache_type}")
|
print(f"{blue_color_code}Cache Type:{reset_color_code} {cache_type}")
|
||||||
|
@ -488,6 +488,7 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
print(f"{blue_color_code}Cache Password:{reset_color_code} {cache_password}")
|
print(f"{blue_color_code}Cache Password:{reset_color_code} {cache_password}")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
## to pass a complete url, just set it as `os.environ[REDIS_URL] = <your-redis-url>`, _redis.py checks for REDIS specific environment variables
|
||||||
litellm.cache = Cache(
|
litellm.cache = Cache(
|
||||||
type=cache_type,
|
type=cache_type,
|
||||||
host=cache_host,
|
host=cache_host,
|
||||||
|
|
|
@ -60,10 +60,14 @@ class Router:
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model_list: Optional[list] = None,
|
model_list: Optional[list] = None,
|
||||||
|
## CACHING ##
|
||||||
|
redis_url: Optional[str] = None,
|
||||||
redis_host: Optional[str] = None,
|
redis_host: Optional[str] = None,
|
||||||
redis_port: Optional[int] = None,
|
redis_port: Optional[int] = None,
|
||||||
redis_password: Optional[str] = None,
|
redis_password: Optional[str] = None,
|
||||||
cache_responses: bool = False,
|
cache_responses: bool = False,
|
||||||
|
cache_kwargs: dict = {}, # additional kwargs to pass to RedisCache (see caching.py)
|
||||||
|
## RELIABILITY ##
|
||||||
num_retries: int = 0,
|
num_retries: int = 0,
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
default_litellm_params = {}, # default params for Router.chat.completion.create
|
default_litellm_params = {}, # default params for Router.chat.completion.create
|
||||||
|
@ -107,21 +111,21 @@ class Router:
|
||||||
if self.routing_strategy == "least-busy":
|
if self.routing_strategy == "least-busy":
|
||||||
self._start_health_check_thread()
|
self._start_health_check_thread()
|
||||||
### CACHING ###
|
### CACHING ###
|
||||||
|
cache_type = "local" # default to an in-memory cache
|
||||||
redis_cache = None
|
redis_cache = None
|
||||||
if redis_host is not None and redis_port is not None and redis_password is not None:
|
cache_config = {}
|
||||||
|
if redis_url is not None or (redis_host is not None and redis_port is not None and redis_password is not None):
|
||||||
|
cache_type = "redis"
|
||||||
cache_config = {
|
cache_config = {
|
||||||
'type': 'redis',
|
'url': redis_url,
|
||||||
'host': redis_host,
|
'host': redis_host,
|
||||||
'port': redis_port,
|
'port': redis_port,
|
||||||
'password': redis_password
|
'password': redis_password,
|
||||||
}
|
**cache_kwargs
|
||||||
redis_cache = RedisCache(host=redis_host, port=redis_port, password=redis_password)
|
|
||||||
else: # use an in-memory cache
|
|
||||||
cache_config = {
|
|
||||||
"type": "local"
|
|
||||||
}
|
}
|
||||||
|
redis_cache = RedisCache(**cache_config)
|
||||||
if cache_responses:
|
if cache_responses:
|
||||||
litellm.cache = litellm.Cache(**cache_config) # use Redis for caching completion requests
|
litellm.cache = litellm.Cache(type=cache_type, **cache_config)
|
||||||
self.cache_responses = cache_responses
|
self.cache_responses = cache_responses
|
||||||
self.cache = DualCache(redis_cache=redis_cache, in_memory_cache=InMemoryCache()) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
|
self.cache = DualCache(redis_cache=redis_cache, in_memory_cache=InMemoryCache()) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
|
||||||
## USAGE TRACKING ##
|
## USAGE TRACKING ##
|
||||||
|
|
|
@ -4763,23 +4763,27 @@ def litellm_telemetry(data):
|
||||||
######### Secret Manager ############################
|
######### Secret Manager ############################
|
||||||
# checks if user has passed in a secret manager client
|
# checks if user has passed in a secret manager client
|
||||||
# if passed in then checks the secret there
|
# if passed in then checks the secret there
|
||||||
def get_secret(secret_name: str):
|
def get_secret(secret_name: str, default_value: Optional[str]=None):
|
||||||
if secret_name.startswith("os.environ/"):
|
if secret_name.startswith("os.environ/"):
|
||||||
secret_name = secret_name.replace("os.environ/", "")
|
secret_name = secret_name.replace("os.environ/", "")
|
||||||
if litellm.secret_manager_client is not None:
|
try:
|
||||||
# TODO: check which secret manager is being used
|
if litellm.secret_manager_client is not None:
|
||||||
# currently only supports Infisical
|
try:
|
||||||
try:
|
client = litellm.secret_manager_client
|
||||||
client = litellm.secret_manager_client
|
if type(client).__module__ + '.' + type(client).__name__ == 'azure.keyvault.secrets._client.SecretClient': # support Azure Secret Client - from azure.keyvault.secrets import SecretClient
|
||||||
if type(client).__module__ + '.' + type(client).__name__ == 'azure.keyvault.secrets._client.SecretClient': # support Azure Secret Client - from azure.keyvault.secrets import SecretClient
|
secret = retrieved_secret = client.get_secret(secret_name).value
|
||||||
secret = retrieved_secret = client.get_secret(secret_name).value
|
else: # assume the default is infisicial client
|
||||||
else: # assume the default is infisicial client
|
secret = client.get_secret(secret_name).secret_value
|
||||||
secret = client.get_secret(secret_name).secret_value
|
except: # check if it's in os.environ
|
||||||
except: # check if it's in os.environ
|
secret = os.environ.get(secret_name)
|
||||||
secret = os.environ.get(secret_name)
|
return secret
|
||||||
return secret
|
else:
|
||||||
else:
|
return os.environ.get(secret_name)
|
||||||
return os.environ.get(secret_name)
|
except Exception as e:
|
||||||
|
if default_value is not None:
|
||||||
|
return default_value
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
######## Streaming Class ############################
|
######## Streaming Class ############################
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue