fix(huggingface_restapi.py): fixes issue where 'wait_for_model' was not being passed as expected

This commit is contained in:
Krrish Dholakia 2024-08-09 08:35:36 -07:00
parent 122d8ab2f4
commit d382de7b74
3 changed files with 64 additions and 3 deletions

View file

@ -1,3 +1,4 @@
import json
import os
import sys
import traceback
@ -11,7 +12,7 @@ load_dotenv()
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from unittest.mock import MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import litellm
from litellm import completion, completion_cost, embedding
@ -740,3 +741,43 @@ async def test_databricks_embeddings(sync_mode):
# print(response)
# local_proxy_embeddings()
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_hf_embedddings_with_optional_params(sync_mode):
litellm.set_verbose = True
if sync_mode:
client = HTTPHandler(concurrent_limit=1)
mock_obj = MagicMock()
else:
client = AsyncHTTPHandler(concurrent_limit=1)
mock_obj = AsyncMock()
with patch.object(client, "post", new=mock_obj) as mock_client:
try:
if sync_mode:
response = embedding(
model="huggingface/jinaai/jina-embeddings-v2-small-en",
input=["good morning from litellm"],
wait_for_model=True,
client=client,
)
else:
response = await litellm.aembedding(
model="huggingface/jinaai/jina-embeddings-v2-small-en",
input=["good morning from litellm"],
wait_for_model=True,
client=client,
)
except Exception:
pass
mock_client.assert_called_once()
print(f"mock_client.call_args.kwargs: {mock_client.call_args.kwargs}")
assert "options" in mock_client.call_args.kwargs["data"]
json_data = json.loads(mock_client.call_args.kwargs["data"])
assert "wait_for_model" in json_data["options"]
assert json_data["options"]["wait_for_model"] is True