forked from phoenix/litellm-mirror
docs(huggingface.md): add text-classification to huggingface docs
This commit is contained in:
parent
50be25d11a
commit
d4d175030f
3 changed files with 177 additions and 9 deletions
|
@ -21,6 +21,11 @@ This is done by adding the "huggingface/" prefix to `model`, example `completion
|
||||||
<Tabs>
|
<Tabs>
|
||||||
<TabItem value="tgi" label="Text-generation-interface (TGI)">
|
<TabItem value="tgi" label="Text-generation-interface (TGI)">
|
||||||
|
|
||||||
|
By default, LiteLLM will assume a huggingface call follows the TGI format.
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import os
|
import os
|
||||||
from litellm import completion
|
from litellm import completion
|
||||||
|
@ -40,9 +45,58 @@ response = completion(
|
||||||
print(response)
|
print(response)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Add models to your config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: wizard-coder
|
||||||
|
litellm_params:
|
||||||
|
model: huggingface/WizardLM/WizardCoder-Python-34B-V1.0
|
||||||
|
api_key: os.environ/HUGGINGFACE_API_KEY
|
||||||
|
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
2. Start the proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ litellm --config /path/to/config.yaml --debug
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test it!
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"model": "wizard-coder",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "I like you!"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
</TabItem>
|
</TabItem>
|
||||||
<TabItem value="conv" label="Conversational-task (BlenderBot, etc.)">
|
<TabItem value="conv" label="Conversational-task (BlenderBot, etc.)">
|
||||||
|
|
||||||
|
Append `conversational` to the model name
|
||||||
|
|
||||||
|
e.g. `huggingface/conversational/<model-name>`
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import os
|
import os
|
||||||
from litellm import completion
|
from litellm import completion
|
||||||
|
@ -54,7 +108,7 @@ messages = [{ "content": "There's a llama in my garden 😱 What should I do?","
|
||||||
|
|
||||||
# e.g. Call 'facebook/blenderbot-400M-distill' hosted on HF Inference endpoints
|
# e.g. Call 'facebook/blenderbot-400M-distill' hosted on HF Inference endpoints
|
||||||
response = completion(
|
response = completion(
|
||||||
model="huggingface/facebook/blenderbot-400M-distill",
|
model="huggingface/conversational/facebook/blenderbot-400M-distill",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
api_base="https://my-endpoint.huggingface.cloud"
|
api_base="https://my-endpoint.huggingface.cloud"
|
||||||
)
|
)
|
||||||
|
@ -62,7 +116,123 @@ response = completion(
|
||||||
print(response)
|
print(response)
|
||||||
```
|
```
|
||||||
</TabItem>
|
</TabItem>
|
||||||
<TabItem value="none" label="Non TGI/Conversational-task LLMs">
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Add models to your config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: blenderbot
|
||||||
|
litellm_params:
|
||||||
|
model: huggingface/conversational/facebook/blenderbot-400M-distill
|
||||||
|
api_key: os.environ/HUGGINGFACE_API_KEY
|
||||||
|
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
2. Start the proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ litellm --config /path/to/config.yaml --debug
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test it!
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"model": "blenderbot",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "I like you!"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="classification" label="Text Classification">
|
||||||
|
|
||||||
|
Append `text-classification` to the model name
|
||||||
|
|
||||||
|
e.g. `huggingface/text-classification/<model-name>`
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
from litellm import completion
|
||||||
|
|
||||||
|
# [OPTIONAL] set env var
|
||||||
|
os.environ["HUGGINGFACE_API_KEY"] = "huggingface_api_key"
|
||||||
|
|
||||||
|
messages = [{ "content": "I like you, I love you!","role": "user"}]
|
||||||
|
|
||||||
|
# e.g. Call 'shahrukhx01/question-vs-statement-classifier' hosted on HF Inference endpoints
|
||||||
|
response = completion(
|
||||||
|
model="huggingface/text-classification/shahrukhx01/question-vs-statement-classifier",
|
||||||
|
messages=messages,
|
||||||
|
api_base="https://my-endpoint.endpoints.huggingface.cloud",
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
1. Add models to your config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: bert-classifier
|
||||||
|
litellm_params:
|
||||||
|
model: huggingface/text-classification/shahrukhx01/question-vs-statement-classifier
|
||||||
|
api_key: os.environ/HUGGINGFACE_API_KEY
|
||||||
|
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
2. Start the proxy
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ litellm --config /path/to/config.yaml --debug
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test it!
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"model": "bert-classifier",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "I like you!"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="none" label="Text Generation (NOT TGI)">
|
||||||
|
|
||||||
|
Append `text-generation` to the model name
|
||||||
|
|
||||||
|
e.g. `huggingface/text-generation/<model-name>`
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import os
|
import os
|
||||||
|
@ -75,7 +245,7 @@ messages = [{ "content": "There's a llama in my garden 😱 What should I do?","
|
||||||
|
|
||||||
# e.g. Call 'roneneldan/TinyStories-3M' hosted on HF Inference endpoints
|
# e.g. Call 'roneneldan/TinyStories-3M' hosted on HF Inference endpoints
|
||||||
response = completion(
|
response = completion(
|
||||||
model="huggingface/roneneldan/TinyStories-3M",
|
model="huggingface/text-generation/roneneldan/TinyStories-3M",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
api_base="https://p69xlsj6rpno5drq.us-east-1.aws.endpoints.huggingface.cloud",
|
api_base="https://p69xlsj6rpno5drq.us-east-1.aws.endpoints.huggingface.cloud",
|
||||||
)
|
)
|
||||||
|
|
|
@ -230,6 +230,8 @@ def read_tgi_conv_models():
|
||||||
def get_hf_task_for_model(model: str) -> hf_tasks:
|
def get_hf_task_for_model(model: str) -> hf_tasks:
|
||||||
# read text file, cast it to set
|
# read text file, cast it to set
|
||||||
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
|
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
|
||||||
|
if model.split("/")[0] in hf_task_list:
|
||||||
|
return model.split("/")[0] # type: ignore
|
||||||
tgi_models, conversational_models = read_tgi_conv_models()
|
tgi_models, conversational_models = read_tgi_conv_models()
|
||||||
if model in tgi_models:
|
if model in tgi_models:
|
||||||
return "text-generation-inference"
|
return "text-generation-inference"
|
||||||
|
@ -401,10 +403,7 @@ class Huggingface(BaseLLM):
|
||||||
exception_mapping_worked = False
|
exception_mapping_worked = False
|
||||||
try:
|
try:
|
||||||
headers = self.validate_environment(api_key, headers)
|
headers = self.validate_environment(api_key, headers)
|
||||||
if optional_params.get("hf_task") is None:
|
|
||||||
task = get_hf_task_for_model(model)
|
task = get_hf_task_for_model(model)
|
||||||
else:
|
|
||||||
task = optional_params.get("hf_task") # type: ignore
|
|
||||||
## VALIDATE API FORMAT
|
## VALIDATE API FORMAT
|
||||||
if task is None or not isinstance(task, str) or task not in hf_task_list:
|
if task is None or not isinstance(task, str) or task not in hf_task_list:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
|
|
@ -1291,9 +1291,8 @@ def test_hf_classifier_task():
|
||||||
user_message = "I like you. I love you"
|
user_message = "I like you. I love you"
|
||||||
messages = [{"content": user_message, "role": "user"}]
|
messages = [{"content": user_message, "role": "user"}]
|
||||||
response = completion(
|
response = completion(
|
||||||
model="huggingface/shahrukhx01/question-vs-statement-classifier",
|
model="huggingface/text-classification/shahrukhx01/question-vs-statement-classifier",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
hf_task="text-classification",
|
|
||||||
)
|
)
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
assert isinstance(response, litellm.ModelResponse)
|
assert isinstance(response, litellm.ModelResponse)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue