fix(huggingface_restapi.py): async implementation

This commit is contained in:
Krrish Dholakia 2023-11-15 16:54:08 -08:00
parent cc955fca89
commit 03efc9185e
2 changed files with 31 additions and 34 deletions

View file

@ -448,11 +448,9 @@ class Huggingface(BaseLLM):
input_text: str,
model: str,
optional_params: dict):
if self._aclient_session is None:
self._aclient_session = self.create_aclient_session()
client = self._aclient_session
response = None
try:
async with httpx.AsyncClient() as client:
response = await client.post(url=api_base, json=data, headers=headers)
response_json = response.json()
if response.status_code != 200:
@ -481,19 +479,18 @@ class Huggingface(BaseLLM):
headers: dict,
model_response: ModelResponse,
model: str):
if self._aclient_session is None:
self._aclient_session = self.create_aclient_session()
client = self._aclient_session
async with client.stream(
async with httpx.AsyncClient() as client:
response = client.stream(
"POST",
url=f"{api_base}",
json=data,
headers=headers,
method="POST"
) as response:
if response.status_code != 200:
raise HuggingfaceError(status_code=response.status_code, message="An error occurred while streaming")
headers=headers
)
async with response as r:
if r.status_code != 200:
raise HuggingfaceError(status_code=r.status_code, message="An error occurred while streaming")
streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="huggingface",logging_obj=logging_obj)
streamwrapper = CustomStreamWrapper(completion_stream=r.aiter_lines(), model=model, custom_llm_provider="huggingface",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper:
yield transformed_chunk

View file

@ -4,7 +4,9 @@
import sys, os
import pytest
import traceback
import asyncio
import asyncio, logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
sys.path.insert(
0, os.path.abspath("../..")
@ -99,5 +101,3 @@ def test_get_response_non_openai_streaming():
pytest.fail(f"An exception occurred: {e}")
return response
asyncio.run(test_async_call())
test_get_response_non_openai_streaming()