forked from phoenix/litellm-mirror
fix(huggingface_restapi.py): async implementation
This commit is contained in:
parent
cc955fca89
commit
03efc9185e
2 changed files with 31 additions and 34 deletions
|
@ -448,24 +448,22 @@ class Huggingface(BaseLLM):
|
|||
input_text: str,
|
||||
model: str,
|
||||
optional_params: dict):
|
||||
if self._aclient_session is None:
|
||||
self._aclient_session = self.create_aclient_session()
|
||||
client = self._aclient_session
|
||||
response = None
|
||||
try:
|
||||
response = await client.post(url=api_base, json=data, headers=headers)
|
||||
response_json = response.json()
|
||||
if response.status_code != 200:
|
||||
raise HuggingfaceError(status_code=response.status_code, message=response.text, request=response.request, response=response)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
return self.convert_to_model_response_object(completion_response=response_json,
|
||||
model_response=model_response,
|
||||
task=task,
|
||||
encoding=encoding,
|
||||
input_text=input_text,
|
||||
model=model,
|
||||
optional_params=optional_params)
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(url=api_base, json=data, headers=headers)
|
||||
response_json = response.json()
|
||||
if response.status_code != 200:
|
||||
raise HuggingfaceError(status_code=response.status_code, message=response.text, request=response.request, response=response)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
return self.convert_to_model_response_object(completion_response=response_json,
|
||||
model_response=model_response,
|
||||
task=task,
|
||||
encoding=encoding,
|
||||
input_text=input_text,
|
||||
model=model,
|
||||
optional_params=optional_params)
|
||||
except Exception as e:
|
||||
if isinstance(e,httpx.TimeoutException):
|
||||
raise HuggingfaceError(status_code=500, message="Request Timeout Error")
|
||||
|
@ -481,21 +479,20 @@ class Huggingface(BaseLLM):
|
|||
headers: dict,
|
||||
model_response: ModelResponse,
|
||||
model: str):
|
||||
if self._aclient_session is None:
|
||||
self._aclient_session = self.create_aclient_session()
|
||||
client = self._aclient_session
|
||||
async with client.stream(
|
||||
url=f"{api_base}",
|
||||
json=data,
|
||||
headers=headers,
|
||||
method="POST"
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
raise HuggingfaceError(status_code=response.status_code, message="An error occurred while streaming")
|
||||
|
||||
streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="huggingface",logging_obj=logging_obj)
|
||||
async for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = client.stream(
|
||||
"POST",
|
||||
url=f"{api_base}",
|
||||
json=data,
|
||||
headers=headers
|
||||
)
|
||||
async with response as r:
|
||||
if r.status_code != 200:
|
||||
raise HuggingfaceError(status_code=r.status_code, message="An error occurred while streaming")
|
||||
|
||||
streamwrapper = CustomStreamWrapper(completion_stream=r.aiter_lines(), model=model, custom_llm_provider="huggingface",logging_obj=logging_obj)
|
||||
async for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
|
||||
def embedding(self,
|
||||
model: str,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue