docs(huggingface.md): add text-classification to huggingface docs

This commit is contained in:
Krrish Dholakia 2024-05-10 14:39:14 -07:00
parent 50be25d11a
commit d4d175030f
3 changed files with 177 additions and 9 deletions

View file

@ -21,6 +21,11 @@ This is done by adding the "huggingface/" prefix to `model`, example `completion
<Tabs> <Tabs>
<TabItem value="tgi" label="Text-generation-interface (TGI)"> <TabItem value="tgi" label="Text-generation-interface (TGI)">
By default, LiteLLM will assume a huggingface call follows the TGI format.
<Tabs>
<TabItem value="sdk" label="SDK">
```python ```python
import os import os
from litellm import completion from litellm import completion
@ -40,9 +45,58 @@ response = completion(
print(response) print(response)
``` ```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add models to your config.yaml
```yaml
model_list:
- model_name: wizard-coder
litellm_params:
model: huggingface/WizardLM/WizardCoder-Python-34B-V1.0
api_key: os.environ/HUGGINGFACE_API_KEY
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml --debug
```
3. Test it!
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "wizard-coder",
"messages": [
{
"role": "user",
"content": "I like you!"
}
],
}'
```
</TabItem>
</Tabs>
</TabItem> </TabItem>
<TabItem value="conv" label="Conversational-task (BlenderBot, etc.)"> <TabItem value="conv" label="Conversational-task (BlenderBot, etc.)">
Append `conversational` to the model name
e.g. `huggingface/conversational/<model-name>`
<Tabs>
<TabItem value="sdk" label="SDK">
```python ```python
import os import os
from litellm import completion from litellm import completion
@ -54,7 +108,7 @@ messages = [{ "content": "There's a llama in my garden 😱 What should I do?","
# e.g. Call 'facebook/blenderbot-400M-distill' hosted on HF Inference endpoints # e.g. Call 'facebook/blenderbot-400M-distill' hosted on HF Inference endpoints
response = completion( response = completion(
model="huggingface/facebook/blenderbot-400M-distill", model="huggingface/conversational/facebook/blenderbot-400M-distill",
messages=messages, messages=messages,
api_base="https://my-endpoint.huggingface.cloud" api_base="https://my-endpoint.huggingface.cloud"
) )
@ -62,7 +116,123 @@ response = completion(
print(response) print(response)
``` ```
</TabItem> </TabItem>
<TabItem value="none" label="Non TGI/Conversational-task LLMs"> <TabItem value="proxy" label="PROXY">
1. Add models to your config.yaml
```yaml
model_list:
- model_name: blenderbot
litellm_params:
model: huggingface/conversational/facebook/blenderbot-400M-distill
api_key: os.environ/HUGGINGFACE_API_KEY
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml --debug
```
3. Test it!
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "blenderbot",
"messages": [
{
"role": "user",
"content": "I like you!"
}
],
}'
```
</TabItem>
</Tabs>
</TabItem>
<TabItem value="classification" label="Text Classification">
Append `text-classification` to the model name
e.g. `huggingface/text-classification/<model-name>`
<Tabs>
<TabItem value="sdk" label="SDK">
```python
import os
from litellm import completion
# [OPTIONAL] set env var
os.environ["HUGGINGFACE_API_KEY"] = "huggingface_api_key"
messages = [{ "content": "I like you, I love you!","role": "user"}]
# e.g. Call 'shahrukhx01/question-vs-statement-classifier' hosted on HF Inference endpoints
response = completion(
model="huggingface/text-classification/shahrukhx01/question-vs-statement-classifier",
messages=messages,
api_base="https://my-endpoint.endpoints.huggingface.cloud",
)
print(response)
```
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Add models to your config.yaml
```yaml
model_list:
- model_name: bert-classifier
litellm_params:
model: huggingface/text-classification/shahrukhx01/question-vs-statement-classifier
api_key: os.environ/HUGGINGFACE_API_KEY
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml --debug
```
3. Test it!
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"model": "bert-classifier",
"messages": [
{
"role": "user",
"content": "I like you!"
}
],
}'
```
</TabItem>
</Tabs>
</TabItem>
<TabItem value="none" label="Text Generation (NOT TGI)">
Append `text-generation` to the model name
e.g. `huggingface/text-generation/<model-name>`
```python ```python
import os import os
@ -75,7 +245,7 @@ messages = [{ "content": "There's a llama in my garden 😱 What should I do?","
# e.g. Call 'roneneldan/TinyStories-3M' hosted on HF Inference endpoints # e.g. Call 'roneneldan/TinyStories-3M' hosted on HF Inference endpoints
response = completion( response = completion(
model="huggingface/roneneldan/TinyStories-3M", model="huggingface/text-generation/roneneldan/TinyStories-3M",
messages=messages, messages=messages,
api_base="https://p69xlsj6rpno5drq.us-east-1.aws.endpoints.huggingface.cloud", api_base="https://p69xlsj6rpno5drq.us-east-1.aws.endpoints.huggingface.cloud",
) )

View file

@ -230,6 +230,8 @@ def read_tgi_conv_models():
def get_hf_task_for_model(model: str) -> hf_tasks: def get_hf_task_for_model(model: str) -> hf_tasks:
# read text file, cast it to set # read text file, cast it to set
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt" # read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
if model.split("/")[0] in hf_task_list:
return model.split("/")[0] # type: ignore
tgi_models, conversational_models = read_tgi_conv_models() tgi_models, conversational_models = read_tgi_conv_models()
if model in tgi_models: if model in tgi_models:
return "text-generation-inference" return "text-generation-inference"
@ -401,10 +403,7 @@ class Huggingface(BaseLLM):
exception_mapping_worked = False exception_mapping_worked = False
try: try:
headers = self.validate_environment(api_key, headers) headers = self.validate_environment(api_key, headers)
if optional_params.get("hf_task") is None:
task = get_hf_task_for_model(model) task = get_hf_task_for_model(model)
else:
task = optional_params.get("hf_task") # type: ignore
## VALIDATE API FORMAT ## VALIDATE API FORMAT
if task is None or not isinstance(task, str) or task not in hf_task_list: if task is None or not isinstance(task, str) or task not in hf_task_list:
raise Exception( raise Exception(

View file

@ -1291,9 +1291,8 @@ def test_hf_classifier_task():
user_message = "I like you. I love you" user_message = "I like you. I love you"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
response = completion( response = completion(
model="huggingface/shahrukhx01/question-vs-statement-classifier", model="huggingface/text-classification/shahrukhx01/question-vs-statement-classifier",
messages=messages, messages=messages,
hf_task="text-classification",
) )
print(f"response: {response}") print(f"response: {response}")
assert isinstance(response, litellm.ModelResponse) assert isinstance(response, litellm.ModelResponse)