Merge pull request #3571 from BerriAI/litellm_hf_classifier_support

Huggingface classifier support
This commit is contained in:
Krish Dholakia 2024-05-10 17:54:27 -07:00 committed by GitHub
commit 1aa567f3b5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 415 additions and 64 deletions

View file

@ -21,6 +21,11 @@ This is done by adding the "huggingface/" prefix to `model`, example `completion
<Tabs>
<TabItem value="tgi" label="Text-generation-interface (TGI)">
By default, LiteLLM will assume a huggingface call follows the TGI format.
<Tabs>
<TabItem value="sdk" label="SDK">
```python
import os
from litellm import completion
@ -40,9 +45,58 @@ response = completion(
print(response)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add models to your config.yaml
```yaml
model_list:
- model_name: wizard-coder
litellm_params:
model: huggingface/WizardLM/WizardCoder-Python-34B-V1.0
api_key: os.environ/HUGGINGFACE_API_KEY
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml --debug
```
3. Test it!
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "wizard-coder",
"messages": [
{
"role": "user",
"content": "I like you!"
}
],
}'
```
</TabItem>
</Tabs>
</TabItem>
<TabItem value="conv" label="Conversational-task (BlenderBot, etc.)">
Append `conversational` to the model name
e.g. `huggingface/conversational/<model-name>`
<Tabs>
<TabItem value="sdk" label="SDK">
```python
import os
from litellm import completion
@ -54,7 +108,7 @@ messages = [{ "content": "There's a llama in my garden 😱 What should I do?","
# e.g. Call 'facebook/blenderbot-400M-distill' hosted on HF Inference endpoints
response = completion(
model="huggingface/facebook/blenderbot-400M-distill",
model="huggingface/conversational/facebook/blenderbot-400M-distill",
messages=messages,
api_base="https://my-endpoint.huggingface.cloud"
)
@ -62,7 +116,123 @@ response = completion(
print(response)
```
</TabItem>
<TabItem value="none" label="Non TGI/Conversational-task LLMs">
<TabItem value="proxy" label="PROXY">
1. Add models to your config.yaml
```yaml
model_list:
- model_name: blenderbot
litellm_params:
model: huggingface/conversational/facebook/blenderbot-400M-distill
api_key: os.environ/HUGGINGFACE_API_KEY
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml --debug
```
3. Test it!
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "blenderbot",
"messages": [
{
"role": "user",
"content": "I like you!"
}
],
}'
```
</TabItem>
</Tabs>
</TabItem>
<TabItem value="classification" label="Text Classification">
Append `text-classification` to the model name
e.g. `huggingface/text-classification/<model-name>`
<Tabs>
<TabItem value="sdk" label="SDK">
```python
import os
from litellm import completion
# [OPTIONAL] set env var
os.environ["HUGGINGFACE_API_KEY"] = "huggingface_api_key"
messages = [{ "content": "I like you, I love you!","role": "user"}]
# e.g. Call 'shahrukhx01/question-vs-statement-classifier' hosted on HF Inference endpoints
response = completion(
model="huggingface/text-classification/shahrukhx01/question-vs-statement-classifier",
messages=messages,
api_base="https://my-endpoint.endpoints.huggingface.cloud",
)
print(response)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add models to your config.yaml
```yaml
model_list:
- model_name: bert-classifier
litellm_params:
model: huggingface/text-classification/shahrukhx01/question-vs-statement-classifier
api_key: os.environ/HUGGINGFACE_API_KEY
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml --debug
```
3. Test it!
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "bert-classifier",
"messages": [
{
"role": "user",
"content": "I like you!"
}
],
}'
```
</TabItem>
</Tabs>
</TabItem>
<TabItem value="none" label="Text Generation (NOT TGI)">
Append `text-generation` to the model name
e.g. `huggingface/text-generation/<model-name>`
```python
import os
@ -75,7 +245,7 @@ messages = [{ "content": "There's a llama in my garden 😱 What should I do?","
# e.g. Call 'roneneldan/TinyStories-3M' hosted on HF Inference endpoints
response = completion(
model="huggingface/roneneldan/TinyStories-3M",
model="huggingface/text-generation/roneneldan/TinyStories-3M",
messages=messages,
api_base="https://p69xlsj6rpno5drq.us-east-1.aws.endpoints.huggingface.cloud",
)

View file

@ -6,10 +6,12 @@ import httpx, requests
from .base import BaseLLM
import time
import litellm
from typing import Callable, Dict, List, Any
from typing import Callable, Dict, List, Any, Literal
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, Usage
from typing import Optional
from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.types.completion import ChatCompletionMessageToolCallParam
import enum
class HuggingfaceError(Exception):
@ -39,11 +41,29 @@ class HuggingfaceError(Exception):
) # Call the base class constructor with the parameters it needs
hf_task_list = [
"text-generation-inference",
"conversational",
"text-classification",
"text-generation",
]
hf_tasks = Literal[
"text-generation-inference",
"conversational",
"text-classification",
"text-generation",
]
class HuggingfaceConfig:
"""
Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
"""
hf_task: Optional[hf_tasks] = (
None # litellm-specific param, used to know the api spec to use when calling huggingface api
)
best_of: Optional[int] = None
decoder_input_details: Optional[bool] = None
details: Optional[bool] = True # enables returning logprobs + best of
@ -101,6 +121,51 @@ class HuggingfaceConfig:
and v is not None
}
def get_supported_openai_params(self):
return [
"stream",
"temperature",
"max_tokens",
"top_p",
"stop",
"n",
"echo",
]
def map_openai_params(
self, non_default_params: dict, optional_params: dict
) -> dict:
for param, value in non_default_params.items():
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
if param == "temperature":
if value == 0.0 or value == 0:
# hugging face exception raised when temp==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
value = 0.01
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "n":
optional_params["best_of"] = value
optional_params["do_sample"] = (
True # Need to sample if you want best of for hf inference endpoints
)
if param == "stream":
optional_params["stream"] = value
if param == "stop":
optional_params["stop"] = value
if param == "max_tokens":
# HF TGI raises the following exception when max_new_tokens==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
if value == 0:
value = 1
optional_params["max_new_tokens"] = value
if param == "echo":
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
optional_params["decoder_input_details"] = True
return optional_params
def output_parser(generated_text: str):
"""
@ -162,16 +227,18 @@ def read_tgi_conv_models():
return set(), set()
def get_hf_task_for_model(model):
def get_hf_task_for_model(model: str) -> hf_tasks:
# read text file, cast it to set
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
if model.split("/")[0] in hf_task_list:
return model.split("/")[0] # type: ignore
tgi_models, conversational_models = read_tgi_conv_models()
if model in tgi_models:
return "text-generation-inference"
elif model in conversational_models:
return "conversational"
elif "roneneldan/TinyStories" in model:
return None
return "text-generation"
else:
return "text-generation-inference" # default to tgi
@ -202,7 +269,7 @@ class Huggingface(BaseLLM):
self,
completion_response,
model_response,
task,
task: hf_tasks,
optional_params,
encoding,
input_text,
@ -270,6 +337,10 @@ class Huggingface(BaseLLM):
)
choices_list.append(choice_obj)
model_response["choices"].extend(choices_list)
elif task == "text-classification":
model_response["choices"][0]["message"]["content"] = json.dumps(
completion_response
)
else:
if len(completion_response[0]["generated_text"]) > 0:
model_response["choices"][0]["message"]["content"] = output_parser(
@ -333,6 +404,12 @@ class Huggingface(BaseLLM):
try:
headers = self.validate_environment(api_key, headers)
task = get_hf_task_for_model(model)
## VALIDATE API FORMAT
if task is None or not isinstance(task, str) or task not in hf_task_list:
raise Exception(
"Invalid hf task - {}. Valid formats - {}.".format(task, hf_tasks)
)
print_verbose(f"{model}, {task}")
completion_url = ""
input_text = ""
@ -433,14 +510,15 @@ class Huggingface(BaseLLM):
inference_params.pop("return_full_text")
data = {
"inputs": prompt,
"parameters": inference_params,
"stream": ( # type: ignore
}
if task == "text-generation-inference":
data["parameters"] = inference_params
data["stream"] = ( # type: ignore
True
if "stream" in optional_params
and optional_params["stream"] == True
else False
),
}
)
input_text = prompt
## LOGGING
logging_obj.pre_call(
@ -531,10 +609,10 @@ class Huggingface(BaseLLM):
isinstance(completion_response, dict)
and "error" in completion_response
):
print_verbose(f"completion error: {completion_response['error']}")
print_verbose(f"completion error: {completion_response['error']}") # type: ignore
print_verbose(f"response.status_code: {response.status_code}")
raise HuggingfaceError(
message=completion_response["error"],
message=completion_response["error"], # type: ignore
status_code=response.status_code,
)
return self.convert_to_model_response_object(
@ -563,7 +641,7 @@ class Huggingface(BaseLLM):
data: dict,
headers: dict,
model_response: ModelResponse,
task: str,
task: hf_tasks,
encoding: Any,
input_text: str,
model: str,

View file

@ -13,6 +13,7 @@ import litellm
from litellm import embedding, completion, completion_cost, Timeout
from litellm import RateLimitError
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
from unittest.mock import patch, MagicMock
# litellm.num_retries=3
litellm.cache = None
@ -1137,7 +1138,7 @@ def test_get_hf_task_for_model():
model = "roneneldan/TinyStories-3M"
model_type = litellm.llms.huggingface_restapi.get_hf_task_for_model(model)
print(f"model:{model}, model type: {model_type}")
assert model_type == None
assert model_type == "text-generation"
# test_get_hf_task_for_model()
@ -1145,15 +1146,92 @@ def test_get_hf_task_for_model():
# ################### Hugging Face TGI models ########################
# # TGI model
# # this is a TGI model https://huggingface.co/glaiveai/glaive-coder-7b
def hf_test_completion_tgi():
# litellm.set_verbose=True
def tgi_mock_post(url, data=None, json=None, headers=None):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json.return_value = [
{
"generated_text": "<|assistant|>\nI'm",
"details": {
"finish_reason": "length",
"generated_tokens": 10,
"seed": None,
"prefill": [],
"tokens": [
{
"id": 28789,
"text": "<",
"logprob": -0.025222778,
"special": False,
},
{
"id": 28766,
"text": "|",
"logprob": -0.000003695488,
"special": False,
},
{
"id": 489,
"text": "ass",
"logprob": -0.0000019073486,
"special": False,
},
{
"id": 11143,
"text": "istant",
"logprob": -0.000002026558,
"special": False,
},
{
"id": 28766,
"text": "|",
"logprob": -0.0000015497208,
"special": False,
},
{
"id": 28767,
"text": ">",
"logprob": -0.0000011920929,
"special": False,
},
{
"id": 13,
"text": "\n",
"logprob": -0.00009703636,
"special": False,
},
{"id": 28737, "text": "I", "logprob": -0.1953125, "special": False},
{
"id": 28742,
"text": "'",
"logprob": -0.88183594,
"special": False,
},
{
"id": 28719,
"text": "m",
"logprob": -0.00032639503,
"special": False,
},
],
},
}
]
return mock_response
def test_hf_test_completion_tgi():
litellm.set_verbose = True
try:
response = completion(
model="huggingface/HuggingFaceH4/zephyr-7b-beta",
messages=[{"content": "Hello, how are you?", "role": "user"}],
)
# Add any assertions here to check the response
print(response)
with patch("requests.post", side_effect=tgi_mock_post):
response = completion(
model="huggingface/HuggingFaceH4/zephyr-7b-beta",
messages=[{"content": "Hello, how are you?", "role": "user"}],
max_tokens=10,
)
# Add any assertions here to check the response
print(response)
except litellm.ServiceUnavailableError as e:
pass
except Exception as e:
@ -1191,6 +1269,40 @@ def hf_test_completion_tgi():
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# hf_test_completion_none_task()
def mock_post(url, data=None, json=None, headers=None):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json.return_value = [
[
{"label": "LABEL_0", "score": 0.9990691542625427},
{"label": "LABEL_1", "score": 0.0009308889275416732},
]
]
return mock_response
def test_hf_classifier_task():
try:
with patch("requests.post", side_effect=mock_post):
litellm.set_verbose = True
user_message = "I like you. I love you"
messages = [{"content": user_message, "role": "user"}]
response = completion(
model="huggingface/text-classification/shahrukhx01/question-vs-statement-classifier",
messages=messages,
)
print(f"response: {response}")
assert isinstance(response, litellm.ModelResponse)
assert isinstance(response.choices[0], litellm.Choices)
assert response.choices[0].message.content is not None
assert isinstance(response.choices[0].message.content, str)
except Exception as e:
pytest.fail(f"Error occurred: {str(e)}")
########################### End of Hugging Face Tests ##############################################
# def test_completion_hf_api():
# # failing on circle ci commenting out

View file

@ -3,7 +3,27 @@ from litellm import get_optional_params
litellm.add_function_to_prompt = True
optional_params = get_optional_params(
tools= [{'type': 'function', 'function': {'description': 'Get the current weather in a given location', 'name': 'get_current_weather', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state, e.g. San Francisco, CA'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}}],
tool_choice= 'auto',
model="",
tools=[
{
"type": "function",
"function": {
"description": "Get the current weather in a given location",
"name": "get_current_weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
],
tool_choice="auto",
)
assert optional_params is not None
assert optional_params is not None

View file

@ -86,6 +86,7 @@ def test_azure_optional_params_embeddings():
def test_azure_gpt_optional_params_gpt_vision():
# for OpenAI, Azure all extra params need to get passed as extra_body to OpenAI python. We assert we actually set extra_body here
optional_params = litellm.utils.get_optional_params(
model="",
user="John",
custom_llm_provider="azure",
max_tokens=10,
@ -125,6 +126,7 @@ def test_azure_gpt_optional_params_gpt_vision():
def test_azure_gpt_optional_params_gpt_vision_with_extra_body():
# if user passes extra_body, we should not over write it, we should pass it along to OpenAI python
optional_params = litellm.utils.get_optional_params(
model="",
user="John",
custom_llm_provider="azure",
max_tokens=10,
@ -167,6 +169,7 @@ def test_azure_gpt_optional_params_gpt_vision_with_extra_body():
def test_openai_extra_headers():
optional_params = litellm.utils.get_optional_params(
model="",
user="John",
custom_llm_provider="openai",
max_tokens=10,

View file

@ -4871,6 +4871,7 @@ def get_optional_params_embeddings(
def get_optional_params(
# use the openai defaults
# https://platform.openai.com/docs/api-reference/chat/create
model: str,
functions=None,
function_call=None,
temperature=None,
@ -4884,7 +4885,6 @@ def get_optional_params(
frequency_penalty=None,
logit_bias=None,
user=None,
model=None,
custom_llm_provider="",
response_format=None,
seed=None,
@ -4913,7 +4913,7 @@ def get_optional_params(
passed_params[k] = v
optional_params = {}
optional_params: Dict = {}
common_auth_dict = litellm.common_cloud_provider_auth_params
if custom_llm_provider in common_auth_dict["providers"]:
@ -5187,41 +5187,9 @@ def get_optional_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
if temperature is not None:
if temperature == 0.0 or temperature == 0:
# hugging face exception raised when temp==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
temperature = 0.01
optional_params["temperature"] = temperature
if top_p is not None:
optional_params["top_p"] = top_p
if n is not None:
optional_params["best_of"] = n
optional_params["do_sample"] = (
True # Need to sample if you want best of for hf inference endpoints
)
if stream is not None:
optional_params["stream"] = stream
if stop is not None:
optional_params["stop"] = stop
if max_tokens is not None:
# HF TGI raises the following exception when max_new_tokens==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
if max_tokens == 0:
max_tokens = 1
optional_params["max_new_tokens"] = max_tokens
if n is not None:
optional_params["best_of"] = n
if presence_penalty is not None:
optional_params["repetition_penalty"] = presence_penalty
if "echo" in passed_params:
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
optional_params["decoder_input_details"] = special_params["echo"]
passed_params.pop(
"echo", None
) # since we handle translating echo, we should not send it to TGI request
optional_params = litellm.HuggingfaceConfig().map_openai_params(
non_default_params=non_default_params, optional_params=optional_params
)
elif custom_llm_provider == "together_ai":
## check if unsupported param passed in
supported_params = get_supported_openai_params(
@ -6181,7 +6149,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
"seed",
]
elif custom_llm_provider == "huggingface":
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
return litellm.HuggingfaceConfig().get_supported_openai_params()
elif custom_llm_provider == "together_ai":
return [
"stream",