mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
feat(huggingface_restapi.py): Support multiple hf embedding types + async hf embeddings
Closes https://github.com/BerriAI/litellm/issues/3261
This commit is contained in:
parent
f1b7d2318c
commit
69afbc6091
3 changed files with 332 additions and 59 deletions
|
@ -409,6 +409,62 @@ def test_hf_embedding():
|
|||
|
||||
# test_hf_embedding()
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
def tgi_mock_post(*args, **kwargs):
|
||||
import json
|
||||
|
||||
expected_data = {
|
||||
"inputs": {
|
||||
"source_sentence": "good morning from litellm",
|
||||
"sentences": ["this is another item"],
|
||||
}
|
||||
}
|
||||
assert (
|
||||
json.loads(kwargs["data"]) == expected_data
|
||||
), "Data does not match the expected data"
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"Content-Type": "application/json"}
|
||||
mock_response.json.return_value = [0.7708950042724609]
|
||||
return mock_response
|
||||
|
||||
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_hf_embedding_sentence_sim(sync_mode):
|
||||
try:
|
||||
# huggingface/microsoft/codebert-base
|
||||
# huggingface/facebook/bart-large
|
||||
if sync_mode is True:
|
||||
client = HTTPHandler(concurrent_limit=1)
|
||||
else:
|
||||
client = AsyncHTTPHandler(concurrent_limit=1)
|
||||
with patch.object(client, "post", side_effect=tgi_mock_post) as mock_client:
|
||||
data = {
|
||||
"model": "huggingface/TaylorAI/bge-micro-v2",
|
||||
"input": ["good morning from litellm", "this is another item"],
|
||||
"client": client,
|
||||
}
|
||||
if sync_mode is True:
|
||||
response = embedding(**data)
|
||||
else:
|
||||
response = await litellm.aembedding(**data)
|
||||
|
||||
print(f"response:", response)
|
||||
|
||||
mock_client.assert_called_once()
|
||||
|
||||
assert isinstance(response.usage, litellm.Usage)
|
||||
|
||||
except Exception as e:
|
||||
# Note: Huggingface inference API is unstable and fails with "model loading errors all the time"
|
||||
raise e
|
||||
|
||||
|
||||
# test async embeddings
|
||||
def test_aembedding():
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue