feat(huggingface_restapi.py): Support multiple hf embedding types + async hf embeddings

Closes https://github.com/BerriAI/litellm/issues/3261
This commit is contained in:
Krrish Dholakia 2024-07-30 13:32:03 -07:00
parent f1b7d2318c
commit 69afbc6091
3 changed files with 332 additions and 59 deletions

View file

@ -6,12 +6,13 @@ import os
import time import time
import types import types
from enum import Enum from enum import Enum
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union, get_args
import httpx import httpx
import requests import requests
import litellm import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.types.completion import ChatCompletionMessageToolCallParam from litellm.types.completion import ChatCompletionMessageToolCallParam
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
@ -60,6 +61,10 @@ hf_tasks = Literal[
"text-generation", "text-generation",
] ]
hf_tasks_embeddings = Literal[ # pipeline tags + hf tei endpoints - https://huggingface.github.io/text-embeddings-inference/#/
"sentence-similarity", "feature-extraction", "rerank", "embed", "similarity"
]
class HuggingfaceConfig: class HuggingfaceConfig:
""" """
@ -249,6 +254,55 @@ def get_hf_task_for_model(model: str) -> Tuple[hf_tasks, str]:
return "text-generation-inference", model # default to tgi return "text-generation-inference", model # default to tgi
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
def get_hf_task_embedding_for_model(
model: str, task_type: Optional[str], api_base: str
) -> Optional[str]:
if task_type is not None:
if task_type in get_args(hf_tasks_embeddings):
return task_type
else:
raise Exception(
"Invalid task_type={}. Expected one of={}".format(
task_type, hf_tasks_embeddings
)
)
http_client = HTTPHandler(concurrent_limit=1)
model_info = http_client.get(url=api_base)
model_info_dict = model_info.json()
pipeline_tag: Optional[str] = model_info_dict.get("pipeline_tag", None)
return pipeline_tag
async def async_get_hf_task_embedding_for_model(
model: str, task_type: Optional[str], api_base: str
) -> Optional[str]:
if task_type is not None:
if task_type in get_args(hf_tasks_embeddings):
return task_type
else:
raise Exception(
"Invalid task_type={}. Expected one of={}".format(
task_type, hf_tasks_embeddings
)
)
http_client = AsyncHTTPHandler(concurrent_limit=1)
model_info = await http_client.get(url=api_base)
model_info_dict = model_info.json()
pipeline_tag: Optional[str] = model_info_dict.get("pipeline_tag", None)
return pipeline_tag
class Huggingface(BaseLLM): class Huggingface(BaseLLM):
_client_session: Optional[httpx.Client] = None _client_session: Optional[httpx.Client] = None
_aclient_session: Optional[httpx.AsyncClient] = None _aclient_session: Optional[httpx.AsyncClient] = None
@ -256,7 +310,7 @@ class Huggingface(BaseLLM):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
def validate_environment(self, api_key, headers): def validate_environment(self, api_key, headers) -> dict:
default_headers = { default_headers = {
"content-type": "application/json", "content-type": "application/json",
} }
@ -762,76 +816,82 @@ class Huggingface(BaseLLM):
async for transformed_chunk in streamwrapper: async for transformed_chunk in streamwrapper:
yield transformed_chunk yield transformed_chunk
def embedding( def _transform_input_on_pipeline_tag(
self, self, input: List, pipeline_tag: Optional[str]
model: str, ) -> dict:
input: list, if pipeline_tag is None:
model_response: litellm.EmbeddingResponse, return {"inputs": input}
api_key: Optional[str] = None, if pipeline_tag == "sentence-similarity" or pipeline_tag == "similarity":
api_base: Optional[str] = None, if len(input) < 2:
logging_obj=None, raise HuggingfaceError(
encoding=None, status_code=400,
): message="sentence-similarity requires 2+ sentences",
super().embedding() )
headers = self.validate_environment(api_key, headers=None) return {"inputs": {"source_sentence": input[0], "sentences": input[1:]}}
# print_verbose(f"{model}, {task}") elif pipeline_tag == "rerank":
embed_url = "" if len(input) < 2:
if "https" in model: raise HuggingfaceError(
embed_url = model status_code=400,
elif api_base: message="reranker requires 2+ sentences",
embed_url = api_base )
elif "HF_API_BASE" in os.environ: return {"inputs": {"query": input[0], "texts": input[1:]}}
embed_url = os.getenv("HF_API_BASE", "") return {"inputs": input} # default to feature-extraction pipeline tag
elif "HUGGINGFACE_API_BASE" in os.environ:
embed_url = os.getenv("HUGGINGFACE_API_BASE", "")
else:
embed_url = f"https://api-inference.huggingface.co/models/{model}"
async def _async_transform_input(
self, model: str, task_type: Optional[str], embed_url: str, input: List
) -> dict:
hf_task = await async_get_hf_task_embedding_for_model(
model=model, task_type=task_type, api_base=embed_url
)
data = self._transform_input_on_pipeline_tag(input=input, pipeline_tag=hf_task)
return data
def _transform_input(
self,
input: List,
model: str,
call_type: Literal["sync", "async"],
optional_params: dict,
embed_url: str,
) -> dict:
## TRANSFORMATION ##
if "sentence-transformers" in model: if "sentence-transformers" in model:
if len(input) == 0: if len(input) == 0:
raise HuggingfaceError( raise HuggingfaceError(
status_code=400, status_code=400,
message="sentence transformers requires 2+ sentences", message="sentence transformers requires 2+ sentences",
) )
data = { data = {"inputs": {"source_sentence": input[0], "sentences": input[1:]}}
"inputs": {
"source_sentence": input[0],
"sentences": [
"That is a happy dog",
"That is a very happy person",
"Today is a sunny day",
],
}
}
else: else:
data = {"inputs": input} # type: ignore data = {"inputs": input} # type: ignore
## LOGGING task_type = optional_params.pop("input_type", None)
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": embed_url,
},
)
## COMPLETION CALL
response = requests.post(embed_url, headers=headers, data=json.dumps(data))
## LOGGING if call_type == "sync":
logging_obj.post_call( hf_task = get_hf_task_embedding_for_model(
input=input, model=model, task_type=task_type, api_base=embed_url
api_key=api_key, )
additional_args={"complete_input_dict": data}, elif call_type == "async":
original_response=response, return self._async_transform_input(
) model=model, task_type=task_type, embed_url=embed_url, input=input
) # type: ignore
embeddings = response.json() data = self._transform_input_on_pipeline_tag(
input=input, pipeline_tag=hf_task
)
if "error" in embeddings: return data
raise HuggingfaceError(status_code=500, message=embeddings["error"])
def _process_embedding_response(
self,
embeddings: dict,
model_response: litellm.EmbeddingResponse,
model: str,
input: List,
encoding: Callable,
) -> litellm.EmbeddingResponse:
output_data = [] output_data = []
if "similarities" in embeddings: if "similarities" in embeddings:
for idx, embedding in embeddings["similarities"]: for idx, embedding in embeddings["similarities"]:
@ -888,3 +948,156 @@ class Huggingface(BaseLLM):
), ),
) )
return model_response return model_response
async def aembedding(
self,
model: str,
input: list,
model_response: litellm.utils.EmbeddingResponse,
timeout: Union[float, httpx.Timeout],
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
api_base: str,
api_key: Optional[str],
headers: dict,
encoding: Callable,
client: Optional[AsyncHTTPHandler] = None,
):
## TRANSFORMATION ##
data = self._transform_input(
input=input,
model=model,
call_type="sync",
optional_params=optional_params,
embed_url=api_base,
)
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": api_base,
},
)
## COMPLETION CALL
if client is None:
client = AsyncHTTPHandler(concurrent_limit=1)
response = await client.post(api_base, headers=headers, data=json.dumps(data))
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
embeddings = response.json()
if "error" in embeddings:
raise HuggingfaceError(status_code=500, message=embeddings["error"])
## PROCESS RESPONSE ##
return self._process_embedding_response(
embeddings=embeddings,
model_response=model_response,
model=model,
input=input,
encoding=encoding,
)
def embedding(
self,
model: str,
input: list,
model_response: litellm.EmbeddingResponse,
optional_params: dict,
logging_obj: LiteLLMLoggingObj,
encoding: Callable,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
aembedding: Optional[bool] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
) -> litellm.EmbeddingResponse:
super().embedding()
headers = self.validate_environment(api_key, headers=None)
# print_verbose(f"{model}, {task}")
embed_url = ""
if "https" in model:
embed_url = model
elif api_base:
embed_url = api_base
elif "HF_API_BASE" in os.environ:
embed_url = os.getenv("HF_API_BASE", "")
elif "HUGGINGFACE_API_BASE" in os.environ:
embed_url = os.getenv("HUGGINGFACE_API_BASE", "")
else:
embed_url = f"https://api-inference.huggingface.co/models/{model}"
## ROUTING ##
if aembedding is True:
return self.aembedding(
input=input,
model_response=model_response,
timeout=timeout,
logging_obj=logging_obj,
headers=headers,
api_base=embed_url, # type: ignore
api_key=api_key,
client=client if isinstance(client, AsyncHTTPHandler) else None,
model=model,
optional_params=optional_params,
encoding=encoding,
)
## TRANSFORMATION ##
data = self._transform_input(
input=input,
model=model,
call_type="sync",
optional_params=optional_params,
embed_url=embed_url,
)
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": embed_url,
},
)
## COMPLETION CALL
if client is None or not isinstance(client, HTTPHandler):
client = HTTPHandler(concurrent_limit=1)
response = client.post(embed_url, headers=headers, data=json.dumps(data))
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
embeddings = response.json()
if "error" in embeddings:
raise HuggingfaceError(status_code=500, message=embeddings["error"])
## PROCESS RESPONSE ##
return self._process_embedding_response(
embeddings=embeddings,
model_response=model_response,
model=model,
input=input,
encoding=encoding,
)

View file

@ -3114,6 +3114,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
or custom_llm_provider == "vertex_ai" or custom_llm_provider == "vertex_ai"
or custom_llm_provider == "databricks" or custom_llm_provider == "databricks"
or custom_llm_provider == "watsonx" or custom_llm_provider == "watsonx"
or custom_llm_provider == "huggingface"
): # currently implemented aiohttp calls for just azure and openai, soon all. ): # currently implemented aiohttp calls for just azure and openai, soon all.
# Await normally # Await normally
init_response = await loop.run_in_executor(None, func_with_context) init_response = await loop.run_in_executor(None, func_with_context)
@ -3450,7 +3451,7 @@ def embedding(
or litellm.huggingface_key or litellm.huggingface_key
or get_secret("HUGGINGFACE_API_KEY") or get_secret("HUGGINGFACE_API_KEY")
or litellm.api_key or litellm.api_key
) ) # type: ignore
response = huggingface.embedding( response = huggingface.embedding(
model=model, model=model,
input=input, input=input,
@ -3459,6 +3460,9 @@ def embedding(
api_base=api_base, api_base=api_base,
logging_obj=logging, logging_obj=logging,
model_response=EmbeddingResponse(), model_response=EmbeddingResponse(),
optional_params=optional_params,
client=client,
aembedding=aembedding,
) )
elif custom_llm_provider == "bedrock": elif custom_llm_provider == "bedrock":
response = bedrock.embedding( response = bedrock.embedding(

View file

@ -409,6 +409,62 @@ def test_hf_embedding():
# test_hf_embedding() # test_hf_embedding()
from unittest.mock import MagicMock, patch
def tgi_mock_post(*args, **kwargs):
import json
expected_data = {
"inputs": {
"source_sentence": "good morning from litellm",
"sentences": ["this is another item"],
}
}
assert (
json.loads(kwargs["data"]) == expected_data
), "Data does not match the expected data"
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json.return_value = [0.7708950042724609]
return mock_response
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_hf_embedding_sentence_sim(sync_mode):
try:
# huggingface/microsoft/codebert-base
# huggingface/facebook/bart-large
if sync_mode is True:
client = HTTPHandler(concurrent_limit=1)
else:
client = AsyncHTTPHandler(concurrent_limit=1)
with patch.object(client, "post", side_effect=tgi_mock_post) as mock_client:
data = {
"model": "huggingface/TaylorAI/bge-micro-v2",
"input": ["good morning from litellm", "this is another item"],
"client": client,
}
if sync_mode is True:
response = embedding(**data)
else:
response = await litellm.aembedding(**data)
print(f"response:", response)
mock_client.assert_called_once()
assert isinstance(response.usage, litellm.Usage)
except Exception as e:
# Note: Huggingface inference API is unstable and fails with "model loading errors all the time"
raise e
# test async embeddings # test async embeddings
def test_aembedding(): def test_aembedding():