forked from phoenix/litellm-mirror
fix(huggingface_restapi.py): async implementation
This commit is contained in:
parent
cc955fca89
commit
03efc9185e
2 changed files with 31 additions and 34 deletions
|
@ -448,24 +448,22 @@ class Huggingface(BaseLLM):
|
||||||
input_text: str,
|
input_text: str,
|
||||||
model: str,
|
model: str,
|
||||||
optional_params: dict):
|
optional_params: dict):
|
||||||
if self._aclient_session is None:
|
|
||||||
self._aclient_session = self.create_aclient_session()
|
|
||||||
client = self._aclient_session
|
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
response = await client.post(url=api_base, json=data, headers=headers)
|
async with httpx.AsyncClient() as client:
|
||||||
response_json = response.json()
|
response = await client.post(url=api_base, json=data, headers=headers)
|
||||||
if response.status_code != 200:
|
response_json = response.json()
|
||||||
raise HuggingfaceError(status_code=response.status_code, message=response.text, request=response.request, response=response)
|
if response.status_code != 200:
|
||||||
|
raise HuggingfaceError(status_code=response.status_code, message=response.text, request=response.request, response=response)
|
||||||
## RESPONSE OBJECT
|
|
||||||
return self.convert_to_model_response_object(completion_response=response_json,
|
## RESPONSE OBJECT
|
||||||
model_response=model_response,
|
return self.convert_to_model_response_object(completion_response=response_json,
|
||||||
task=task,
|
model_response=model_response,
|
||||||
encoding=encoding,
|
task=task,
|
||||||
input_text=input_text,
|
encoding=encoding,
|
||||||
model=model,
|
input_text=input_text,
|
||||||
optional_params=optional_params)
|
model=model,
|
||||||
|
optional_params=optional_params)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if isinstance(e,httpx.TimeoutException):
|
if isinstance(e,httpx.TimeoutException):
|
||||||
raise HuggingfaceError(status_code=500, message="Request Timeout Error")
|
raise HuggingfaceError(status_code=500, message="Request Timeout Error")
|
||||||
|
@ -481,21 +479,20 @@ class Huggingface(BaseLLM):
|
||||||
headers: dict,
|
headers: dict,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
model: str):
|
model: str):
|
||||||
if self._aclient_session is None:
|
async with httpx.AsyncClient() as client:
|
||||||
self._aclient_session = self.create_aclient_session()
|
response = client.stream(
|
||||||
client = self._aclient_session
|
"POST",
|
||||||
async with client.stream(
|
url=f"{api_base}",
|
||||||
url=f"{api_base}",
|
json=data,
|
||||||
json=data,
|
headers=headers
|
||||||
headers=headers,
|
)
|
||||||
method="POST"
|
async with response as r:
|
||||||
) as response:
|
if r.status_code != 200:
|
||||||
if response.status_code != 200:
|
raise HuggingfaceError(status_code=r.status_code, message="An error occurred while streaming")
|
||||||
raise HuggingfaceError(status_code=response.status_code, message="An error occurred while streaming")
|
|
||||||
|
streamwrapper = CustomStreamWrapper(completion_stream=r.aiter_lines(), model=model, custom_llm_provider="huggingface",logging_obj=logging_obj)
|
||||||
streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="huggingface",logging_obj=logging_obj)
|
async for transformed_chunk in streamwrapper:
|
||||||
async for transformed_chunk in streamwrapper:
|
yield transformed_chunk
|
||||||
yield transformed_chunk
|
|
||||||
|
|
||||||
def embedding(self,
|
def embedding(self,
|
||||||
model: str,
|
model: str,
|
||||||
|
|
|
@ -4,7 +4,9 @@
|
||||||
import sys, os
|
import sys, os
|
||||||
import pytest
|
import pytest
|
||||||
import traceback
|
import traceback
|
||||||
import asyncio
|
import asyncio, logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
|
@ -99,5 +101,3 @@ def test_get_response_non_openai_streaming():
|
||||||
pytest.fail(f"An exception occurred: {e}")
|
pytest.fail(f"An exception occurred: {e}")
|
||||||
return response
|
return response
|
||||||
asyncio.run(test_async_call())
|
asyncio.run(test_async_call())
|
||||||
|
|
||||||
test_get_response_non_openai_streaming()
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue