diff --git a/docs/my-website/docs/providers/huggingface.md b/docs/my-website/docs/providers/huggingface.md index f8ebadfcf..35befd3e2 100644 --- a/docs/my-website/docs/providers/huggingface.md +++ b/docs/my-website/docs/providers/huggingface.md @@ -21,6 +21,11 @@ This is done by adding the "huggingface/" prefix to `model`, example `completion +By default, LiteLLM will assume a huggingface call follows the TGI format. + + + + ```python import os from litellm import completion @@ -40,9 +45,58 @@ response = completion( print(response) ``` + + + +1. Add models to your config.yaml + + ```yaml + model_list: + - model_name: wizard-coder + litellm_params: + model: huggingface/WizardLM/WizardCoder-Python-34B-V1.0 + api_key: os.environ/HUGGINGFACE_API_KEY + api_base: "https://my-endpoint.endpoints.huggingface.cloud" + ``` + + + +2. Start the proxy + + ```bash + $ litellm --config /path/to/config.yaml --debug + ``` + +3. Test it! + + ```shell + curl --location 'http://0.0.0.0:4000/chat/completions' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "model": "wizard-coder", + "messages": [ + { + "role": "user", + "content": "I like you!" + } + ], + }' + ``` + + + + +Append `conversational` to the model name + +e.g. `huggingface/conversational/` + + + + ```python import os from litellm import completion @@ -54,7 +108,7 @@ messages = [{ "content": "There's a llama in my garden 😱 What should I do?"," # e.g. Call 'facebook/blenderbot-400M-distill' hosted on HF Inference endpoints response = completion( - model="huggingface/facebook/blenderbot-400M-distill", + model="huggingface/conversational/facebook/blenderbot-400M-distill", messages=messages, api_base="https://my-endpoint.huggingface.cloud" ) @@ -62,7 +116,123 @@ response = completion( print(response) ``` - + + +1. Add models to your config.yaml + + ```yaml + model_list: + - model_name: blenderbot + litellm_params: + model: huggingface/conversational/facebook/blenderbot-400M-distill + api_key: os.environ/HUGGINGFACE_API_KEY + api_base: "https://my-endpoint.endpoints.huggingface.cloud" + ``` + + + +2. Start the proxy + + ```bash + $ litellm --config /path/to/config.yaml --debug + ``` + +3. Test it! + + ```shell + curl --location 'http://0.0.0.0:4000/chat/completions' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "model": "blenderbot", + "messages": [ + { + "role": "user", + "content": "I like you!" + } + ], + }' + ``` + + + + + + + +Append `text-classification` to the model name + +e.g. `huggingface/text-classification/` + + + + +```python +import os +from litellm import completion + +# [OPTIONAL] set env var +os.environ["HUGGINGFACE_API_KEY"] = "huggingface_api_key" + +messages = [{ "content": "I like you, I love you!","role": "user"}] + +# e.g. Call 'shahrukhx01/question-vs-statement-classifier' hosted on HF Inference endpoints +response = completion( + model="huggingface/text-classification/shahrukhx01/question-vs-statement-classifier", + messages=messages, + api_base="https://my-endpoint.endpoints.huggingface.cloud", +) + +print(response) +``` + + + +1. Add models to your config.yaml + + ```yaml + model_list: + - model_name: bert-classifier + litellm_params: + model: huggingface/text-classification/shahrukhx01/question-vs-statement-classifier + api_key: os.environ/HUGGINGFACE_API_KEY + api_base: "https://my-endpoint.endpoints.huggingface.cloud" + ``` + + + +2. Start the proxy + + ```bash + $ litellm --config /path/to/config.yaml --debug + ``` + +3. Test it! + + ```shell + curl --location 'http://0.0.0.0:4000/chat/completions' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "model": "bert-classifier", + "messages": [ + { + "role": "user", + "content": "I like you!" + } + ], + }' + ``` + + + + + + + +Append `text-generation` to the model name + +e.g. `huggingface/text-generation/` ```python import os @@ -75,7 +245,7 @@ messages = [{ "content": "There's a llama in my garden 😱 What should I do?"," # e.g. Call 'roneneldan/TinyStories-3M' hosted on HF Inference endpoints response = completion( - model="huggingface/roneneldan/TinyStories-3M", + model="huggingface/text-generation/roneneldan/TinyStories-3M", messages=messages, api_base="https://p69xlsj6rpno5drq.us-east-1.aws.endpoints.huggingface.cloud", ) diff --git a/litellm/llms/huggingface_restapi.py b/litellm/llms/huggingface_restapi.py index 26591b95d..ad3c570e7 100644 --- a/litellm/llms/huggingface_restapi.py +++ b/litellm/llms/huggingface_restapi.py @@ -230,6 +230,8 @@ def read_tgi_conv_models(): def get_hf_task_for_model(model: str) -> hf_tasks: # read text file, cast it to set # read the file called "huggingface_llms_metadata/hf_text_generation_models.txt" + if model.split("/")[0] in hf_task_list: + return model.split("/")[0] # type: ignore tgi_models, conversational_models = read_tgi_conv_models() if model in tgi_models: return "text-generation-inference" @@ -401,10 +403,7 @@ class Huggingface(BaseLLM): exception_mapping_worked = False try: headers = self.validate_environment(api_key, headers) - if optional_params.get("hf_task") is None: - task = get_hf_task_for_model(model) - else: - task = optional_params.get("hf_task") # type: ignore + task = get_hf_task_for_model(model) ## VALIDATE API FORMAT if task is None or not isinstance(task, str) or task not in hf_task_list: raise Exception( diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 4da489cc5..b3e1928a8 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -1291,9 +1291,8 @@ def test_hf_classifier_task(): user_message = "I like you. I love you" messages = [{"content": user_message, "role": "user"}] response = completion( - model="huggingface/shahrukhx01/question-vs-statement-classifier", + model="huggingface/text-classification/shahrukhx01/question-vs-statement-classifier", messages=messages, - hf_task="text-classification", ) print(f"response: {response}") assert isinstance(response, litellm.ModelResponse)