fix(caching.py): remove url parsing logic - causing redis ssl connections to fail

this reverts a change that was causing redis url w/ ssl to fail. this also adds unit testing for this sc
enario, to prevent future regressions
This commit is contained in:
Krrish Dholakia 2024-04-19 14:01:13 -07:00
parent 9dc0871023
commit 7065e4ee12
4 changed files with 84 additions and 14 deletions

View file

@ -154,13 +154,6 @@ class RedisCache(BaseCache):
self.redis_kwargs = redis_kwargs self.redis_kwargs = redis_kwargs
self.async_redis_conn_pool = get_redis_connection_pool(**redis_kwargs) self.async_redis_conn_pool = get_redis_connection_pool(**redis_kwargs)
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
parsed_kwargs = redis.connection.parse_url(redis_kwargs["url"])
redis_kwargs.update(parsed_kwargs)
self.redis_kwargs.update(parsed_kwargs)
# pop url
self.redis_kwargs.pop("url")
# redis namespaces # redis namespaces
self.namespace = namespace self.namespace = namespace
# for high traffic, we store the redis results in memory and then batch write to redis # for high traffic, we store the redis results in memory and then batch write to redis
@ -175,6 +168,12 @@ class RedisCache(BaseCache):
### HEALTH MONITORING OBJECT ### ### HEALTH MONITORING OBJECT ###
self.service_logger_obj = ServiceLogging() self.service_logger_obj = ServiceLogging()
### ASYNC HEALTH PING ###
try:
asyncio.get_running_loop().create_task(self.ping())
except Exception:
pass
def init_async_client(self): def init_async_client(self):
from ._redis import get_redis_async_client from ._redis import get_redis_async_client
@ -601,13 +600,31 @@ class RedisCache(BaseCache):
print_verbose(f"Error occurred in pipeline read - {str(e)}") print_verbose(f"Error occurred in pipeline read - {str(e)}")
return key_value_dict return key_value_dict
async def ping(self): def sync_ping(self) -> bool:
"""
Tests if the sync redis client is correctly setup.
"""
print_verbose(f"Pinging Async Redis Cache")
try:
response = self.redis_client.ping()
print_verbose(f"Redis Cache PING: {response}")
return response
except Exception as e:
# NON blocking - notify users Redis is throwing an exception
print_verbose(
f"LiteLLM Redis Cache PING: - Got exception from REDIS : {str(e)}"
)
traceback.print_exc()
raise e
async def ping(self) -> bool:
_redis_client = self.init_async_client() _redis_client = self.init_async_client()
async with _redis_client as redis_client: async with _redis_client as redis_client:
print_verbose(f"Pinging Async Redis Cache") print_verbose(f"Pinging Async Redis Cache")
try: try:
response = await redis_client.ping() response = await redis_client.ping()
print_verbose(f"Redis Cache PING: {response}") print_verbose(f"Redis Cache PING: {response}")
return response
except Exception as e: except Exception as e:
# NON blocking - notify users Redis is throwing an exception # NON blocking - notify users Redis is throwing an exception
print_verbose( print_verbose(

View file

@ -67,7 +67,7 @@ class PrometheusLogger:
# unpack kwargs # unpack kwargs
model = kwargs.get("model", "") model = kwargs.get("model", "")
response_cost = kwargs.get("response_cost", 0.0) response_cost = kwargs.get("response_cost", 0.0) or 0
litellm_params = kwargs.get("litellm_params", {}) or {} litellm_params = kwargs.get("litellm_params", {}) or {}
proxy_server_request = litellm_params.get("proxy_server_request") or {} proxy_server_request = litellm_params.get("proxy_server_request") or {}
end_user_id = proxy_server_request.get("body", {}).get("user", None) end_user_id = proxy_server_request.get("body", {}).get("user", None)

View file

@ -3,8 +3,8 @@ model_list:
litellm_params: litellm_params:
model: openai/my-fake-model model: openai/my-fake-model
api_key: my-fake-key api_key: my-fake-key
# api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/ api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
api_base: http://0.0.0.0:8080 # api_base: http://0.0.0.0:8080
stream_timeout: 0.001 stream_timeout: 0.001
rpm: 10 rpm: 10
- litellm_params: - litellm_params:
@ -33,9 +33,7 @@ litellm_settings:
router_settings: router_settings:
routing_strategy: usage-based-routing-v2 routing_strategy: usage-based-routing-v2
redis_host: os.environ/REDIS_HOST redis_url: "rediss://:073f655645b843c4839329aea8384e68@us1-great-lizard-40486.upstash.io:40486/0"
redis_password: os.environ/REDIS_PASSWORD
redis_port: os.environ/REDIS_PORT
enable_pre_call_checks: True enable_pre_call_checks: True
general_settings: general_settings:

View file

@ -15,6 +15,61 @@ from litellm import Router
## 2. 2 models - openai, azure - 2 diff model groups, 1 caching group ## 2. 2 models - openai, azure - 2 diff model groups, 1 caching group
@pytest.mark.asyncio
async def test_router_async_caching_with_ssl_url():
"""
Tests when a redis url is passed to the router, if caching is correctly setup
"""
try:
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo-0613",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 100000,
"rpm": 10000,
},
],
redis_url=os.getenv("REDIS_URL"),
)
response = await router.cache.redis_cache.ping()
print(f"response: {response}")
assert response == True
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")
def test_router_sync_caching_with_ssl_url():
"""
Tests when a redis url is passed to the router, if caching is correctly setup
"""
try:
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo-0613",
"api_key": os.getenv("OPENAI_API_KEY"),
},
"tpm": 100000,
"rpm": 10000,
},
],
redis_url=os.getenv("REDIS_URL"),
)
response = router.cache.redis_cache.sync_ping()
print(f"response: {response}")
assert response == True
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_acompletion_caching_on_router(): async def test_acompletion_caching_on_router():
# tests acompletion + caching on router # tests acompletion + caching on router