fix(huggingface_restapi.py): fixes issue where 'wait_for_model' was not being passed as expected

This commit is contained in:
Krrish Dholakia 2024-08-09 08:35:36 -07:00
parent a1c3167853
commit 51ccfa9e77
3 changed files with 64 additions and 3 deletions

View file

@ -838,7 +838,12 @@ class Huggingface(BaseLLM):
return {"inputs": input} # default to feature-extraction pipeline tag
async def _async_transform_input(
self, model: str, task_type: Optional[str], embed_url: str, input: List
self,
model: str,
task_type: Optional[str],
embed_url: str,
input: List,
optional_params: dict,
) -> dict:
hf_task = await async_get_hf_task_embedding_for_model(
model=model, task_type=task_type, api_base=embed_url
@ -846,6 +851,9 @@ class Huggingface(BaseLLM):
data = self._transform_input_on_pipeline_tag(input=input, pipeline_tag=hf_task)
if len(optional_params.keys()) > 0:
data["options"] = optional_params
return data
def _transform_input(
@ -856,6 +864,7 @@ class Huggingface(BaseLLM):
optional_params: dict,
embed_url: str,
) -> dict:
data: Dict = {}
## TRANSFORMATION ##
if "sentence-transformers" in model:
if len(input) == 0:
@ -865,7 +874,7 @@ class Huggingface(BaseLLM):
)
data = {"inputs": {"source_sentence": input[0], "sentences": input[1:]}}
else:
data = {"inputs": input} # type: ignore
data = {"inputs": input}
task_type = optional_params.pop("input_type", None)
@ -882,6 +891,9 @@ class Huggingface(BaseLLM):
input=input, pipeline_tag=hf_task
)
if len(optional_params.keys()) > 0:
data["options"] = optional_params
return data
def _process_embedding_response(