diff --git a/litellm/llms/huggingface_restapi.py b/litellm/llms/huggingface_restapi.py index 8b9d906f2..792a42678 100644 --- a/litellm/llms/huggingface_restapi.py +++ b/litellm/llms/huggingface_restapi.py @@ -448,24 +448,22 @@ class Huggingface(BaseLLM): input_text: str, model: str, optional_params: dict): - if self._aclient_session is None: - self._aclient_session = self.create_aclient_session() - client = self._aclient_session response = None try: - response = await client.post(url=api_base, json=data, headers=headers) - response_json = response.json() - if response.status_code != 200: - raise HuggingfaceError(status_code=response.status_code, message=response.text, request=response.request, response=response) - - ## RESPONSE OBJECT - return self.convert_to_model_response_object(completion_response=response_json, - model_response=model_response, - task=task, - encoding=encoding, - input_text=input_text, - model=model, - optional_params=optional_params) + async with httpx.AsyncClient() as client: + response = await client.post(url=api_base, json=data, headers=headers) + response_json = response.json() + if response.status_code != 200: + raise HuggingfaceError(status_code=response.status_code, message=response.text, request=response.request, response=response) + + ## RESPONSE OBJECT + return self.convert_to_model_response_object(completion_response=response_json, + model_response=model_response, + task=task, + encoding=encoding, + input_text=input_text, + model=model, + optional_params=optional_params) except Exception as e: if isinstance(e,httpx.TimeoutException): raise HuggingfaceError(status_code=500, message="Request Timeout Error") @@ -481,21 +479,20 @@ class Huggingface(BaseLLM): headers: dict, model_response: ModelResponse, model: str): - if self._aclient_session is None: - self._aclient_session = self.create_aclient_session() - client = self._aclient_session - async with client.stream( - url=f"{api_base}", - json=data, - headers=headers, - method="POST" - ) as response: - if response.status_code != 200: - raise HuggingfaceError(status_code=response.status_code, message="An error occurred while streaming") - - streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="huggingface",logging_obj=logging_obj) - async for transformed_chunk in streamwrapper: - yield transformed_chunk + async with httpx.AsyncClient() as client: + response = client.stream( + "POST", + url=f"{api_base}", + json=data, + headers=headers + ) + async with response as r: + if r.status_code != 200: + raise HuggingfaceError(status_code=r.status_code, message="An error occurred while streaming") + + streamwrapper = CustomStreamWrapper(completion_stream=r.aiter_lines(), model=model, custom_llm_provider="huggingface",logging_obj=logging_obj) + async for transformed_chunk in streamwrapper: + yield transformed_chunk def embedding(self, model: str, diff --git a/litellm/tests/test_async_fn.py b/litellm/tests/test_async_fn.py index 6081956f2..920fff3cc 100644 --- a/litellm/tests/test_async_fn.py +++ b/litellm/tests/test_async_fn.py @@ -4,7 +4,9 @@ import sys, os import pytest import traceback -import asyncio +import asyncio, logging +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) sys.path.insert( 0, os.path.abspath("../..") @@ -99,5 +101,3 @@ def test_get_response_non_openai_streaming(): pytest.fail(f"An exception occurred: {e}") return response asyncio.run(test_async_call()) - -test_get_response_non_openai_streaming()