diff --git a/docs/my-website/docs/embedding/supported_embedding.md b/docs/my-website/docs/embedding/supported_embedding.md index 73ac47755..aa3c2c4c5 100644 --- a/docs/my-website/docs/embedding/supported_embedding.md +++ b/docs/my-website/docs/embedding/supported_embedding.md @@ -270,7 +270,7 @@ response = embedding( | embed-multilingual-v2.0 | `embedding(model="embed-multilingual-v2.0", input=["good morning from litellm", "this is another item"])` | ## HuggingFace Embedding Models -LiteLLM supports all Feature-Extraction Embedding models: https://huggingface.co/models?pipeline_tag=feature-extraction +LiteLLM supports all Feature-Extraction + Sentence Similarity Embedding models: https://huggingface.co/models?pipeline_tag=feature-extraction ### Usage ```python @@ -282,6 +282,25 @@ response = embedding( input=["good morning from litellm"] ) ``` + +### Usage - Set input_type + +LiteLLM infers input type (feature-extraction or sentence-similarity) by making a GET request to the api base. + +Override this, by setting the `input_type` yourself. + +```python +from litellm import embedding +import os +os.environ['HUGGINGFACE_API_KEY'] = "" +response = embedding( + model='huggingface/microsoft/codebert-base', + input=["good morning from litellm", "you are a good bot"], + api_base = "https://p69xlsj6rpno5drq.us-east-1.aws.endpoints.huggingface.cloud", + input_type="sentence-similarity" +) +``` + ### Usage - Custom API Base ```python from litellm import embedding diff --git a/litellm/llms/cohere.py b/litellm/llms/cohere.py index f3c2770b3..3873027b2 100644 --- a/litellm/llms/cohere.py +++ b/litellm/llms/cohere.py @@ -1,3 +1,6 @@ +#################### OLD ######################## +##### See `cohere_chat.py` for `/chat` calls #### +################################################# import json import os import time diff --git a/litellm/llms/huggingface_restapi.py b/litellm/llms/huggingface_restapi.py index 8b755e2bb..2910a644e 100644 --- a/litellm/llms/huggingface_restapi.py +++ b/litellm/llms/huggingface_restapi.py @@ -6,12 +6,13 @@ import os import time import types from enum import Enum -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union, get_args import httpx import requests import litellm +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.types.completion import ChatCompletionMessageToolCallParam from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage @@ -60,6 +61,10 @@ hf_tasks = Literal[ "text-generation", ] +hf_tasks_embeddings = Literal[ # pipeline tags + hf tei endpoints - https://huggingface.github.io/text-embeddings-inference/#/ + "sentence-similarity", "feature-extraction", "rerank", "embed", "similarity" +] + class HuggingfaceConfig: """ @@ -249,6 +254,55 @@ def get_hf_task_for_model(model: str) -> Tuple[hf_tasks, str]: return "text-generation-inference", model # default to tgi +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler + + +def get_hf_task_embedding_for_model( + model: str, task_type: Optional[str], api_base: str +) -> Optional[str]: + if task_type is not None: + if task_type in get_args(hf_tasks_embeddings): + return task_type + else: + raise Exception( + "Invalid task_type={}. Expected one of={}".format( + task_type, hf_tasks_embeddings + ) + ) + http_client = HTTPHandler(concurrent_limit=1) + + model_info = http_client.get(url=api_base) + + model_info_dict = model_info.json() + + pipeline_tag: Optional[str] = model_info_dict.get("pipeline_tag", None) + + return pipeline_tag + + +async def async_get_hf_task_embedding_for_model( + model: str, task_type: Optional[str], api_base: str +) -> Optional[str]: + if task_type is not None: + if task_type in get_args(hf_tasks_embeddings): + return task_type + else: + raise Exception( + "Invalid task_type={}. Expected one of={}".format( + task_type, hf_tasks_embeddings + ) + ) + http_client = AsyncHTTPHandler(concurrent_limit=1) + + model_info = await http_client.get(url=api_base) + + model_info_dict = model_info.json() + + pipeline_tag: Optional[str] = model_info_dict.get("pipeline_tag", None) + + return pipeline_tag + + class Huggingface(BaseLLM): _client_session: Optional[httpx.Client] = None _aclient_session: Optional[httpx.AsyncClient] = None @@ -256,7 +310,7 @@ class Huggingface(BaseLLM): def __init__(self) -> None: super().__init__() - def validate_environment(self, api_key, headers): + def _validate_environment(self, api_key, headers) -> dict: default_headers = { "content-type": "application/json", } @@ -406,7 +460,7 @@ class Huggingface(BaseLLM): super().completion() exception_mapping_worked = False try: - headers = self.validate_environment(api_key, headers) + headers = self._validate_environment(api_key, headers) task, model = get_hf_task_for_model(model) ## VALIDATE API FORMAT if task is None or not isinstance(task, str) or task not in hf_task_list: @@ -762,76 +816,82 @@ class Huggingface(BaseLLM): async for transformed_chunk in streamwrapper: yield transformed_chunk - def embedding( - self, - model: str, - input: list, - model_response: litellm.EmbeddingResponse, - api_key: Optional[str] = None, - api_base: Optional[str] = None, - logging_obj=None, - encoding=None, - ): - super().embedding() - headers = self.validate_environment(api_key, headers=None) - # print_verbose(f"{model}, {task}") - embed_url = "" - if "https" in model: - embed_url = model - elif api_base: - embed_url = api_base - elif "HF_API_BASE" in os.environ: - embed_url = os.getenv("HF_API_BASE", "") - elif "HUGGINGFACE_API_BASE" in os.environ: - embed_url = os.getenv("HUGGINGFACE_API_BASE", "") - else: - embed_url = f"https://api-inference.huggingface.co/models/{model}" + def _transform_input_on_pipeline_tag( + self, input: List, pipeline_tag: Optional[str] + ) -> dict: + if pipeline_tag is None: + return {"inputs": input} + if pipeline_tag == "sentence-similarity" or pipeline_tag == "similarity": + if len(input) < 2: + raise HuggingfaceError( + status_code=400, + message="sentence-similarity requires 2+ sentences", + ) + return {"inputs": {"source_sentence": input[0], "sentences": input[1:]}} + elif pipeline_tag == "rerank": + if len(input) < 2: + raise HuggingfaceError( + status_code=400, + message="reranker requires 2+ sentences", + ) + return {"inputs": {"query": input[0], "texts": input[1:]}} + return {"inputs": input} # default to feature-extraction pipeline tag + async def _async_transform_input( + self, model: str, task_type: Optional[str], embed_url: str, input: List + ) -> dict: + hf_task = await async_get_hf_task_embedding_for_model( + model=model, task_type=task_type, api_base=embed_url + ) + + data = self._transform_input_on_pipeline_tag(input=input, pipeline_tag=hf_task) + + return data + + def _transform_input( + self, + input: List, + model: str, + call_type: Literal["sync", "async"], + optional_params: dict, + embed_url: str, + ) -> dict: + ## TRANSFORMATION ## if "sentence-transformers" in model: if len(input) == 0: raise HuggingfaceError( status_code=400, message="sentence transformers requires 2+ sentences", ) - data = { - "inputs": { - "source_sentence": input[0], - "sentences": [ - "That is a happy dog", - "That is a very happy person", - "Today is a sunny day", - ], - } - } + data = {"inputs": {"source_sentence": input[0], "sentences": input[1:]}} else: data = {"inputs": input} # type: ignore - ## LOGGING - logging_obj.pre_call( - input=input, - api_key=api_key, - additional_args={ - "complete_input_dict": data, - "headers": headers, - "api_base": embed_url, - }, - ) - ## COMPLETION CALL - response = requests.post(embed_url, headers=headers, data=json.dumps(data)) + task_type = optional_params.pop("input_type", None) - ## LOGGING - logging_obj.post_call( - input=input, - api_key=api_key, - additional_args={"complete_input_dict": data}, - original_response=response, - ) + if call_type == "sync": + hf_task = get_hf_task_embedding_for_model( + model=model, task_type=task_type, api_base=embed_url + ) + elif call_type == "async": + return self._async_transform_input( + model=model, task_type=task_type, embed_url=embed_url, input=input + ) # type: ignore - embeddings = response.json() + data = self._transform_input_on_pipeline_tag( + input=input, pipeline_tag=hf_task + ) - if "error" in embeddings: - raise HuggingfaceError(status_code=500, message=embeddings["error"]) + return data + def _process_embedding_response( + self, + embeddings: dict, + model_response: litellm.EmbeddingResponse, + model: str, + input: List, + encoding: Any, + ) -> litellm.EmbeddingResponse: output_data = [] if "similarities" in embeddings: for idx, embedding in embeddings["similarities"]: @@ -888,3 +948,156 @@ class Huggingface(BaseLLM): ), ) return model_response + + async def aembedding( + self, + model: str, + input: list, + model_response: litellm.utils.EmbeddingResponse, + timeout: Union[float, httpx.Timeout], + logging_obj: LiteLLMLoggingObj, + optional_params: dict, + api_base: str, + api_key: Optional[str], + headers: dict, + encoding: Callable, + client: Optional[AsyncHTTPHandler] = None, + ): + ## TRANSFORMATION ## + data = self._transform_input( + input=input, + model=model, + call_type="sync", + optional_params=optional_params, + embed_url=api_base, + ) + + ## LOGGING + logging_obj.pre_call( + input=input, + api_key=api_key, + additional_args={ + "complete_input_dict": data, + "headers": headers, + "api_base": api_base, + }, + ) + ## COMPLETION CALL + if client is None: + client = AsyncHTTPHandler(concurrent_limit=1) + + response = await client.post(api_base, headers=headers, data=json.dumps(data)) + + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=response, + ) + + embeddings = response.json() + + if "error" in embeddings: + raise HuggingfaceError(status_code=500, message=embeddings["error"]) + + ## PROCESS RESPONSE ## + return self._process_embedding_response( + embeddings=embeddings, + model_response=model_response, + model=model, + input=input, + encoding=encoding, + ) + + def embedding( + self, + model: str, + input: list, + model_response: litellm.EmbeddingResponse, + optional_params: dict, + logging_obj: LiteLLMLoggingObj, + encoding: Callable, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + timeout: Union[float, httpx.Timeout] = httpx.Timeout(None), + aembedding: Optional[bool] = None, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, + ) -> litellm.EmbeddingResponse: + super().embedding() + headers = self._validate_environment(api_key, headers=None) + # print_verbose(f"{model}, {task}") + embed_url = "" + if "https" in model: + embed_url = model + elif api_base: + embed_url = api_base + elif "HF_API_BASE" in os.environ: + embed_url = os.getenv("HF_API_BASE", "") + elif "HUGGINGFACE_API_BASE" in os.environ: + embed_url = os.getenv("HUGGINGFACE_API_BASE", "") + else: + embed_url = f"https://api-inference.huggingface.co/models/{model}" + + ## ROUTING ## + if aembedding is True: + return self.aembedding( + input=input, + model_response=model_response, + timeout=timeout, + logging_obj=logging_obj, + headers=headers, + api_base=embed_url, # type: ignore + api_key=api_key, + client=client if isinstance(client, AsyncHTTPHandler) else None, + model=model, + optional_params=optional_params, + encoding=encoding, + ) + + ## TRANSFORMATION ## + + data = self._transform_input( + input=input, + model=model, + call_type="sync", + optional_params=optional_params, + embed_url=embed_url, + ) + + ## LOGGING + logging_obj.pre_call( + input=input, + api_key=api_key, + additional_args={ + "complete_input_dict": data, + "headers": headers, + "api_base": embed_url, + }, + ) + ## COMPLETION CALL + if client is None or not isinstance(client, HTTPHandler): + client = HTTPHandler(concurrent_limit=1) + response = client.post(embed_url, headers=headers, data=json.dumps(data)) + + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=response, + ) + + embeddings = response.json() + + if "error" in embeddings: + raise HuggingfaceError(status_code=500, message=embeddings["error"]) + + ## PROCESS RESPONSE ## + return self._process_embedding_response( + embeddings=embeddings, + model_response=model_response, + model=model, + input=input, + encoding=encoding, + ) diff --git a/litellm/llms/ollama.py b/litellm/llms/ollama.py index 7b15582f4..6b984e1d8 100644 --- a/litellm/llms/ollama.py +++ b/litellm/llms/ollama.py @@ -258,7 +258,7 @@ def get_ollama_response( logging_obj=logging_obj, ) return response - elif stream == True: + elif stream is True: return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj) response = requests.post( @@ -326,7 +326,7 @@ def ollama_completion_stream(url, data, logging_obj): try: if response.status_code != 200: raise OllamaError( - status_code=response.status_code, message=response.text + status_code=response.status_code, message=response.read() ) streamwrapper = litellm.CustomStreamWrapper( diff --git a/litellm/main.py b/litellm/main.py index 528cbf071..429efb6c0 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -3115,6 +3115,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse: or custom_llm_provider == "databricks" or custom_llm_provider == "watsonx" or custom_llm_provider == "cohere" + or custom_llm_provider == "huggingface" ): # currently implemented aiohttp calls for just azure and openai, soon all. # Await normally init_response = await loop.run_in_executor(None, func_with_context) @@ -3454,15 +3455,18 @@ def embedding( or litellm.huggingface_key or get_secret("HUGGINGFACE_API_KEY") or litellm.api_key - ) + ) # type: ignore response = huggingface.embedding( model=model, input=input, - encoding=encoding, + encoding=encoding, # type: ignore api_key=api_key, api_base=api_base, logging_obj=logging, model_response=EmbeddingResponse(), + optional_params=optional_params, + client=client, + aembedding=aembedding, ) elif custom_llm_provider == "bedrock": response = bedrock.embedding( diff --git a/litellm/router.py b/litellm/router.py index d72f3ea5e..fcbd3a230 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -2944,7 +2944,7 @@ class Router: elif isinstance(id, int): id = str(id) - total_tokens = completion_response["usage"]["total_tokens"] + total_tokens = completion_response["usage"].get("total_tokens", 0) # ------------ # Setup values diff --git a/litellm/tests/test_embedding.py b/litellm/tests/test_embedding.py index c44967a9a..62487b488 100644 --- a/litellm/tests/test_embedding.py +++ b/litellm/tests/test_embedding.py @@ -415,6 +415,62 @@ def test_hf_embedding(): # test_hf_embedding() +from unittest.mock import MagicMock, patch + + +def tgi_mock_post(*args, **kwargs): + import json + + expected_data = { + "inputs": { + "source_sentence": "good morning from litellm", + "sentences": ["this is another item"], + } + } + assert ( + json.loads(kwargs["data"]) == expected_data + ), "Data does not match the expected data" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {"Content-Type": "application/json"} + mock_response.json.return_value = [0.7708950042724609] + return mock_response + + +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler + + +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_hf_embedding_sentence_sim(sync_mode): + try: + # huggingface/microsoft/codebert-base + # huggingface/facebook/bart-large + if sync_mode is True: + client = HTTPHandler(concurrent_limit=1) + else: + client = AsyncHTTPHandler(concurrent_limit=1) + with patch.object(client, "post", side_effect=tgi_mock_post) as mock_client: + data = { + "model": "huggingface/TaylorAI/bge-micro-v2", + "input": ["good morning from litellm", "this is another item"], + "client": client, + } + if sync_mode is True: + response = embedding(**data) + else: + response = await litellm.aembedding(**data) + + print(f"response:", response) + + mock_client.assert_called_once() + + assert isinstance(response.usage, litellm.Usage) + + except Exception as e: + # Note: Huggingface inference API is unstable and fails with "model loading errors all the time" + raise e + # test async embeddings def test_aembedding(): diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 9aebc0f24..cebac19ea 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -571,6 +571,8 @@ async def test_completion_predibase_streaming(sync_mode): pass except litellm.InternalServerError as e: pass + except litellm.ServiceUnavailableError as e: + pass except Exception as e: print("ERROR class", e.__class__) print("ERROR message", e) @@ -1764,6 +1766,7 @@ async def test_sagemaker_streaming_async(): # asyncio.run(test_sagemaker_streaming_async()) +@pytest.mark.skip(reason="costly sagemaker deployment. Move to mock implementation") def test_completion_sagemaker_stream(): try: response = completion(