fix(huggingface_restapi.py): async implementation

This commit is contained in:
Krrish Dholakia 2023-11-15 16:54:08 -08:00
parent cc955fca89
commit 03efc9185e
2 changed files with 31 additions and 34 deletions

View file

@ -448,11 +448,9 @@ class Huggingface(BaseLLM):
input_text: str, input_text: str,
model: str, model: str,
optional_params: dict): optional_params: dict):
if self._aclient_session is None:
self._aclient_session = self.create_aclient_session()
client = self._aclient_session
response = None response = None
try: try:
async with httpx.AsyncClient() as client:
response = await client.post(url=api_base, json=data, headers=headers) response = await client.post(url=api_base, json=data, headers=headers)
response_json = response.json() response_json = response.json()
if response.status_code != 200: if response.status_code != 200:
@ -481,19 +479,18 @@ class Huggingface(BaseLLM):
headers: dict, headers: dict,
model_response: ModelResponse, model_response: ModelResponse,
model: str): model: str):
if self._aclient_session is None: async with httpx.AsyncClient() as client:
self._aclient_session = self.create_aclient_session() response = client.stream(
client = self._aclient_session "POST",
async with client.stream(
url=f"{api_base}", url=f"{api_base}",
json=data, json=data,
headers=headers, headers=headers
method="POST" )
) as response: async with response as r:
if response.status_code != 200: if r.status_code != 200:
raise HuggingfaceError(status_code=response.status_code, message="An error occurred while streaming") raise HuggingfaceError(status_code=r.status_code, message="An error occurred while streaming")
streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="huggingface",logging_obj=logging_obj) streamwrapper = CustomStreamWrapper(completion_stream=r.aiter_lines(), model=model, custom_llm_provider="huggingface",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper: async for transformed_chunk in streamwrapper:
yield transformed_chunk yield transformed_chunk

View file

@ -4,7 +4,9 @@
import sys, os import sys, os
import pytest import pytest
import traceback import traceback
import asyncio import asyncio, logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
@ -99,5 +101,3 @@ def test_get_response_non_openai_streaming():
pytest.fail(f"An exception occurred: {e}") pytest.fail(f"An exception occurred: {e}")
return response return response
asyncio.run(test_async_call()) asyncio.run(test_async_call())
test_get_response_non_openai_streaming()