fix(huggingface_restapi.py): fixes issue where 'wait_for_model' was not being passed as expected

This commit is contained in:
Krrish Dholakia 2024-08-09 08:35:36 -07:00
parent a1c3167853
commit 51ccfa9e77
3 changed files with 64 additions and 3 deletions

View file

@ -838,7 +838,12 @@ class Huggingface(BaseLLM):
return {"inputs": input} # default to feature-extraction pipeline tag return {"inputs": input} # default to feature-extraction pipeline tag
async def _async_transform_input( async def _async_transform_input(
self, model: str, task_type: Optional[str], embed_url: str, input: List self,
model: str,
task_type: Optional[str],
embed_url: str,
input: List,
optional_params: dict,
) -> dict: ) -> dict:
hf_task = await async_get_hf_task_embedding_for_model( hf_task = await async_get_hf_task_embedding_for_model(
model=model, task_type=task_type, api_base=embed_url model=model, task_type=task_type, api_base=embed_url
@ -846,6 +851,9 @@ class Huggingface(BaseLLM):
data = self._transform_input_on_pipeline_tag(input=input, pipeline_tag=hf_task) data = self._transform_input_on_pipeline_tag(input=input, pipeline_tag=hf_task)
if len(optional_params.keys()) > 0:
data["options"] = optional_params
return data return data
def _transform_input( def _transform_input(
@ -856,6 +864,7 @@ class Huggingface(BaseLLM):
optional_params: dict, optional_params: dict,
embed_url: str, embed_url: str,
) -> dict: ) -> dict:
data: Dict = {}
## TRANSFORMATION ## ## TRANSFORMATION ##
if "sentence-transformers" in model: if "sentence-transformers" in model:
if len(input) == 0: if len(input) == 0:
@ -865,7 +874,7 @@ class Huggingface(BaseLLM):
) )
data = {"inputs": {"source_sentence": input[0], "sentences": input[1:]}} data = {"inputs": {"source_sentence": input[0], "sentences": input[1:]}}
else: else:
data = {"inputs": input} # type: ignore data = {"inputs": input}
task_type = optional_params.pop("input_type", None) task_type = optional_params.pop("input_type", None)
@ -882,6 +891,9 @@ class Huggingface(BaseLLM):
input=input, pipeline_tag=hf_task input=input, pipeline_tag=hf_task
) )
if len(optional_params.keys()) > 0:
data["options"] = optional_params
return data return data
def _process_embedding_response( def _process_embedding_response(

View file

@ -1,3 +1,4 @@
import json
import os import os
import sys import sys
import traceback import traceback
@ -11,7 +12,7 @@ load_dotenv()
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
from unittest.mock import MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import litellm import litellm
from litellm import completion, completion_cost, embedding from litellm import completion, completion_cost, embedding
@ -740,3 +741,43 @@ async def test_databricks_embeddings(sync_mode):
# print(response) # print(response)
# local_proxy_embeddings() # local_proxy_embeddings()
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_hf_embedddings_with_optional_params(sync_mode):
litellm.set_verbose = True
if sync_mode:
client = HTTPHandler(concurrent_limit=1)
mock_obj = MagicMock()
else:
client = AsyncHTTPHandler(concurrent_limit=1)
mock_obj = AsyncMock()
with patch.object(client, "post", new=mock_obj) as mock_client:
try:
if sync_mode:
response = embedding(
model="huggingface/jinaai/jina-embeddings-v2-small-en",
input=["good morning from litellm"],
wait_for_model=True,
client=client,
)
else:
response = await litellm.aembedding(
model="huggingface/jinaai/jina-embeddings-v2-small-en",
input=["good morning from litellm"],
wait_for_model=True,
client=client,
)
except Exception:
pass
mock_client.assert_called_once()
print(f"mock_client.call_args.kwargs: {mock_client.call_args.kwargs}")
assert "options" in mock_client.call_args.kwargs["data"]
json_data = json.loads(mock_client.call_args.kwargs["data"])
assert "wait_for_model" in json_data["options"]
assert json_data["options"]["wait_for_model"] is True

View file

@ -438,3 +438,11 @@ def test_get_optional_params_image_gen():
print(response) print(response)
assert "aws_region_name" in response assert "aws_region_name" in response
def test_bedrock_optional_params_embeddings_provider_specific_params():
optional_params = get_optional_params_embeddings(
custom_llm_provider="huggingface",
wait_for_model=True,
)
assert len(optional_params) == 1