mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
fix(huggingface_restapi.py): fixes issue where 'wait_for_model' was not being passed as expected
This commit is contained in:
parent
122d8ab2f4
commit
d382de7b74
3 changed files with 64 additions and 3 deletions
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
@ -11,7 +12,7 @@ load_dotenv()
|
|||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import litellm
|
||||
from litellm import completion, completion_cost, embedding
|
||||
|
@ -740,3 +741,43 @@ async def test_databricks_embeddings(sync_mode):
|
|||
# print(response)
|
||||
|
||||
# local_proxy_embeddings()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_hf_embedddings_with_optional_params(sync_mode):
|
||||
litellm.set_verbose = True
|
||||
|
||||
if sync_mode:
|
||||
client = HTTPHandler(concurrent_limit=1)
|
||||
mock_obj = MagicMock()
|
||||
else:
|
||||
client = AsyncHTTPHandler(concurrent_limit=1)
|
||||
mock_obj = AsyncMock()
|
||||
|
||||
with patch.object(client, "post", new=mock_obj) as mock_client:
|
||||
try:
|
||||
if sync_mode:
|
||||
response = embedding(
|
||||
model="huggingface/jinaai/jina-embeddings-v2-small-en",
|
||||
input=["good morning from litellm"],
|
||||
wait_for_model=True,
|
||||
client=client,
|
||||
)
|
||||
else:
|
||||
response = await litellm.aembedding(
|
||||
model="huggingface/jinaai/jina-embeddings-v2-small-en",
|
||||
input=["good morning from litellm"],
|
||||
wait_for_model=True,
|
||||
client=client,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
mock_client.assert_called_once()
|
||||
|
||||
print(f"mock_client.call_args.kwargs: {mock_client.call_args.kwargs}")
|
||||
assert "options" in mock_client.call_args.kwargs["data"]
|
||||
json_data = json.loads(mock_client.call_args.kwargs["data"])
|
||||
assert "wait_for_model" in json_data["options"]
|
||||
assert json_data["options"]["wait_for_model"] is True
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue