feat(huggingface_restapi.py): Support multiple hf embedding types + async hf embeddings

Closes https://github.com/BerriAI/litellm/issues/3261
This commit is contained in:
Krrish Dholakia 2024-07-30 13:32:03 -07:00
parent f1b7d2318c
commit 69afbc6091
3 changed files with 332 additions and 59 deletions

View file

@ -6,12 +6,13 @@ import os
import time
import types
from enum import Enum
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union, get_args
import httpx
import requests
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.types.completion import ChatCompletionMessageToolCallParam
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
@ -60,6 +61,10 @@ hf_tasks = Literal[
"text-generation",
]
hf_tasks_embeddings = Literal[ # pipeline tags + hf tei endpoints - https://huggingface.github.io/text-embeddings-inference/#/
"sentence-similarity", "feature-extraction", "rerank", "embed", "similarity"
]
class HuggingfaceConfig:
"""
@ -249,6 +254,55 @@ def get_hf_task_for_model(model: str) -> Tuple[hf_tasks, str]:
return "text-generation-inference", model # default to tgi
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
def get_hf_task_embedding_for_model(
model: str, task_type: Optional[str], api_base: str
) -> Optional[str]:
if task_type is not None:
if task_type in get_args(hf_tasks_embeddings):
return task_type
else:
raise Exception(
"Invalid task_type={}. Expected one of={}".format(
task_type, hf_tasks_embeddings
)
)
http_client = HTTPHandler(concurrent_limit=1)
model_info = http_client.get(url=api_base)
model_info_dict = model_info.json()
pipeline_tag: Optional[str] = model_info_dict.get("pipeline_tag", None)
return pipeline_tag
async def async_get_hf_task_embedding_for_model(
model: str, task_type: Optional[str], api_base: str
) -> Optional[str]:
if task_type is not None:
if task_type in get_args(hf_tasks_embeddings):
return task_type
else:
raise Exception(
"Invalid task_type={}. Expected one of={}".format(
task_type, hf_tasks_embeddings
)
)
http_client = AsyncHTTPHandler(concurrent_limit=1)
model_info = await http_client.get(url=api_base)
model_info_dict = model_info.json()
pipeline_tag: Optional[str] = model_info_dict.get("pipeline_tag", None)
return pipeline_tag
class Huggingface(BaseLLM):
_client_session: Optional[httpx.Client] = None
_aclient_session: Optional[httpx.AsyncClient] = None
@ -256,7 +310,7 @@ class Huggingface(BaseLLM):
def __init__(self) -> None:
super().__init__()
def validate_environment(self, api_key, headers):
def validate_environment(self, api_key, headers) -> dict:
default_headers = {
"content-type": "application/json",
}
@ -762,76 +816,82 @@ class Huggingface(BaseLLM):
async for transformed_chunk in streamwrapper:
yield transformed_chunk
def embedding(
self,
model: str,
input: list,
model_response: litellm.EmbeddingResponse,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
logging_obj=None,
encoding=None,
):
super().embedding()
headers = self.validate_environment(api_key, headers=None)
# print_verbose(f"{model}, {task}")
embed_url = ""
if "https" in model:
embed_url = model
elif api_base:
embed_url = api_base
elif "HF_API_BASE" in os.environ:
embed_url = os.getenv("HF_API_BASE", "")
elif "HUGGINGFACE_API_BASE" in os.environ:
embed_url = os.getenv("HUGGINGFACE_API_BASE", "")
else:
embed_url = f"https://api-inference.huggingface.co/models/{model}"
def _transform_input_on_pipeline_tag(
self, input: List, pipeline_tag: Optional[str]
) -> dict:
if pipeline_tag is None:
return {"inputs": input}
if pipeline_tag == "sentence-similarity" or pipeline_tag == "similarity":
if len(input) < 2:
raise HuggingfaceError(
status_code=400,
message="sentence-similarity requires 2+ sentences",
)
return {"inputs": {"source_sentence": input[0], "sentences": input[1:]}}
elif pipeline_tag == "rerank":
if len(input) < 2:
raise HuggingfaceError(
status_code=400,
message="reranker requires 2+ sentences",
)
return {"inputs": {"query": input[0], "texts": input[1:]}}
return {"inputs": input} # default to feature-extraction pipeline tag
async def _async_transform_input(
self, model: str, task_type: Optional[str], embed_url: str, input: List
) -> dict:
hf_task = await async_get_hf_task_embedding_for_model(
model=model, task_type=task_type, api_base=embed_url
)
data = self._transform_input_on_pipeline_tag(input=input, pipeline_tag=hf_task)
return data
def _transform_input(
self,
input: List,
model: str,
call_type: Literal["sync", "async"],
optional_params: dict,
embed_url: str,
) -> dict:
## TRANSFORMATION ##
if "sentence-transformers" in model:
if len(input) == 0:
raise HuggingfaceError(
status_code=400,
message="sentence transformers requires 2+ sentences",
)
data = {
"inputs": {
"source_sentence": input[0],
"sentences": [
"That is a happy dog",
"That is a very happy person",
"Today is a sunny day",
],
}
}
data = {"inputs": {"source_sentence": input[0], "sentences": input[1:]}}
else:
data = {"inputs": input} # type: ignore
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": embed_url,
},
)
## COMPLETION CALL
response = requests.post(embed_url, headers=headers, data=json.dumps(data))
task_type = optional_params.pop("input_type", None)
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
if call_type == "sync":
hf_task = get_hf_task_embedding_for_model(
model=model, task_type=task_type, api_base=embed_url
)
elif call_type == "async":
return self._async_transform_input(
model=model, task_type=task_type, embed_url=embed_url, input=input
) # type: ignore
embeddings = response.json()
data = self._transform_input_on_pipeline_tag(
input=input, pipeline_tag=hf_task
)
if "error" in embeddings:
raise HuggingfaceError(status_code=500, message=embeddings["error"])
return data
def _process_embedding_response(
self,
embeddings: dict,
model_response: litellm.EmbeddingResponse,
model: str,
input: List,
encoding: Callable,
) -> litellm.EmbeddingResponse:
output_data = []
if "similarities" in embeddings:
for idx, embedding in embeddings["similarities"]:
@ -888,3 +948,156 @@ class Huggingface(BaseLLM):
),
)
return model_response
async def aembedding(
self,
model: str,
input: list,
model_response: litellm.utils.EmbeddingResponse,
timeout: Union[float, httpx.Timeout],
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
api_base: str,
api_key: Optional[str],
headers: dict,
encoding: Callable,
client: Optional[AsyncHTTPHandler] = None,
):
## TRANSFORMATION ##
data = self._transform_input(
input=input,
model=model,
call_type="sync",
optional_params=optional_params,
embed_url=api_base,
)
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": api_base,
},
)
## COMPLETION CALL
if client is None:
client = AsyncHTTPHandler(concurrent_limit=1)
response = await client.post(api_base, headers=headers, data=json.dumps(data))
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
embeddings = response.json()
if "error" in embeddings:
raise HuggingfaceError(status_code=500, message=embeddings["error"])
## PROCESS RESPONSE ##
return self._process_embedding_response(
embeddings=embeddings,
model_response=model_response,
model=model,
input=input,
encoding=encoding,
)
def embedding(
self,
model: str,
input: list,
model_response: litellm.EmbeddingResponse,
optional_params: dict,
logging_obj: LiteLLMLoggingObj,
encoding: Callable,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
aembedding: Optional[bool] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
) -> litellm.EmbeddingResponse:
super().embedding()
headers = self.validate_environment(api_key, headers=None)
# print_verbose(f"{model}, {task}")
embed_url = ""
if "https" in model:
embed_url = model
elif api_base:
embed_url = api_base
elif "HF_API_BASE" in os.environ:
embed_url = os.getenv("HF_API_BASE", "")
elif "HUGGINGFACE_API_BASE" in os.environ:
embed_url = os.getenv("HUGGINGFACE_API_BASE", "")
else:
embed_url = f"https://api-inference.huggingface.co/models/{model}"
## ROUTING ##
if aembedding is True:
return self.aembedding(
input=input,
model_response=model_response,
timeout=timeout,
logging_obj=logging_obj,
headers=headers,
api_base=embed_url, # type: ignore
api_key=api_key,
client=client if isinstance(client, AsyncHTTPHandler) else None,
model=model,
optional_params=optional_params,
encoding=encoding,
)
## TRANSFORMATION ##
data = self._transform_input(
input=input,
model=model,
call_type="sync",
optional_params=optional_params,
embed_url=embed_url,
)
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": embed_url,
},
)
## COMPLETION CALL
if client is None or not isinstance(client, HTTPHandler):
client = HTTPHandler(concurrent_limit=1)
response = client.post(embed_url, headers=headers, data=json.dumps(data))
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
embeddings = response.json()
if "error" in embeddings:
raise HuggingfaceError(status_code=500, message=embeddings["error"])
## PROCESS RESPONSE ##
return self._process_embedding_response(
embeddings=embeddings,
model_response=model_response,
model=model,
input=input,
encoding=encoding,
)

View file

@ -3114,6 +3114,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
or custom_llm_provider == "vertex_ai"
or custom_llm_provider == "databricks"
or custom_llm_provider == "watsonx"
or custom_llm_provider == "huggingface"
): # currently implemented aiohttp calls for just azure and openai, soon all.
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
@ -3450,7 +3451,7 @@ def embedding(
or litellm.huggingface_key
or get_secret("HUGGINGFACE_API_KEY")
or litellm.api_key
)
) # type: ignore
response = huggingface.embedding(
model=model,
input=input,
@ -3459,6 +3460,9 @@ def embedding(
api_base=api_base,
logging_obj=logging,
model_response=EmbeddingResponse(),
optional_params=optional_params,
client=client,
aembedding=aembedding,
)
elif custom_llm_provider == "bedrock":
response = bedrock.embedding(

View file

@ -409,6 +409,62 @@ def test_hf_embedding():
# test_hf_embedding()
from unittest.mock import MagicMock, patch
def tgi_mock_post(*args, **kwargs):
import json
expected_data = {
"inputs": {
"source_sentence": "good morning from litellm",
"sentences": ["this is another item"],
}
}
assert (
json.loads(kwargs["data"]) == expected_data
), "Data does not match the expected data"
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json.return_value = [0.7708950042724609]
return mock_response
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_hf_embedding_sentence_sim(sync_mode):
try:
# huggingface/microsoft/codebert-base
# huggingface/facebook/bart-large
if sync_mode is True:
client = HTTPHandler(concurrent_limit=1)
else:
client = AsyncHTTPHandler(concurrent_limit=1)
with patch.object(client, "post", side_effect=tgi_mock_post) as mock_client:
data = {
"model": "huggingface/TaylorAI/bge-micro-v2",
"input": ["good morning from litellm", "this is another item"],
"client": client,
}
if sync_mode is True:
response = embedding(**data)
else:
response = await litellm.aembedding(**data)
print(f"response:", response)
mock_client.assert_called_once()
assert isinstance(response.usage, litellm.Usage)
except Exception as e:
# Note: Huggingface inference API is unstable and fails with "model loading errors all the time"
raise e
# test async embeddings
def test_aembedding():