forked from phoenix/litellm-mirror
fix(huggingface_restapi.py): async implementation
This commit is contained in:
parent
cc955fca89
commit
03efc9185e
2 changed files with 31 additions and 34 deletions
|
@ -448,11 +448,9 @@ class Huggingface(BaseLLM):
|
|||
input_text: str,
|
||||
model: str,
|
||||
optional_params: dict):
|
||||
if self._aclient_session is None:
|
||||
self._aclient_session = self.create_aclient_session()
|
||||
client = self._aclient_session
|
||||
response = None
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(url=api_base, json=data, headers=headers)
|
||||
response_json = response.json()
|
||||
if response.status_code != 200:
|
||||
|
@ -481,19 +479,18 @@ class Huggingface(BaseLLM):
|
|||
headers: dict,
|
||||
model_response: ModelResponse,
|
||||
model: str):
|
||||
if self._aclient_session is None:
|
||||
self._aclient_session = self.create_aclient_session()
|
||||
client = self._aclient_session
|
||||
async with client.stream(
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = client.stream(
|
||||
"POST",
|
||||
url=f"{api_base}",
|
||||
json=data,
|
||||
headers=headers,
|
||||
method="POST"
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
raise HuggingfaceError(status_code=response.status_code, message="An error occurred while streaming")
|
||||
headers=headers
|
||||
)
|
||||
async with response as r:
|
||||
if r.status_code != 200:
|
||||
raise HuggingfaceError(status_code=r.status_code, message="An error occurred while streaming")
|
||||
|
||||
streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="huggingface",logging_obj=logging_obj)
|
||||
streamwrapper = CustomStreamWrapper(completion_stream=r.aiter_lines(), model=model, custom_llm_provider="huggingface",logging_obj=logging_obj)
|
||||
async for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
|
||||
|
|
|
@ -4,7 +4,9 @@
|
|||
import sys, os
|
||||
import pytest
|
||||
import traceback
|
||||
import asyncio
|
||||
import asyncio, logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
|
@ -99,5 +101,3 @@ def test_get_response_non_openai_streaming():
|
|||
pytest.fail(f"An exception occurred: {e}")
|
||||
return response
|
||||
asyncio.run(test_async_call())
|
||||
|
||||
test_get_response_non_openai_streaming()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue