Merge pull request #1989 from BerriAI/litellm_redis_url_fix

fix(redis.py): fix instantiating redis client from url
This commit is contained in:
Krish Dholakia 2024-02-15 21:23:17 -08:00 committed by GitHub
commit 1e238614c8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 44 additions and 17 deletions

View file

@ -98,6 +98,9 @@ def _get_redis_client_logic(**env_overrides):
def get_redis_client(**env_overrides):
redis_kwargs = _get_redis_client_logic(**env_overrides)
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
redis_kwargs.pop(
"connection_pool", None
) # redis.from_url doesn't support setting your own connection pool
return redis.Redis.from_url(**redis_kwargs)
return redis.Redis(**redis_kwargs)
@ -105,6 +108,9 @@ def get_redis_client(**env_overrides):
def get_redis_async_client(**env_overrides):
redis_kwargs = _get_redis_client_logic(**env_overrides)
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
redis_kwargs.pop(
"connection_pool", None
) # redis.from_url doesn't support setting your own connection pool
return async_redis.Redis.from_url(**redis_kwargs)
return async_redis.Redis(
socket_timeout=5,

View file

@ -124,7 +124,7 @@ class RedisCache(BaseCache):
self.redis_client.set(name=key, value=str(value), ex=ttl)
except Exception as e:
# NON blocking - notify users Redis is throwing an exception
logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e)
print_verbose("LiteLLM Caching: set() - Got exception from REDIS : ", e)
async def async_set_cache(self, key, value, **kwargs):
_redis_client = self.init_async_client()
@ -134,10 +134,12 @@ class RedisCache(BaseCache):
f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
)
try:
await redis_client.set(name=key, value=json.dumps(value), ex=ttl)
await redis_client.set(
name=key, value=json.dumps(value), ex=ttl, get=True
)
except Exception as e:
# NON blocking - notify users Redis is throwing an exception
logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e)
print_verbose("LiteLLM Caching: set() - Got exception from REDIS : ", e)
async def async_set_cache_pipeline(self, cache_list, ttl=None):
"""

View file

@ -399,9 +399,12 @@ class Huggingface(BaseLLM):
data = {
"inputs": prompt,
"parameters": optional_params,
"stream": True
if "stream" in optional_params and optional_params["stream"] == True
else False,
"stream": (
True
if "stream" in optional_params
and optional_params["stream"] == True
else False
),
}
input_text = prompt
else:
@ -430,9 +433,12 @@ class Huggingface(BaseLLM):
data = {
"inputs": prompt,
"parameters": inference_params,
"stream": True
if "stream" in optional_params and optional_params["stream"] == True
else False,
"stream": (
True
if "stream" in optional_params
and optional_params["stream"] == True
else False
),
}
input_text = prompt
## LOGGING
@ -561,14 +567,12 @@ class Huggingface(BaseLLM):
input_text: str,
model: str,
optional_params: dict,
timeout: float
timeout: float,
):
response = None
try:
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
url=api_base, json=data, headers=headers
)
response = await client.post(url=api_base, json=data, headers=headers)
response_json = response.json()
if response.status_code != 200:
raise HuggingfaceError(
@ -607,7 +611,7 @@ class Huggingface(BaseLLM):
headers: dict,
model_response: ModelResponse,
model: str,
timeout: float
timeout: float,
):
async with httpx.AsyncClient(timeout=timeout) as client:
response = client.stream(
@ -615,9 +619,10 @@ class Huggingface(BaseLLM):
)
async with response as r:
if r.status_code != 200:
text = await r.aread()
raise HuggingfaceError(
status_code=r.status_code,
message="An error occurred while streaming",
message=str(text),
)
streamwrapper = CustomStreamWrapper(
completion_stream=r.aiter_lines(),
@ -625,8 +630,12 @@ class Huggingface(BaseLLM):
custom_llm_provider="huggingface",
logging_obj=logging_obj,
)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
async def generator():
async for transformed_chunk in streamwrapper:
yield transformed_chunk
return generator()
def embedding(
self,

View file

@ -1031,6 +1031,8 @@ async def test_hf_completion_tgi_stream():
if complete_response.strip() == "":
raise Exception("Empty response received")
print(f"completion_response: {complete_response}")
except litellm.ServiceUnavailableError as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -7029,6 +7029,14 @@ def exception_type(
model=model,
response=original_exception.response,
)
elif original_exception.status_code == 503:
exception_mapping_worked = True
raise ServiceUnavailableError(
message=f"HuggingfaceException - {original_exception.message}",
llm_provider="huggingface",
model=model,
response=original_exception.response,
)
else:
exception_mapping_worked = True
raise APIError(