forked from phoenix/litellm-mirror
fix(huggingface_restapi.py): fixes issue where 'wait_for_model' was not being passed as expected
This commit is contained in:
parent
a1c3167853
commit
51ccfa9e77
3 changed files with 64 additions and 3 deletions
|
@ -838,7 +838,12 @@ class Huggingface(BaseLLM):
|
||||||
return {"inputs": input} # default to feature-extraction pipeline tag
|
return {"inputs": input} # default to feature-extraction pipeline tag
|
||||||
|
|
||||||
async def _async_transform_input(
|
async def _async_transform_input(
|
||||||
self, model: str, task_type: Optional[str], embed_url: str, input: List
|
self,
|
||||||
|
model: str,
|
||||||
|
task_type: Optional[str],
|
||||||
|
embed_url: str,
|
||||||
|
input: List,
|
||||||
|
optional_params: dict,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
hf_task = await async_get_hf_task_embedding_for_model(
|
hf_task = await async_get_hf_task_embedding_for_model(
|
||||||
model=model, task_type=task_type, api_base=embed_url
|
model=model, task_type=task_type, api_base=embed_url
|
||||||
|
@ -846,6 +851,9 @@ class Huggingface(BaseLLM):
|
||||||
|
|
||||||
data = self._transform_input_on_pipeline_tag(input=input, pipeline_tag=hf_task)
|
data = self._transform_input_on_pipeline_tag(input=input, pipeline_tag=hf_task)
|
||||||
|
|
||||||
|
if len(optional_params.keys()) > 0:
|
||||||
|
data["options"] = optional_params
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def _transform_input(
|
def _transform_input(
|
||||||
|
@ -856,6 +864,7 @@ class Huggingface(BaseLLM):
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
embed_url: str,
|
embed_url: str,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
data: Dict = {}
|
||||||
## TRANSFORMATION ##
|
## TRANSFORMATION ##
|
||||||
if "sentence-transformers" in model:
|
if "sentence-transformers" in model:
|
||||||
if len(input) == 0:
|
if len(input) == 0:
|
||||||
|
@ -865,7 +874,7 @@ class Huggingface(BaseLLM):
|
||||||
)
|
)
|
||||||
data = {"inputs": {"source_sentence": input[0], "sentences": input[1:]}}
|
data = {"inputs": {"source_sentence": input[0], "sentences": input[1:]}}
|
||||||
else:
|
else:
|
||||||
data = {"inputs": input} # type: ignore
|
data = {"inputs": input}
|
||||||
|
|
||||||
task_type = optional_params.pop("input_type", None)
|
task_type = optional_params.pop("input_type", None)
|
||||||
|
|
||||||
|
@ -882,6 +891,9 @@ class Huggingface(BaseLLM):
|
||||||
input=input, pipeline_tag=hf_task
|
input=input, pipeline_tag=hf_task
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if len(optional_params.keys()) > 0:
|
||||||
|
data["options"] = optional_params
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def _process_embedding_response(
|
def _process_embedding_response(
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
@ -11,7 +12,7 @@ load_dotenv()
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import completion, completion_cost, embedding
|
from litellm import completion, completion_cost, embedding
|
||||||
|
@ -740,3 +741,43 @@ async def test_databricks_embeddings(sync_mode):
|
||||||
# print(response)
|
# print(response)
|
||||||
|
|
||||||
# local_proxy_embeddings()
|
# local_proxy_embeddings()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_hf_embedddings_with_optional_params(sync_mode):
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
if sync_mode:
|
||||||
|
client = HTTPHandler(concurrent_limit=1)
|
||||||
|
mock_obj = MagicMock()
|
||||||
|
else:
|
||||||
|
client = AsyncHTTPHandler(concurrent_limit=1)
|
||||||
|
mock_obj = AsyncMock()
|
||||||
|
|
||||||
|
with patch.object(client, "post", new=mock_obj) as mock_client:
|
||||||
|
try:
|
||||||
|
if sync_mode:
|
||||||
|
response = embedding(
|
||||||
|
model="huggingface/jinaai/jina-embeddings-v2-small-en",
|
||||||
|
input=["good morning from litellm"],
|
||||||
|
wait_for_model=True,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = await litellm.aembedding(
|
||||||
|
model="huggingface/jinaai/jina-embeddings-v2-small-en",
|
||||||
|
input=["good morning from litellm"],
|
||||||
|
wait_for_model=True,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
mock_client.assert_called_once()
|
||||||
|
|
||||||
|
print(f"mock_client.call_args.kwargs: {mock_client.call_args.kwargs}")
|
||||||
|
assert "options" in mock_client.call_args.kwargs["data"]
|
||||||
|
json_data = json.loads(mock_client.call_args.kwargs["data"])
|
||||||
|
assert "wait_for_model" in json_data["options"]
|
||||||
|
assert json_data["options"]["wait_for_model"] is True
|
||||||
|
|
|
@ -438,3 +438,11 @@ def test_get_optional_params_image_gen():
|
||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
assert "aws_region_name" in response
|
assert "aws_region_name" in response
|
||||||
|
|
||||||
|
|
||||||
|
def test_bedrock_optional_params_embeddings_provider_specific_params():
|
||||||
|
optional_params = get_optional_params_embeddings(
|
||||||
|
custom_llm_provider="huggingface",
|
||||||
|
wait_for_model=True,
|
||||||
|
)
|
||||||
|
assert len(optional_params) == 1
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue