fix(huggingface_restapi.py): async implementation

This commit is contained in:
Krrish Dholakia 2023-11-15 16:54:08 -08:00
parent cc955fca89
commit 03efc9185e
2 changed files with 31 additions and 34 deletions

View file

@ -448,24 +448,22 @@ class Huggingface(BaseLLM):
input_text: str,
model: str,
optional_params: dict):
if self._aclient_session is None:
self._aclient_session = self.create_aclient_session()
client = self._aclient_session
response = None
try:
response = await client.post(url=api_base, json=data, headers=headers)
response_json = response.json()
if response.status_code != 200:
raise HuggingfaceError(status_code=response.status_code, message=response.text, request=response.request, response=response)
## RESPONSE OBJECT
return self.convert_to_model_response_object(completion_response=response_json,
model_response=model_response,
task=task,
encoding=encoding,
input_text=input_text,
model=model,
optional_params=optional_params)
async with httpx.AsyncClient() as client:
response = await client.post(url=api_base, json=data, headers=headers)
response_json = response.json()
if response.status_code != 200:
raise HuggingfaceError(status_code=response.status_code, message=response.text, request=response.request, response=response)
## RESPONSE OBJECT
return self.convert_to_model_response_object(completion_response=response_json,
model_response=model_response,
task=task,
encoding=encoding,
input_text=input_text,
model=model,
optional_params=optional_params)
except Exception as e:
if isinstance(e,httpx.TimeoutException):
raise HuggingfaceError(status_code=500, message="Request Timeout Error")
@ -481,21 +479,20 @@ class Huggingface(BaseLLM):
headers: dict,
model_response: ModelResponse,
model: str):
if self._aclient_session is None:
self._aclient_session = self.create_aclient_session()
client = self._aclient_session
async with client.stream(
url=f"{api_base}",
json=data,
headers=headers,
method="POST"
) as response:
if response.status_code != 200:
raise HuggingfaceError(status_code=response.status_code, message="An error occurred while streaming")
streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="huggingface",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
async with httpx.AsyncClient() as client:
response = client.stream(
"POST",
url=f"{api_base}",
json=data,
headers=headers
)
async with response as r:
if r.status_code != 200:
raise HuggingfaceError(status_code=r.status_code, message="An error occurred while streaming")
streamwrapper = CustomStreamWrapper(completion_stream=r.aiter_lines(), model=model, custom_llm_provider="huggingface",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
def embedding(self,
model: str,