forked from phoenix/litellm-mirror
docs(huggingface.md): add text-classification to huggingface docs
This commit is contained in:
parent
50be25d11a
commit
d4d175030f
3 changed files with 177 additions and 9 deletions
|
@ -21,6 +21,11 @@ This is done by adding the "huggingface/" prefix to `model`, example `completion
|
|||
<Tabs>
|
||||
<TabItem value="tgi" label="Text-generation-interface (TGI)">
|
||||
|
||||
By default, LiteLLM will assume a huggingface call follows the TGI format.
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
import os
|
||||
from litellm import completion
|
||||
|
@ -40,9 +45,58 @@ response = completion(
|
|||
print(response)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
1. Add models to your config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: wizard-coder
|
||||
litellm_params:
|
||||
model: huggingface/WizardLM/WizardCoder-Python-34B-V1.0
|
||||
api_key: os.environ/HUGGINGFACE_API_KEY
|
||||
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
|
||||
```
|
||||
|
||||
|
||||
|
||||
2. Start the proxy
|
||||
|
||||
```bash
|
||||
$ litellm --config /path/to/config.yaml --debug
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"model": "wizard-coder",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "I like you!"
|
||||
}
|
||||
],
|
||||
}'
|
||||
```
|
||||
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
</TabItem>
|
||||
<TabItem value="conv" label="Conversational-task (BlenderBot, etc.)">
|
||||
|
||||
Append `conversational` to the model name
|
||||
|
||||
e.g. `huggingface/conversational/<model-name>`
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
import os
|
||||
from litellm import completion
|
||||
|
@ -54,7 +108,7 @@ messages = [{ "content": "There's a llama in my garden 😱 What should I do?","
|
|||
|
||||
# e.g. Call 'facebook/blenderbot-400M-distill' hosted on HF Inference endpoints
|
||||
response = completion(
|
||||
model="huggingface/facebook/blenderbot-400M-distill",
|
||||
model="huggingface/conversational/facebook/blenderbot-400M-distill",
|
||||
messages=messages,
|
||||
api_base="https://my-endpoint.huggingface.cloud"
|
||||
)
|
||||
|
@ -62,7 +116,123 @@ response = completion(
|
|||
print(response)
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="none" label="Non TGI/Conversational-task LLMs">
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
1. Add models to your config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: blenderbot
|
||||
litellm_params:
|
||||
model: huggingface/conversational/facebook/blenderbot-400M-distill
|
||||
api_key: os.environ/HUGGINGFACE_API_KEY
|
||||
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
|
||||
```
|
||||
|
||||
|
||||
|
||||
2. Start the proxy
|
||||
|
||||
```bash
|
||||
$ litellm --config /path/to/config.yaml --debug
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"model": "blenderbot",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "I like you!"
|
||||
}
|
||||
],
|
||||
}'
|
||||
```
|
||||
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
</TabItem>
|
||||
<TabItem value="classification" label="Text Classification">
|
||||
|
||||
Append `text-classification` to the model name
|
||||
|
||||
e.g. `huggingface/text-classification/<model-name>`
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
import os
|
||||
from litellm import completion
|
||||
|
||||
# [OPTIONAL] set env var
|
||||
os.environ["HUGGINGFACE_API_KEY"] = "huggingface_api_key"
|
||||
|
||||
messages = [{ "content": "I like you, I love you!","role": "user"}]
|
||||
|
||||
# e.g. Call 'shahrukhx01/question-vs-statement-classifier' hosted on HF Inference endpoints
|
||||
response = completion(
|
||||
model="huggingface/text-classification/shahrukhx01/question-vs-statement-classifier",
|
||||
messages=messages,
|
||||
api_base="https://my-endpoint.endpoints.huggingface.cloud",
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="proxy" label="PROXY">
|
||||
|
||||
1. Add models to your config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: bert-classifier
|
||||
litellm_params:
|
||||
model: huggingface/text-classification/shahrukhx01/question-vs-statement-classifier
|
||||
api_key: os.environ/HUGGINGFACE_API_KEY
|
||||
api_base: "https://my-endpoint.endpoints.huggingface.cloud"
|
||||
```
|
||||
|
||||
|
||||
|
||||
2. Start the proxy
|
||||
|
||||
```bash
|
||||
$ litellm --config /path/to/config.yaml --debug
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"model": "bert-classifier",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "I like you!"
|
||||
}
|
||||
],
|
||||
}'
|
||||
```
|
||||
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
</TabItem>
|
||||
<TabItem value="none" label="Text Generation (NOT TGI)">
|
||||
|
||||
Append `text-generation` to the model name
|
||||
|
||||
e.g. `huggingface/text-generation/<model-name>`
|
||||
|
||||
```python
|
||||
import os
|
||||
|
@ -75,7 +245,7 @@ messages = [{ "content": "There's a llama in my garden 😱 What should I do?","
|
|||
|
||||
# e.g. Call 'roneneldan/TinyStories-3M' hosted on HF Inference endpoints
|
||||
response = completion(
|
||||
model="huggingface/roneneldan/TinyStories-3M",
|
||||
model="huggingface/text-generation/roneneldan/TinyStories-3M",
|
||||
messages=messages,
|
||||
api_base="https://p69xlsj6rpno5drq.us-east-1.aws.endpoints.huggingface.cloud",
|
||||
)
|
||||
|
|
|
@ -230,6 +230,8 @@ def read_tgi_conv_models():
|
|||
def get_hf_task_for_model(model: str) -> hf_tasks:
|
||||
# read text file, cast it to set
|
||||
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
|
||||
if model.split("/")[0] in hf_task_list:
|
||||
return model.split("/")[0] # type: ignore
|
||||
tgi_models, conversational_models = read_tgi_conv_models()
|
||||
if model in tgi_models:
|
||||
return "text-generation-inference"
|
||||
|
@ -401,10 +403,7 @@ class Huggingface(BaseLLM):
|
|||
exception_mapping_worked = False
|
||||
try:
|
||||
headers = self.validate_environment(api_key, headers)
|
||||
if optional_params.get("hf_task") is None:
|
||||
task = get_hf_task_for_model(model)
|
||||
else:
|
||||
task = optional_params.get("hf_task") # type: ignore
|
||||
## VALIDATE API FORMAT
|
||||
if task is None or not isinstance(task, str) or task not in hf_task_list:
|
||||
raise Exception(
|
||||
|
|
|
@ -1291,9 +1291,8 @@ def test_hf_classifier_task():
|
|||
user_message = "I like you. I love you"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
response = completion(
|
||||
model="huggingface/shahrukhx01/question-vs-statement-classifier",
|
||||
model="huggingface/text-classification/shahrukhx01/question-vs-statement-classifier",
|
||||
messages=messages,
|
||||
hf_task="text-classification",
|
||||
)
|
||||
print(f"response: {response}")
|
||||
assert isinstance(response, litellm.ModelResponse)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue