fix(huggingface_restapi.py): fix task extraction from model name

This commit is contained in:
Krrish Dholakia 2024-05-15 07:28:19 -07:00
parent 900bb9aba8
commit 8117af664c
3 changed files with 18 additions and 10 deletions

View file

@ -6,7 +6,7 @@ import httpx, requests
from .base import BaseLLM from .base import BaseLLM
import time import time
import litellm import litellm
from typing import Callable, Dict, List, Any, Literal from typing import Callable, Dict, List, Any, Literal, Tuple
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, Usage from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, Usage
from typing import Optional from typing import Optional
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
@ -227,20 +227,21 @@ def read_tgi_conv_models():
return set(), set() return set(), set()
def get_hf_task_for_model(model: str) -> hf_tasks: def get_hf_task_for_model(model: str) -> Tuple[hf_tasks, str]:
# read text file, cast it to set # read text file, cast it to set
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt" # read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
if model.split("/")[0] in hf_task_list: if model.split("/")[0] in hf_task_list:
return model.split("/")[0] # type: ignore split_model = model.split("/", 1)
return split_model[0], split_model[1] # type: ignore
tgi_models, conversational_models = read_tgi_conv_models() tgi_models, conversational_models = read_tgi_conv_models()
if model in tgi_models: if model in tgi_models:
return "text-generation-inference" return "text-generation-inference", model
elif model in conversational_models: elif model in conversational_models:
return "conversational" return "conversational", model
elif "roneneldan/TinyStories" in model: elif "roneneldan/TinyStories" in model:
return "text-generation" return "text-generation", model
else: else:
return "text-generation-inference" # default to tgi return "text-generation-inference", model # default to tgi
class Huggingface(BaseLLM): class Huggingface(BaseLLM):
@ -403,7 +404,7 @@ class Huggingface(BaseLLM):
exception_mapping_worked = False exception_mapping_worked = False
try: try:
headers = self.validate_environment(api_key, headers) headers = self.validate_environment(api_key, headers)
task = get_hf_task_for_model(model) task, model = get_hf_task_for_model(model)
## VALIDATE API FORMAT ## VALIDATE API FORMAT
if task is None or not isinstance(task, str) or task not in hf_task_list: if task is None or not isinstance(task, str) or task not in hf_task_list:
raise Exception( raise Exception(
@ -514,7 +515,7 @@ class Huggingface(BaseLLM):
if task == "text-generation-inference": if task == "text-generation-inference":
data["parameters"] = inference_params data["parameters"] = inference_params
data["stream"] = ( # type: ignore data["stream"] = ( # type: ignore
True True # type: ignore
if "stream" in optional_params if "stream" in optional_params
and optional_params["stream"] == True and optional_params["stream"] == True
else False else False

View file

@ -20,7 +20,10 @@ model_list:
api_base: os.environ/AZURE_API_BASE api_base: os.environ/AZURE_API_BASE
input_cost_per_token: 0.0 input_cost_per_token: 0.0
output_cost_per_token: 0.0 output_cost_per_token: 0.0
- model_name: bert-classifier
litellm_params:
model: huggingface/text-classification/shahrukhx01/question-vs-statement-classifier
api_key: os.environ/HUGGINGFACE_API_KEY
router_settings: router_settings:
redis_host: redis redis_host: redis
# redis_password: <your redis password> # redis_password: <your redis password>

View file

@ -1318,6 +1318,10 @@ def test_hf_test_completion_tgi():
def mock_post(url, data=None, json=None, headers=None): def mock_post(url, data=None, json=None, headers=None):
print(f"url={url}")
if "text-classification" in url:
raise Exception("Model not found")
mock_response = MagicMock() mock_response = MagicMock()
mock_response.status_code = 200 mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"} mock_response.headers = {"Content-Type": "application/json"}