fix(caching.py): don't decode a string

This commit is contained in:
Krrish Dholakia 2024-04-13 18:48:03 -07:00
parent 9f42d15713
commit bef24cd4ab
3 changed files with 16 additions and 6 deletions

View file

@ -505,10 +505,12 @@ class RedisCache(BaseCache):
# 'results' is a list of values corresponding to the order of keys in 'key_list'. # 'results' is a list of values corresponding to the order of keys in 'key_list'.
key_value_dict = dict(zip(key_list, results)) key_value_dict = dict(zip(key_list, results))
decoded_results = { decoded_results = {}
k.decode("utf-8"): self._get_cache_logic(v) for k, v in key_value_dict.items():
for k, v in key_value_dict.items() if isinstance(k, bytes):
} k = k.decode("utf-8")
v = self._get_cache_logic(v)
decoded_results[k] = v
return decoded_results return decoded_results
except Exception as e: except Exception as e:

View file

@ -58,7 +58,6 @@ class PrometheusServicesLogger:
def is_metric_registered(self, metric_name) -> bool: def is_metric_registered(self, metric_name) -> bool:
for metric in self.REGISTRY.collect(): for metric in self.REGISTRY.collect():
print(f"metric name: {metric.name}")
if metric_name == metric.name: if metric_name == metric.name:
return True return True
return False return False
@ -82,7 +81,7 @@ class PrometheusServicesLogger:
) )
def create_counter(self, label: str): def create_counter(self, label: str):
metric_name = "litellm_{}_requests".format(label) metric_name = "litellm_{}_failed_requests".format(label)
is_registered = self.is_metric_registered(metric_name) is_registered = self.is_metric_registered(metric_name)
if is_registered: if is_registered:
return self.get_metric(metric_name) return self.get_metric(metric_name)
@ -158,6 +157,7 @@ class PrometheusServicesLogger:
) )
async def async_service_failure_hook(self, payload: ServiceLoggerPayload): async def async_service_failure_hook(self, payload: ServiceLoggerPayload):
print(f"received error payload: {payload.error}")
if self.mock_testing: if self.mock_testing:
self.mock_testing_failure_calls += 1 self.mock_testing_failure_calls += 1

View file

@ -54,6 +54,8 @@ async def test_completion_with_caching():
assert sl.mock_testing_async_success_hook > 0 assert sl.mock_testing_async_success_hook > 0
assert sl.prometheusServicesLogger.mock_testing_success_calls > 0 assert sl.prometheusServicesLogger.mock_testing_success_calls > 0
assert sl.mock_testing_sync_failure_hook == 0
assert sl.mock_testing_async_failure_hook == 0
@pytest.mark.asyncio @pytest.mark.asyncio
@ -81,6 +83,8 @@ async def test_completion_with_caching_bad_call():
pass pass
assert sl.mock_testing_async_failure_hook > 0 assert sl.mock_testing_async_failure_hook > 0
assert sl.mock_testing_async_success_hook == 0
assert sl.mock_testing_sync_success_hook == 0
@pytest.mark.asyncio @pytest.mark.asyncio
@ -134,6 +138,8 @@ async def test_router_with_caching():
response1 = await router.acompletion(model="azure/gpt-4", messages=messages) response1 = await router.acompletion(model="azure/gpt-4", messages=messages)
assert sl.mock_testing_async_success_hook > 0 assert sl.mock_testing_async_success_hook > 0
assert sl.mock_testing_sync_failure_hook == 0
assert sl.mock_testing_async_failure_hook == 0
assert sl.prometheusServicesLogger.mock_testing_success_calls > 0 assert sl.prometheusServicesLogger.mock_testing_success_calls > 0
except Exception as e: except Exception as e:
@ -194,6 +200,8 @@ async def test_router_with_caching_bad_call():
pass pass
assert sl.mock_testing_async_failure_hook > 0 assert sl.mock_testing_async_failure_hook > 0
assert sl.mock_testing_async_success_hook == 0
assert sl.mock_testing_sync_success_hook == 0
except Exception as e: except Exception as e:
pytest.fail(f"An exception occured - {str(e)}") pytest.fail(f"An exception occured - {str(e)}")