forked from phoenix/litellm-mirror
Merge pull request #1989 from BerriAI/litellm_redis_url_fix
fix(redis.py): fix instantiating redis client from url
This commit is contained in:
commit
1e238614c8
5 changed files with 44 additions and 17 deletions
|
@ -98,6 +98,9 @@ def _get_redis_client_logic(**env_overrides):
|
|||
def get_redis_client(**env_overrides):
|
||||
redis_kwargs = _get_redis_client_logic(**env_overrides)
|
||||
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
|
||||
redis_kwargs.pop(
|
||||
"connection_pool", None
|
||||
) # redis.from_url doesn't support setting your own connection pool
|
||||
return redis.Redis.from_url(**redis_kwargs)
|
||||
return redis.Redis(**redis_kwargs)
|
||||
|
||||
|
@ -105,6 +108,9 @@ def get_redis_client(**env_overrides):
|
|||
def get_redis_async_client(**env_overrides):
|
||||
redis_kwargs = _get_redis_client_logic(**env_overrides)
|
||||
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
|
||||
redis_kwargs.pop(
|
||||
"connection_pool", None
|
||||
) # redis.from_url doesn't support setting your own connection pool
|
||||
return async_redis.Redis.from_url(**redis_kwargs)
|
||||
return async_redis.Redis(
|
||||
socket_timeout=5,
|
||||
|
|
|
@ -124,7 +124,7 @@ class RedisCache(BaseCache):
|
|||
self.redis_client.set(name=key, value=str(value), ex=ttl)
|
||||
except Exception as e:
|
||||
# NON blocking - notify users Redis is throwing an exception
|
||||
logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e)
|
||||
print_verbose("LiteLLM Caching: set() - Got exception from REDIS : ", e)
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
_redis_client = self.init_async_client()
|
||||
|
@ -134,10 +134,12 @@ class RedisCache(BaseCache):
|
|||
f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
|
||||
)
|
||||
try:
|
||||
await redis_client.set(name=key, value=json.dumps(value), ex=ttl)
|
||||
await redis_client.set(
|
||||
name=key, value=json.dumps(value), ex=ttl, get=True
|
||||
)
|
||||
except Exception as e:
|
||||
# NON blocking - notify users Redis is throwing an exception
|
||||
logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e)
|
||||
print_verbose("LiteLLM Caching: set() - Got exception from REDIS : ", e)
|
||||
|
||||
async def async_set_cache_pipeline(self, cache_list, ttl=None):
|
||||
"""
|
||||
|
|
|
@ -399,9 +399,12 @@ class Huggingface(BaseLLM):
|
|||
data = {
|
||||
"inputs": prompt,
|
||||
"parameters": optional_params,
|
||||
"stream": True
|
||||
if "stream" in optional_params and optional_params["stream"] == True
|
||||
else False,
|
||||
"stream": (
|
||||
True
|
||||
if "stream" in optional_params
|
||||
and optional_params["stream"] == True
|
||||
else False
|
||||
),
|
||||
}
|
||||
input_text = prompt
|
||||
else:
|
||||
|
@ -430,9 +433,12 @@ class Huggingface(BaseLLM):
|
|||
data = {
|
||||
"inputs": prompt,
|
||||
"parameters": inference_params,
|
||||
"stream": True
|
||||
if "stream" in optional_params and optional_params["stream"] == True
|
||||
else False,
|
||||
"stream": (
|
||||
True
|
||||
if "stream" in optional_params
|
||||
and optional_params["stream"] == True
|
||||
else False
|
||||
),
|
||||
}
|
||||
input_text = prompt
|
||||
## LOGGING
|
||||
|
@ -561,14 +567,12 @@ class Huggingface(BaseLLM):
|
|||
input_text: str,
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
timeout: float
|
||||
timeout: float,
|
||||
):
|
||||
response = None
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.post(
|
||||
url=api_base, json=data, headers=headers
|
||||
)
|
||||
response = await client.post(url=api_base, json=data, headers=headers)
|
||||
response_json = response.json()
|
||||
if response.status_code != 200:
|
||||
raise HuggingfaceError(
|
||||
|
@ -607,7 +611,7 @@ class Huggingface(BaseLLM):
|
|||
headers: dict,
|
||||
model_response: ModelResponse,
|
||||
model: str,
|
||||
timeout: float
|
||||
timeout: float,
|
||||
):
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = client.stream(
|
||||
|
@ -615,9 +619,10 @@ class Huggingface(BaseLLM):
|
|||
)
|
||||
async with response as r:
|
||||
if r.status_code != 200:
|
||||
text = await r.aread()
|
||||
raise HuggingfaceError(
|
||||
status_code=r.status_code,
|
||||
message="An error occurred while streaming",
|
||||
message=str(text),
|
||||
)
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=r.aiter_lines(),
|
||||
|
@ -625,8 +630,12 @@ class Huggingface(BaseLLM):
|
|||
custom_llm_provider="huggingface",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
async for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
|
||||
async def generator():
|
||||
async for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
|
||||
return generator()
|
||||
|
||||
def embedding(
|
||||
self,
|
||||
|
|
|
@ -1031,6 +1031,8 @@ async def test_hf_completion_tgi_stream():
|
|||
if complete_response.strip() == "":
|
||||
raise Exception("Empty response received")
|
||||
print(f"completion_response: {complete_response}")
|
||||
except litellm.ServiceUnavailableError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
|
|
@ -7029,6 +7029,14 @@ def exception_type(
|
|||
model=model,
|
||||
response=original_exception.response,
|
||||
)
|
||||
elif original_exception.status_code == 503:
|
||||
exception_mapping_worked = True
|
||||
raise ServiceUnavailableError(
|
||||
message=f"HuggingfaceException - {original_exception.message}",
|
||||
llm_provider="huggingface",
|
||||
model=model,
|
||||
response=original_exception.response,
|
||||
)
|
||||
else:
|
||||
exception_mapping_worked = True
|
||||
raise APIError(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue