diff --git a/litellm/llms/huggingface_restapi.py b/litellm/llms/huggingface_restapi.py index ad3c570e7..c54dba75f 100644 --- a/litellm/llms/huggingface_restapi.py +++ b/litellm/llms/huggingface_restapi.py @@ -6,7 +6,7 @@ import httpx, requests from .base import BaseLLM import time import litellm -from typing import Callable, Dict, List, Any, Literal +from typing import Callable, Dict, List, Any, Literal, Tuple from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, Usage from typing import Optional from .prompt_templates.factory import prompt_factory, custom_prompt @@ -227,20 +227,21 @@ def read_tgi_conv_models(): return set(), set() -def get_hf_task_for_model(model: str) -> hf_tasks: +def get_hf_task_for_model(model: str) -> Tuple[hf_tasks, str]: # read text file, cast it to set # read the file called "huggingface_llms_metadata/hf_text_generation_models.txt" if model.split("/")[0] in hf_task_list: - return model.split("/")[0] # type: ignore + split_model = model.split("/", 1) + return split_model[0], split_model[1] # type: ignore tgi_models, conversational_models = read_tgi_conv_models() if model in tgi_models: - return "text-generation-inference" + return "text-generation-inference", model elif model in conversational_models: - return "conversational" + return "conversational", model elif "roneneldan/TinyStories" in model: - return "text-generation" + return "text-generation", model else: - return "text-generation-inference" # default to tgi + return "text-generation-inference", model # default to tgi class Huggingface(BaseLLM): @@ -403,7 +404,7 @@ class Huggingface(BaseLLM): exception_mapping_worked = False try: headers = self.validate_environment(api_key, headers) - task = get_hf_task_for_model(model) + task, model = get_hf_task_for_model(model) ## VALIDATE API FORMAT if task is None or not isinstance(task, str) or task not in hf_task_list: raise Exception( @@ -514,7 +515,7 @@ class Huggingface(BaseLLM): if task == "text-generation-inference": data["parameters"] = inference_params data["stream"] = ( # type: ignore - True + True # type: ignore if "stream" in optional_params and optional_params["stream"] == True else False diff --git a/litellm/proxy/_super_secret_config.yaml b/litellm/proxy/_super_secret_config.yaml index 5f5ee89c3..42f9e3be5 100644 --- a/litellm/proxy/_super_secret_config.yaml +++ b/litellm/proxy/_super_secret_config.yaml @@ -20,7 +20,10 @@ model_list: api_base: os.environ/AZURE_API_BASE input_cost_per_token: 0.0 output_cost_per_token: 0.0 - +- model_name: bert-classifier + litellm_params: + model: huggingface/text-classification/shahrukhx01/question-vs-statement-classifier + api_key: os.environ/HUGGINGFACE_API_KEY router_settings: redis_host: redis # redis_password: diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index b5fa141cc..4441ddf29 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -1318,6 +1318,10 @@ def test_hf_test_completion_tgi(): def mock_post(url, data=None, json=None, headers=None): + + print(f"url={url}") + if "text-classification" in url: + raise Exception("Model not found") mock_response = MagicMock() mock_response.status_code = 200 mock_response.headers = {"Content-Type": "application/json"}