Merge branch 'main' into litellm_async_cohere_calls

This commit is contained in:
Krish Dholakia 2024-07-30 15:35:20 -07:00 committed by GitHub
commit 653aefde40
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 363 additions and 65 deletions

View file

@ -270,7 +270,7 @@ response = embedding(
| embed-multilingual-v2.0 | `embedding(model="embed-multilingual-v2.0", input=["good morning from litellm", "this is another item"])` | | embed-multilingual-v2.0 | `embedding(model="embed-multilingual-v2.0", input=["good morning from litellm", "this is another item"])` |
## HuggingFace Embedding Models ## HuggingFace Embedding Models
LiteLLM supports all Feature-Extraction Embedding models: https://huggingface.co/models?pipeline_tag=feature-extraction LiteLLM supports all Feature-Extraction + Sentence Similarity Embedding models: https://huggingface.co/models?pipeline_tag=feature-extraction
### Usage ### Usage
```python ```python
@ -282,6 +282,25 @@ response = embedding(
input=["good morning from litellm"] input=["good morning from litellm"]
) )
``` ```
### Usage - Set input_type
LiteLLM infers input type (feature-extraction or sentence-similarity) by making a GET request to the api base.
Override this, by setting the `input_type` yourself.
```python
from litellm import embedding
import os
os.environ['HUGGINGFACE_API_KEY'] = ""
response = embedding(
model='huggingface/microsoft/codebert-base',
input=["good morning from litellm", "you are a good bot"],
api_base = "https://p69xlsj6rpno5drq.us-east-1.aws.endpoints.huggingface.cloud",
input_type="sentence-similarity"
)
```
### Usage - Custom API Base ### Usage - Custom API Base
```python ```python
from litellm import embedding from litellm import embedding

View file

@ -1,3 +1,6 @@
#################### OLD ########################
##### See `cohere_chat.py` for `/chat` calls ####
#################################################
import json import json
import os import os
import time import time

View file

@ -6,12 +6,13 @@ import os
import time import time
import types import types
from enum import Enum from enum import Enum
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union, get_args
import httpx import httpx
import requests import requests
import litellm import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.types.completion import ChatCompletionMessageToolCallParam from litellm.types.completion import ChatCompletionMessageToolCallParam
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
@ -60,6 +61,10 @@ hf_tasks = Literal[
"text-generation", "text-generation",
] ]
hf_tasks_embeddings = Literal[ # pipeline tags + hf tei endpoints - https://huggingface.github.io/text-embeddings-inference/#/
"sentence-similarity", "feature-extraction", "rerank", "embed", "similarity"
]
class HuggingfaceConfig: class HuggingfaceConfig:
""" """
@ -249,6 +254,55 @@ def get_hf_task_for_model(model: str) -> Tuple[hf_tasks, str]:
return "text-generation-inference", model # default to tgi return "text-generation-inference", model # default to tgi
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
def get_hf_task_embedding_for_model(
model: str, task_type: Optional[str], api_base: str
) -> Optional[str]:
if task_type is not None:
if task_type in get_args(hf_tasks_embeddings):
return task_type
else:
raise Exception(
"Invalid task_type={}. Expected one of={}".format(
task_type, hf_tasks_embeddings
)
)
http_client = HTTPHandler(concurrent_limit=1)
model_info = http_client.get(url=api_base)
model_info_dict = model_info.json()
pipeline_tag: Optional[str] = model_info_dict.get("pipeline_tag", None)
return pipeline_tag
async def async_get_hf_task_embedding_for_model(
model: str, task_type: Optional[str], api_base: str
) -> Optional[str]:
if task_type is not None:
if task_type in get_args(hf_tasks_embeddings):
return task_type
else:
raise Exception(
"Invalid task_type={}. Expected one of={}".format(
task_type, hf_tasks_embeddings
)
)
http_client = AsyncHTTPHandler(concurrent_limit=1)
model_info = await http_client.get(url=api_base)
model_info_dict = model_info.json()
pipeline_tag: Optional[str] = model_info_dict.get("pipeline_tag", None)
return pipeline_tag
class Huggingface(BaseLLM): class Huggingface(BaseLLM):
_client_session: Optional[httpx.Client] = None _client_session: Optional[httpx.Client] = None
_aclient_session: Optional[httpx.AsyncClient] = None _aclient_session: Optional[httpx.AsyncClient] = None
@ -256,7 +310,7 @@ class Huggingface(BaseLLM):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
def validate_environment(self, api_key, headers): def _validate_environment(self, api_key, headers) -> dict:
default_headers = { default_headers = {
"content-type": "application/json", "content-type": "application/json",
} }
@ -406,7 +460,7 @@ class Huggingface(BaseLLM):
super().completion() super().completion()
exception_mapping_worked = False exception_mapping_worked = False
try: try:
headers = self.validate_environment(api_key, headers) headers = self._validate_environment(api_key, headers)
task, model = get_hf_task_for_model(model) task, model = get_hf_task_for_model(model)
## VALIDATE API FORMAT ## VALIDATE API FORMAT
if task is None or not isinstance(task, str) or task not in hf_task_list: if task is None or not isinstance(task, str) or task not in hf_task_list:
@ -762,76 +816,82 @@ class Huggingface(BaseLLM):
async for transformed_chunk in streamwrapper: async for transformed_chunk in streamwrapper:
yield transformed_chunk yield transformed_chunk
def embedding( def _transform_input_on_pipeline_tag(
self, self, input: List, pipeline_tag: Optional[str]
model: str, ) -> dict:
input: list, if pipeline_tag is None:
model_response: litellm.EmbeddingResponse, return {"inputs": input}
api_key: Optional[str] = None, if pipeline_tag == "sentence-similarity" or pipeline_tag == "similarity":
api_base: Optional[str] = None, if len(input) < 2:
logging_obj=None, raise HuggingfaceError(
encoding=None, status_code=400,
): message="sentence-similarity requires 2+ sentences",
super().embedding() )
headers = self.validate_environment(api_key, headers=None) return {"inputs": {"source_sentence": input[0], "sentences": input[1:]}}
# print_verbose(f"{model}, {task}") elif pipeline_tag == "rerank":
embed_url = "" if len(input) < 2:
if "https" in model: raise HuggingfaceError(
embed_url = model status_code=400,
elif api_base: message="reranker requires 2+ sentences",
embed_url = api_base )
elif "HF_API_BASE" in os.environ: return {"inputs": {"query": input[0], "texts": input[1:]}}
embed_url = os.getenv("HF_API_BASE", "") return {"inputs": input} # default to feature-extraction pipeline tag
elif "HUGGINGFACE_API_BASE" in os.environ:
embed_url = os.getenv("HUGGINGFACE_API_BASE", "")
else:
embed_url = f"https://api-inference.huggingface.co/models/{model}"
async def _async_transform_input(
self, model: str, task_type: Optional[str], embed_url: str, input: List
) -> dict:
hf_task = await async_get_hf_task_embedding_for_model(
model=model, task_type=task_type, api_base=embed_url
)
data = self._transform_input_on_pipeline_tag(input=input, pipeline_tag=hf_task)
return data
def _transform_input(
self,
input: List,
model: str,
call_type: Literal["sync", "async"],
optional_params: dict,
embed_url: str,
) -> dict:
## TRANSFORMATION ##
if "sentence-transformers" in model: if "sentence-transformers" in model:
if len(input) == 0: if len(input) == 0:
raise HuggingfaceError( raise HuggingfaceError(
status_code=400, status_code=400,
message="sentence transformers requires 2+ sentences", message="sentence transformers requires 2+ sentences",
) )
data = { data = {"inputs": {"source_sentence": input[0], "sentences": input[1:]}}
"inputs": {
"source_sentence": input[0],
"sentences": [
"That is a happy dog",
"That is a very happy person",
"Today is a sunny day",
],
}
}
else: else:
data = {"inputs": input} # type: ignore data = {"inputs": input} # type: ignore
## LOGGING task_type = optional_params.pop("input_type", None)
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": embed_url,
},
)
## COMPLETION CALL
response = requests.post(embed_url, headers=headers, data=json.dumps(data))
## LOGGING if call_type == "sync":
logging_obj.post_call( hf_task = get_hf_task_embedding_for_model(
input=input, model=model, task_type=task_type, api_base=embed_url
api_key=api_key, )
additional_args={"complete_input_dict": data}, elif call_type == "async":
original_response=response, return self._async_transform_input(
model=model, task_type=task_type, embed_url=embed_url, input=input
) # type: ignore
data = self._transform_input_on_pipeline_tag(
input=input, pipeline_tag=hf_task
) )
embeddings = response.json() return data
if "error" in embeddings:
raise HuggingfaceError(status_code=500, message=embeddings["error"])
def _process_embedding_response(
self,
embeddings: dict,
model_response: litellm.EmbeddingResponse,
model: str,
input: List,
encoding: Any,
) -> litellm.EmbeddingResponse:
output_data = [] output_data = []
if "similarities" in embeddings: if "similarities" in embeddings:
for idx, embedding in embeddings["similarities"]: for idx, embedding in embeddings["similarities"]:
@ -888,3 +948,156 @@ class Huggingface(BaseLLM):
), ),
) )
return model_response return model_response
async def aembedding(
self,
model: str,
input: list,
model_response: litellm.utils.EmbeddingResponse,
timeout: Union[float, httpx.Timeout],
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
api_base: str,
api_key: Optional[str],
headers: dict,
encoding: Callable,
client: Optional[AsyncHTTPHandler] = None,
):
## TRANSFORMATION ##
data = self._transform_input(
input=input,
model=model,
call_type="sync",
optional_params=optional_params,
embed_url=api_base,
)
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": api_base,
},
)
## COMPLETION CALL
if client is None:
client = AsyncHTTPHandler(concurrent_limit=1)
response = await client.post(api_base, headers=headers, data=json.dumps(data))
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
embeddings = response.json()
if "error" in embeddings:
raise HuggingfaceError(status_code=500, message=embeddings["error"])
## PROCESS RESPONSE ##
return self._process_embedding_response(
embeddings=embeddings,
model_response=model_response,
model=model,
input=input,
encoding=encoding,
)
def embedding(
self,
model: str,
input: list,
model_response: litellm.EmbeddingResponse,
optional_params: dict,
logging_obj: LiteLLMLoggingObj,
encoding: Callable,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
aembedding: Optional[bool] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
) -> litellm.EmbeddingResponse:
super().embedding()
headers = self._validate_environment(api_key, headers=None)
# print_verbose(f"{model}, {task}")
embed_url = ""
if "https" in model:
embed_url = model
elif api_base:
embed_url = api_base
elif "HF_API_BASE" in os.environ:
embed_url = os.getenv("HF_API_BASE", "")
elif "HUGGINGFACE_API_BASE" in os.environ:
embed_url = os.getenv("HUGGINGFACE_API_BASE", "")
else:
embed_url = f"https://api-inference.huggingface.co/models/{model}"
## ROUTING ##
if aembedding is True:
return self.aembedding(
input=input,
model_response=model_response,
timeout=timeout,
logging_obj=logging_obj,
headers=headers,
api_base=embed_url, # type: ignore
api_key=api_key,
client=client if isinstance(client, AsyncHTTPHandler) else None,
model=model,
optional_params=optional_params,
encoding=encoding,
)
## TRANSFORMATION ##
data = self._transform_input(
input=input,
model=model,
call_type="sync",
optional_params=optional_params,
embed_url=embed_url,
)
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": embed_url,
},
)
## COMPLETION CALL
if client is None or not isinstance(client, HTTPHandler):
client = HTTPHandler(concurrent_limit=1)
response = client.post(embed_url, headers=headers, data=json.dumps(data))
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
embeddings = response.json()
if "error" in embeddings:
raise HuggingfaceError(status_code=500, message=embeddings["error"])
## PROCESS RESPONSE ##
return self._process_embedding_response(
embeddings=embeddings,
model_response=model_response,
model=model,
input=input,
encoding=encoding,
)

View file

@ -258,7 +258,7 @@ def get_ollama_response(
logging_obj=logging_obj, logging_obj=logging_obj,
) )
return response return response
elif stream == True: elif stream is True:
return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj) return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj)
response = requests.post( response = requests.post(
@ -326,7 +326,7 @@ def ollama_completion_stream(url, data, logging_obj):
try: try:
if response.status_code != 200: if response.status_code != 200:
raise OllamaError( raise OllamaError(
status_code=response.status_code, message=response.text status_code=response.status_code, message=response.read()
) )
streamwrapper = litellm.CustomStreamWrapper( streamwrapper = litellm.CustomStreamWrapper(

View file

@ -3115,6 +3115,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
or custom_llm_provider == "databricks" or custom_llm_provider == "databricks"
or custom_llm_provider == "watsonx" or custom_llm_provider == "watsonx"
or custom_llm_provider == "cohere" or custom_llm_provider == "cohere"
or custom_llm_provider == "huggingface"
): # currently implemented aiohttp calls for just azure and openai, soon all. ): # currently implemented aiohttp calls for just azure and openai, soon all.
# Await normally # Await normally
init_response = await loop.run_in_executor(None, func_with_context) init_response = await loop.run_in_executor(None, func_with_context)
@ -3454,15 +3455,18 @@ def embedding(
or litellm.huggingface_key or litellm.huggingface_key
or get_secret("HUGGINGFACE_API_KEY") or get_secret("HUGGINGFACE_API_KEY")
or litellm.api_key or litellm.api_key
) ) # type: ignore
response = huggingface.embedding( response = huggingface.embedding(
model=model, model=model,
input=input, input=input,
encoding=encoding, encoding=encoding, # type: ignore
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
logging_obj=logging, logging_obj=logging,
model_response=EmbeddingResponse(), model_response=EmbeddingResponse(),
optional_params=optional_params,
client=client,
aembedding=aembedding,
) )
elif custom_llm_provider == "bedrock": elif custom_llm_provider == "bedrock":
response = bedrock.embedding( response = bedrock.embedding(

View file

@ -2944,7 +2944,7 @@ class Router:
elif isinstance(id, int): elif isinstance(id, int):
id = str(id) id = str(id)
total_tokens = completion_response["usage"]["total_tokens"] total_tokens = completion_response["usage"].get("total_tokens", 0)
# ------------ # ------------
# Setup values # Setup values

View file

@ -415,6 +415,62 @@ def test_hf_embedding():
# test_hf_embedding() # test_hf_embedding()
from unittest.mock import MagicMock, patch
def tgi_mock_post(*args, **kwargs):
import json
expected_data = {
"inputs": {
"source_sentence": "good morning from litellm",
"sentences": ["this is another item"],
}
}
assert (
json.loads(kwargs["data"]) == expected_data
), "Data does not match the expected data"
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json.return_value = [0.7708950042724609]
return mock_response
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_hf_embedding_sentence_sim(sync_mode):
try:
# huggingface/microsoft/codebert-base
# huggingface/facebook/bart-large
if sync_mode is True:
client = HTTPHandler(concurrent_limit=1)
else:
client = AsyncHTTPHandler(concurrent_limit=1)
with patch.object(client, "post", side_effect=tgi_mock_post) as mock_client:
data = {
"model": "huggingface/TaylorAI/bge-micro-v2",
"input": ["good morning from litellm", "this is another item"],
"client": client,
}
if sync_mode is True:
response = embedding(**data)
else:
response = await litellm.aembedding(**data)
print(f"response:", response)
mock_client.assert_called_once()
assert isinstance(response.usage, litellm.Usage)
except Exception as e:
# Note: Huggingface inference API is unstable and fails with "model loading errors all the time"
raise e
# test async embeddings # test async embeddings
def test_aembedding(): def test_aembedding():

View file

@ -571,6 +571,8 @@ async def test_completion_predibase_streaming(sync_mode):
pass pass
except litellm.InternalServerError as e: except litellm.InternalServerError as e:
pass pass
except litellm.ServiceUnavailableError as e:
pass
except Exception as e: except Exception as e:
print("ERROR class", e.__class__) print("ERROR class", e.__class__)
print("ERROR message", e) print("ERROR message", e)
@ -1764,6 +1766,7 @@ async def test_sagemaker_streaming_async():
# asyncio.run(test_sagemaker_streaming_async()) # asyncio.run(test_sagemaker_streaming_async())
@pytest.mark.skip(reason="costly sagemaker deployment. Move to mock implementation")
def test_completion_sagemaker_stream(): def test_completion_sagemaker_stream():
try: try:
response = completion( response = completion(