Merge pull request #3571 from BerriAI/litellm_hf_classifier_support

Huggingface classifier support
This commit is contained in:
Krish Dholakia 2024-05-10 17:54:27 -07:00 committed by GitHub
commit 1aa567f3b5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 415 additions and 64 deletions

View file

@ -21,6 +21,11 @@ This is done by adding the "huggingface/" prefix to `model`, example `completion
<Tabs> <Tabs>
<TabItem value="tgi" label="Text-generation-interface (TGI)"> <TabItem value="tgi" label="Text-generation-interface (TGI)">
By default, LiteLLM will assume a huggingface call follows the TGI format.
<Tabs>
<TabItem value="sdk" label="SDK">
```python ```python
import os import os
from litellm import completion from litellm import completion
@ -40,9 +45,58 @@ response = completion(
print(response) print(response)
``` ```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add models to your config.yaml
```yaml
model_list:
- model_name: wizard-coder
litellm_params:
model: huggingface/WizardLM/WizardCoder-Python-34B-V1.0
api_key: os.environ/HUGGINGFACE_API_KEY
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml --debug
```
3. Test it!
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "wizard-coder",
"messages": [
{
"role": "user",
"content": "I like you!"
}
],
}'
```
</TabItem>
</Tabs>
</TabItem> </TabItem>
<TabItem value="conv" label="Conversational-task (BlenderBot, etc.)"> <TabItem value="conv" label="Conversational-task (BlenderBot, etc.)">
Append `conversational` to the model name
e.g. `huggingface/conversational/<model-name>`
<Tabs>
<TabItem value="sdk" label="SDK">
```python ```python
import os import os
from litellm import completion from litellm import completion
@ -54,7 +108,7 @@ messages = [{ "content": "There's a llama in my garden 😱 What should I do?","
# e.g. Call 'facebook/blenderbot-400M-distill' hosted on HF Inference endpoints # e.g. Call 'facebook/blenderbot-400M-distill' hosted on HF Inference endpoints
response = completion( response = completion(
model="huggingface/facebook/blenderbot-400M-distill", model="huggingface/conversational/facebook/blenderbot-400M-distill",
messages=messages, messages=messages,
api_base="https://my-endpoint.huggingface.cloud" api_base="https://my-endpoint.huggingface.cloud"
) )
@ -62,7 +116,123 @@ response = completion(
print(response) print(response)
``` ```
</TabItem> </TabItem>
<TabItem value="none" label="Non TGI/Conversational-task LLMs"> <TabItem value="proxy" label="PROXY">
1. Add models to your config.yaml
```yaml
model_list:
- model_name: blenderbot
litellm_params:
model: huggingface/conversational/facebook/blenderbot-400M-distill
api_key: os.environ/HUGGINGFACE_API_KEY
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml --debug
```
3. Test it!
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "blenderbot",
"messages": [
{
"role": "user",
"content": "I like you!"
}
],
}'
```
</TabItem>
</Tabs>
</TabItem>
<TabItem value="classification" label="Text Classification">
Append `text-classification` to the model name
e.g. `huggingface/text-classification/<model-name>`
<Tabs>
<TabItem value="sdk" label="SDK">
```python
import os
from litellm import completion
# [OPTIONAL] set env var
os.environ["HUGGINGFACE_API_KEY"] = "huggingface_api_key"
messages = [{ "content": "I like you, I love you!","role": "user"}]
# e.g. Call 'shahrukhx01/question-vs-statement-classifier' hosted on HF Inference endpoints
response = completion(
model="huggingface/text-classification/shahrukhx01/question-vs-statement-classifier",
messages=messages,
api_base="https://my-endpoint.endpoints.huggingface.cloud",
)
print(response)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add models to your config.yaml
```yaml
model_list:
- model_name: bert-classifier
litellm_params:
model: huggingface/text-classification/shahrukhx01/question-vs-statement-classifier
api_key: os.environ/HUGGINGFACE_API_KEY
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml --debug
```
3. Test it!
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "bert-classifier",
"messages": [
{
"role": "user",
"content": "I like you!"
}
],
}'
```
</TabItem>
</Tabs>
</TabItem>
<TabItem value="none" label="Text Generation (NOT TGI)">
Append `text-generation` to the model name
e.g. `huggingface/text-generation/<model-name>`
```python ```python
import os import os
@ -75,7 +245,7 @@ messages = [{ "content": "There's a llama in my garden 😱 What should I do?","
# e.g. Call 'roneneldan/TinyStories-3M' hosted on HF Inference endpoints # e.g. Call 'roneneldan/TinyStories-3M' hosted on HF Inference endpoints
response = completion( response = completion(
model="huggingface/roneneldan/TinyStories-3M", model="huggingface/text-generation/roneneldan/TinyStories-3M",
messages=messages, messages=messages,
api_base="https://p69xlsj6rpno5drq.us-east-1.aws.endpoints.huggingface.cloud", api_base="https://p69xlsj6rpno5drq.us-east-1.aws.endpoints.huggingface.cloud",
) )

View file

@ -6,10 +6,12 @@ import httpx, requests
from .base import BaseLLM from .base import BaseLLM
import time import time
import litellm import litellm
from typing import Callable, Dict, List, Any from typing import Callable, Dict, List, Any, Literal
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, Usage from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, Usage
from typing import Optional from typing import Optional
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.types.completion import ChatCompletionMessageToolCallParam
import enum
class HuggingfaceError(Exception): class HuggingfaceError(Exception):
@ -39,11 +41,29 @@ class HuggingfaceError(Exception):
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
hf_task_list = [
"text-generation-inference",
"conversational",
"text-classification",
"text-generation",
]
hf_tasks = Literal[
"text-generation-inference",
"conversational",
"text-classification",
"text-generation",
]
class HuggingfaceConfig: class HuggingfaceConfig:
""" """
Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
""" """
hf_task: Optional[hf_tasks] = (
None # litellm-specific param, used to know the api spec to use when calling huggingface api
)
best_of: Optional[int] = None best_of: Optional[int] = None
decoder_input_details: Optional[bool] = None decoder_input_details: Optional[bool] = None
details: Optional[bool] = True # enables returning logprobs + best of details: Optional[bool] = True # enables returning logprobs + best of
@ -101,6 +121,51 @@ class HuggingfaceConfig:
and v is not None and v is not None
} }
def get_supported_openai_params(self):
return [
"stream",
"temperature",
"max_tokens",
"top_p",
"stop",
"n",
"echo",
]
def map_openai_params(
self, non_default_params: dict, optional_params: dict
) -> dict:
for param, value in non_default_params.items():
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
if param == "temperature":
if value == 0.0 or value == 0:
# hugging face exception raised when temp==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
value = 0.01
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "n":
optional_params["best_of"] = value
optional_params["do_sample"] = (
True # Need to sample if you want best of for hf inference endpoints
)
if param == "stream":
optional_params["stream"] = value
if param == "stop":
optional_params["stop"] = value
if param == "max_tokens":
# HF TGI raises the following exception when max_new_tokens==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
if value == 0:
value = 1
optional_params["max_new_tokens"] = value
if param == "echo":
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
optional_params["decoder_input_details"] = True
return optional_params
def output_parser(generated_text: str): def output_parser(generated_text: str):
""" """
@ -162,16 +227,18 @@ def read_tgi_conv_models():
return set(), set() return set(), set()
def get_hf_task_for_model(model): def get_hf_task_for_model(model: str) -> hf_tasks:
# read text file, cast it to set # read text file, cast it to set
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt" # read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
if model.split("/")[0] in hf_task_list:
return model.split("/")[0] # type: ignore
tgi_models, conversational_models = read_tgi_conv_models() tgi_models, conversational_models = read_tgi_conv_models()
if model in tgi_models: if model in tgi_models:
return "text-generation-inference" return "text-generation-inference"
elif model in conversational_models: elif model in conversational_models:
return "conversational" return "conversational"
elif "roneneldan/TinyStories" in model: elif "roneneldan/TinyStories" in model:
return None return "text-generation"
else: else:
return "text-generation-inference" # default to tgi return "text-generation-inference" # default to tgi
@ -202,7 +269,7 @@ class Huggingface(BaseLLM):
self, self,
completion_response, completion_response,
model_response, model_response,
task, task: hf_tasks,
optional_params, optional_params,
encoding, encoding,
input_text, input_text,
@ -270,6 +337,10 @@ class Huggingface(BaseLLM):
) )
choices_list.append(choice_obj) choices_list.append(choice_obj)
model_response["choices"].extend(choices_list) model_response["choices"].extend(choices_list)
elif task == "text-classification":
model_response["choices"][0]["message"]["content"] = json.dumps(
completion_response
)
else: else:
if len(completion_response[0]["generated_text"]) > 0: if len(completion_response[0]["generated_text"]) > 0:
model_response["choices"][0]["message"]["content"] = output_parser( model_response["choices"][0]["message"]["content"] = output_parser(
@ -333,6 +404,12 @@ class Huggingface(BaseLLM):
try: try:
headers = self.validate_environment(api_key, headers) headers = self.validate_environment(api_key, headers)
task = get_hf_task_for_model(model) task = get_hf_task_for_model(model)
## VALIDATE API FORMAT
if task is None or not isinstance(task, str) or task not in hf_task_list:
raise Exception(
"Invalid hf task - {}. Valid formats - {}.".format(task, hf_tasks)
)
print_verbose(f"{model}, {task}") print_verbose(f"{model}, {task}")
completion_url = "" completion_url = ""
input_text = "" input_text = ""
@ -433,14 +510,15 @@ class Huggingface(BaseLLM):
inference_params.pop("return_full_text") inference_params.pop("return_full_text")
data = { data = {
"inputs": prompt, "inputs": prompt,
"parameters": inference_params, }
"stream": ( # type: ignore if task == "text-generation-inference":
data["parameters"] = inference_params
data["stream"] = ( # type: ignore
True True
if "stream" in optional_params if "stream" in optional_params
and optional_params["stream"] == True and optional_params["stream"] == True
else False else False
), )
}
input_text = prompt input_text = prompt
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
@ -531,10 +609,10 @@ class Huggingface(BaseLLM):
isinstance(completion_response, dict) isinstance(completion_response, dict)
and "error" in completion_response and "error" in completion_response
): ):
print_verbose(f"completion error: {completion_response['error']}") print_verbose(f"completion error: {completion_response['error']}") # type: ignore
print_verbose(f"response.status_code: {response.status_code}") print_verbose(f"response.status_code: {response.status_code}")
raise HuggingfaceError( raise HuggingfaceError(
message=completion_response["error"], message=completion_response["error"], # type: ignore
status_code=response.status_code, status_code=response.status_code,
) )
return self.convert_to_model_response_object( return self.convert_to_model_response_object(
@ -563,7 +641,7 @@ class Huggingface(BaseLLM):
data: dict, data: dict,
headers: dict, headers: dict,
model_response: ModelResponse, model_response: ModelResponse,
task: str, task: hf_tasks,
encoding: Any, encoding: Any,
input_text: str, input_text: str,
model: str, model: str,

View file

@ -13,6 +13,7 @@ import litellm
from litellm import embedding, completion, completion_cost, Timeout from litellm import embedding, completion, completion_cost, Timeout
from litellm import RateLimitError from litellm import RateLimitError
from litellm.llms.prompt_templates.factory import anthropic_messages_pt from litellm.llms.prompt_templates.factory import anthropic_messages_pt
from unittest.mock import patch, MagicMock
# litellm.num_retries=3 # litellm.num_retries=3
litellm.cache = None litellm.cache = None
@ -1137,7 +1138,7 @@ def test_get_hf_task_for_model():
model = "roneneldan/TinyStories-3M" model = "roneneldan/TinyStories-3M"
model_type = litellm.llms.huggingface_restapi.get_hf_task_for_model(model) model_type = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
print(f"model:{model}, model type: {model_type}") print(f"model:{model}, model type: {model_type}")
assert model_type == None assert model_type == "text-generation"
# test_get_hf_task_for_model() # test_get_hf_task_for_model()
@ -1145,15 +1146,92 @@ def test_get_hf_task_for_model():
# ################### Hugging Face TGI models ######################## # ################### Hugging Face TGI models ########################
# # TGI model # # TGI model
# # this is a TGI model https://huggingface.co/glaiveai/glaive-coder-7b # # this is a TGI model https://huggingface.co/glaiveai/glaive-coder-7b
def hf_test_completion_tgi(): def tgi_mock_post(url, data=None, json=None, headers=None):
# litellm.set_verbose=True mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json.return_value = [
{
"generated_text": "<|assistant|>\nI'm",
"details": {
"finish_reason": "length",
"generated_tokens": 10,
"seed": None,
"prefill": [],
"tokens": [
{
"id": 28789,
"text": "<",
"logprob": -0.025222778,
"special": False,
},
{
"id": 28766,
"text": "|",
"logprob": -0.000003695488,
"special": False,
},
{
"id": 489,
"text": "ass",
"logprob": -0.0000019073486,
"special": False,
},
{
"id": 11143,
"text": "istant",
"logprob": -0.000002026558,
"special": False,
},
{
"id": 28766,
"text": "|",
"logprob": -0.0000015497208,
"special": False,
},
{
"id": 28767,
"text": ">",
"logprob": -0.0000011920929,
"special": False,
},
{
"id": 13,
"text": "\n",
"logprob": -0.00009703636,
"special": False,
},
{"id": 28737, "text": "I", "logprob": -0.1953125, "special": False},
{
"id": 28742,
"text": "'",
"logprob": -0.88183594,
"special": False,
},
{
"id": 28719,
"text": "m",
"logprob": -0.00032639503,
"special": False,
},
],
},
}
]
return mock_response
def test_hf_test_completion_tgi():
litellm.set_verbose = True
try: try:
response = completion( with patch("requests.post", side_effect=tgi_mock_post):
model="huggingface/HuggingFaceH4/zephyr-7b-beta", response = completion(
messages=[{"content": "Hello, how are you?", "role": "user"}], model="huggingface/HuggingFaceH4/zephyr-7b-beta",
) messages=[{"content": "Hello, how are you?", "role": "user"}],
# Add any assertions here to check the response max_tokens=10,
print(response) )
# Add any assertions here to check the response
print(response)
except litellm.ServiceUnavailableError as e: except litellm.ServiceUnavailableError as e:
pass pass
except Exception as e: except Exception as e:
@ -1191,6 +1269,40 @@ def hf_test_completion_tgi():
# except Exception as e: # except Exception as e:
# pytest.fail(f"Error occurred: {e}") # pytest.fail(f"Error occurred: {e}")
# hf_test_completion_none_task() # hf_test_completion_none_task()
def mock_post(url, data=None, json=None, headers=None):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json.return_value = [
[
{"label": "LABEL_0", "score": 0.9990691542625427},
{"label": "LABEL_1", "score": 0.0009308889275416732},
]
]
return mock_response
def test_hf_classifier_task():
try:
with patch("requests.post", side_effect=mock_post):
litellm.set_verbose = True
user_message = "I like you. I love you"
messages = [{"content": user_message, "role": "user"}]
response = completion(
model="huggingface/text-classification/shahrukhx01/question-vs-statement-classifier",
messages=messages,
)
print(f"response: {response}")
assert isinstance(response, litellm.ModelResponse)
assert isinstance(response.choices[0], litellm.Choices)
assert response.choices[0].message.content is not None
assert isinstance(response.choices[0].message.content, str)
except Exception as e:
pytest.fail(f"Error occurred: {str(e)}")
########################### End of Hugging Face Tests ############################################## ########################### End of Hugging Face Tests ##############################################
# def test_completion_hf_api(): # def test_completion_hf_api():
# # failing on circle ci commenting out # # failing on circle ci commenting out

View file

@ -3,7 +3,27 @@ from litellm import get_optional_params
litellm.add_function_to_prompt = True litellm.add_function_to_prompt = True
optional_params = get_optional_params( optional_params = get_optional_params(
tools= [{'type': 'function', 'function': {'description': 'Get the current weather in a given location', 'name': 'get_current_weather', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state, e.g. San Francisco, CA'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}}], model="",
tool_choice= 'auto', tools=[
{
"type": "function",
"function": {
"description": "Get the current weather in a given location",
"name": "get_current_weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
],
tool_choice="auto",
) )
assert optional_params is not None assert optional_params is not None

View file

@ -86,6 +86,7 @@ def test_azure_optional_params_embeddings():
def test_azure_gpt_optional_params_gpt_vision(): def test_azure_gpt_optional_params_gpt_vision():
# for OpenAI, Azure all extra params need to get passed as extra_body to OpenAI python. We assert we actually set extra_body here # for OpenAI, Azure all extra params need to get passed as extra_body to OpenAI python. We assert we actually set extra_body here
optional_params = litellm.utils.get_optional_params( optional_params = litellm.utils.get_optional_params(
model="",
user="John", user="John",
custom_llm_provider="azure", custom_llm_provider="azure",
max_tokens=10, max_tokens=10,
@ -125,6 +126,7 @@ def test_azure_gpt_optional_params_gpt_vision():
def test_azure_gpt_optional_params_gpt_vision_with_extra_body(): def test_azure_gpt_optional_params_gpt_vision_with_extra_body():
# if user passes extra_body, we should not over write it, we should pass it along to OpenAI python # if user passes extra_body, we should not over write it, we should pass it along to OpenAI python
optional_params = litellm.utils.get_optional_params( optional_params = litellm.utils.get_optional_params(
model="",
user="John", user="John",
custom_llm_provider="azure", custom_llm_provider="azure",
max_tokens=10, max_tokens=10,
@ -167,6 +169,7 @@ def test_azure_gpt_optional_params_gpt_vision_with_extra_body():
def test_openai_extra_headers(): def test_openai_extra_headers():
optional_params = litellm.utils.get_optional_params( optional_params = litellm.utils.get_optional_params(
model="",
user="John", user="John",
custom_llm_provider="openai", custom_llm_provider="openai",
max_tokens=10, max_tokens=10,

View file

@ -4871,6 +4871,7 @@ def get_optional_params_embeddings(
def get_optional_params( def get_optional_params(
# use the openai defaults # use the openai defaults
# https://platform.openai.com/docs/api-reference/chat/create # https://platform.openai.com/docs/api-reference/chat/create
model: str,
functions=None, functions=None,
function_call=None, function_call=None,
temperature=None, temperature=None,
@ -4884,7 +4885,6 @@ def get_optional_params(
frequency_penalty=None, frequency_penalty=None,
logit_bias=None, logit_bias=None,
user=None, user=None,
model=None,
custom_llm_provider="", custom_llm_provider="",
response_format=None, response_format=None,
seed=None, seed=None,
@ -4913,7 +4913,7 @@ def get_optional_params(
passed_params[k] = v passed_params[k] = v
optional_params = {} optional_params: Dict = {}
common_auth_dict = litellm.common_cloud_provider_auth_params common_auth_dict = litellm.common_cloud_provider_auth_params
if custom_llm_provider in common_auth_dict["providers"]: if custom_llm_provider in common_auth_dict["providers"]:
@ -5187,41 +5187,9 @@ def get_optional_params(
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider
) )
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None optional_params = litellm.HuggingfaceConfig().map_openai_params(
if temperature is not None: non_default_params=non_default_params, optional_params=optional_params
if temperature == 0.0 or temperature == 0: )
# hugging face exception raised when temp==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
temperature = 0.01
optional_params["temperature"] = temperature
if top_p is not None:
optional_params["top_p"] = top_p
if n is not None:
optional_params["best_of"] = n
optional_params["do_sample"] = (
True # Need to sample if you want best of for hf inference endpoints
)
if stream is not None:
optional_params["stream"] = stream
if stop is not None:
optional_params["stop"] = stop
if max_tokens is not None:
# HF TGI raises the following exception when max_new_tokens==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
if max_tokens == 0:
max_tokens = 1
optional_params["max_new_tokens"] = max_tokens
if n is not None:
optional_params["best_of"] = n
if presence_penalty is not None:
optional_params["repetition_penalty"] = presence_penalty
if "echo" in passed_params:
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
optional_params["decoder_input_details"] = special_params["echo"]
passed_params.pop(
"echo", None
) # since we handle translating echo, we should not send it to TGI request
elif custom_llm_provider == "together_ai": elif custom_llm_provider == "together_ai":
## check if unsupported param passed in ## check if unsupported param passed in
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
@ -6181,7 +6149,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
"seed", "seed",
] ]
elif custom_llm_provider == "huggingface": elif custom_llm_provider == "huggingface":
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"] return litellm.HuggingfaceConfig().get_supported_openai_params()
elif custom_llm_provider == "together_ai": elif custom_llm_provider == "together_ai":
return [ return [
"stream", "stream",