fix(huggingface_restapi.py): async implementation

This commit is contained in:
Krrish Dholakia 2023-11-15 16:54:08 -08:00
parent cc955fca89
commit 03efc9185e
2 changed files with 31 additions and 34 deletions

View file

@ -448,24 +448,22 @@ class Huggingface(BaseLLM):
input_text: str, input_text: str,
model: str, model: str,
optional_params: dict): optional_params: dict):
if self._aclient_session is None:
self._aclient_session = self.create_aclient_session()
client = self._aclient_session
response = None response = None
try: try:
response = await client.post(url=api_base, json=data, headers=headers) async with httpx.AsyncClient() as client:
response_json = response.json() response = await client.post(url=api_base, json=data, headers=headers)
if response.status_code != 200: response_json = response.json()
raise HuggingfaceError(status_code=response.status_code, message=response.text, request=response.request, response=response) if response.status_code != 200:
raise HuggingfaceError(status_code=response.status_code, message=response.text, request=response.request, response=response)
## RESPONSE OBJECT
return self.convert_to_model_response_object(completion_response=response_json, ## RESPONSE OBJECT
model_response=model_response, return self.convert_to_model_response_object(completion_response=response_json,
task=task, model_response=model_response,
encoding=encoding, task=task,
input_text=input_text, encoding=encoding,
model=model, input_text=input_text,
optional_params=optional_params) model=model,
optional_params=optional_params)
except Exception as e: except Exception as e:
if isinstance(e,httpx.TimeoutException): if isinstance(e,httpx.TimeoutException):
raise HuggingfaceError(status_code=500, message="Request Timeout Error") raise HuggingfaceError(status_code=500, message="Request Timeout Error")
@ -481,21 +479,20 @@ class Huggingface(BaseLLM):
headers: dict, headers: dict,
model_response: ModelResponse, model_response: ModelResponse,
model: str): model: str):
if self._aclient_session is None: async with httpx.AsyncClient() as client:
self._aclient_session = self.create_aclient_session() response = client.stream(
client = self._aclient_session "POST",
async with client.stream( url=f"{api_base}",
url=f"{api_base}", json=data,
json=data, headers=headers
headers=headers, )
method="POST" async with response as r:
) as response: if r.status_code != 200:
if response.status_code != 200: raise HuggingfaceError(status_code=r.status_code, message="An error occurred while streaming")
raise HuggingfaceError(status_code=response.status_code, message="An error occurred while streaming")
streamwrapper = CustomStreamWrapper(completion_stream=r.aiter_lines(), model=model, custom_llm_provider="huggingface",logging_obj=logging_obj)
streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="huggingface",logging_obj=logging_obj) async for transformed_chunk in streamwrapper:
async for transformed_chunk in streamwrapper: yield transformed_chunk
yield transformed_chunk
def embedding(self, def embedding(self,
model: str, model: str,

View file

@ -4,7 +4,9 @@
import sys, os import sys, os
import pytest import pytest
import traceback import traceback
import asyncio import asyncio, logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
@ -99,5 +101,3 @@ def test_get_response_non_openai_streaming():
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
return response return response
asyncio.run(test_async_call()) asyncio.run(test_async_call())
test_get_response_non_openai_streaming()