diff --git a/litellm/_redis.py b/litellm/_redis.py index 4484926d4..69ff6f3f2 100644 --- a/litellm/_redis.py +++ b/litellm/_redis.py @@ -98,6 +98,9 @@ def _get_redis_client_logic(**env_overrides): def get_redis_client(**env_overrides): redis_kwargs = _get_redis_client_logic(**env_overrides) if "url" in redis_kwargs and redis_kwargs["url"] is not None: + redis_kwargs.pop( + "connection_pool", None + ) # redis.from_url doesn't support setting your own connection pool return redis.Redis.from_url(**redis_kwargs) return redis.Redis(**redis_kwargs) @@ -105,6 +108,9 @@ def get_redis_client(**env_overrides): def get_redis_async_client(**env_overrides): redis_kwargs = _get_redis_client_logic(**env_overrides) if "url" in redis_kwargs and redis_kwargs["url"] is not None: + redis_kwargs.pop( + "connection_pool", None + ) # redis.from_url doesn't support setting your own connection pool return async_redis.Redis.from_url(**redis_kwargs) return async_redis.Redis( socket_timeout=5, diff --git a/litellm/caching.py b/litellm/caching.py index 564972068..567b9aadb 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -124,7 +124,7 @@ class RedisCache(BaseCache): self.redis_client.set(name=key, value=str(value), ex=ttl) except Exception as e: # NON blocking - notify users Redis is throwing an exception - logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e) + print_verbose("LiteLLM Caching: set() - Got exception from REDIS : ", e) async def async_set_cache(self, key, value, **kwargs): _redis_client = self.init_async_client() @@ -134,10 +134,12 @@ class RedisCache(BaseCache): f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}" ) try: - await redis_client.set(name=key, value=json.dumps(value), ex=ttl) + await redis_client.set( + name=key, value=json.dumps(value), ex=ttl, get=True + ) except Exception as e: # NON blocking - notify users Redis is throwing an exception - logging.debug("LiteLLM Caching: set() - Got exception from REDIS : ", e) + print_verbose("LiteLLM Caching: set() - Got exception from REDIS : ", e) async def async_set_cache_pipeline(self, cache_list, ttl=None): """ diff --git a/litellm/llms/huggingface_restapi.py b/litellm/llms/huggingface_restapi.py index eb8ce38b9..e66627ccc 100644 --- a/litellm/llms/huggingface_restapi.py +++ b/litellm/llms/huggingface_restapi.py @@ -399,9 +399,12 @@ class Huggingface(BaseLLM): data = { "inputs": prompt, "parameters": optional_params, - "stream": True - if "stream" in optional_params and optional_params["stream"] == True - else False, + "stream": ( + True + if "stream" in optional_params + and optional_params["stream"] == True + else False + ), } input_text = prompt else: @@ -430,9 +433,12 @@ class Huggingface(BaseLLM): data = { "inputs": prompt, "parameters": inference_params, - "stream": True - if "stream" in optional_params and optional_params["stream"] == True - else False, + "stream": ( + True + if "stream" in optional_params + and optional_params["stream"] == True + else False + ), } input_text = prompt ## LOGGING @@ -561,14 +567,12 @@ class Huggingface(BaseLLM): input_text: str, model: str, optional_params: dict, - timeout: float + timeout: float, ): response = None try: async with httpx.AsyncClient(timeout=timeout) as client: - response = await client.post( - url=api_base, json=data, headers=headers - ) + response = await client.post(url=api_base, json=data, headers=headers) response_json = response.json() if response.status_code != 200: raise HuggingfaceError( @@ -607,7 +611,7 @@ class Huggingface(BaseLLM): headers: dict, model_response: ModelResponse, model: str, - timeout: float + timeout: float, ): async with httpx.AsyncClient(timeout=timeout) as client: response = client.stream( @@ -615,9 +619,10 @@ class Huggingface(BaseLLM): ) async with response as r: if r.status_code != 200: + text = await r.aread() raise HuggingfaceError( status_code=r.status_code, - message="An error occurred while streaming", + message=str(text), ) streamwrapper = CustomStreamWrapper( completion_stream=r.aiter_lines(), @@ -625,8 +630,12 @@ class Huggingface(BaseLLM): custom_llm_provider="huggingface", logging_obj=logging_obj, ) - async for transformed_chunk in streamwrapper: - yield transformed_chunk + + async def generator(): + async for transformed_chunk in streamwrapper: + yield transformed_chunk + + return generator() def embedding( self, diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index a5497b539..30d777d79 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -1031,6 +1031,8 @@ async def test_hf_completion_tgi_stream(): if complete_response.strip() == "": raise Exception("Empty response received") print(f"completion_response: {complete_response}") + except litellm.ServiceUnavailableError as e: + pass except Exception as e: pytest.fail(f"Error occurred: {e}") diff --git a/litellm/utils.py b/litellm/utils.py index d3efcea73..01a7b37b5 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -7029,6 +7029,14 @@ def exception_type( model=model, response=original_exception.response, ) + elif original_exception.status_code == 503: + exception_mapping_worked = True + raise ServiceUnavailableError( + message=f"HuggingfaceException - {original_exception.message}", + llm_provider="huggingface", + model=model, + response=original_exception.response, + ) else: exception_mapping_worked = True raise APIError(