fix(huggingface_restapi.py): fix task extraction from model name

This commit is contained in:
Krrish Dholakia 2024-05-15 07:28:19 -07:00
parent 900bb9aba8
commit 8117af664c
3 changed files with 18 additions and 10 deletions

View file

@ -6,7 +6,7 @@ import httpx, requests
from .base import BaseLLM
import time
import litellm
from typing import Callable, Dict, List, Any, Literal
from typing import Callable, Dict, List, Any, Literal, Tuple
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, Usage
from typing import Optional
from .prompt_templates.factory import prompt_factory, custom_prompt
@ -227,20 +227,21 @@ def read_tgi_conv_models():
return set(), set()
def get_hf_task_for_model(model: str) -> hf_tasks:
def get_hf_task_for_model(model: str) -> Tuple[hf_tasks, str]:
# read text file, cast it to set
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
if model.split("/")[0] in hf_task_list:
return model.split("/")[0] # type: ignore
split_model = model.split("/", 1)
return split_model[0], split_model[1] # type: ignore
tgi_models, conversational_models = read_tgi_conv_models()
if model in tgi_models:
return "text-generation-inference"
return "text-generation-inference", model
elif model in conversational_models:
return "conversational"
return "conversational", model
elif "roneneldan/TinyStories" in model:
return "text-generation"
return "text-generation", model
else:
return "text-generation-inference" # default to tgi
return "text-generation-inference", model # default to tgi
class Huggingface(BaseLLM):
@ -403,7 +404,7 @@ class Huggingface(BaseLLM):
exception_mapping_worked = False
try:
headers = self.validate_environment(api_key, headers)
task = get_hf_task_for_model(model)
task, model = get_hf_task_for_model(model)
## VALIDATE API FORMAT
if task is None or not isinstance(task, str) or task not in hf_task_list:
raise Exception(
@ -514,7 +515,7 @@ class Huggingface(BaseLLM):
if task == "text-generation-inference":
data["parameters"] = inference_params
data["stream"] = ( # type: ignore
True
True # type: ignore
if "stream" in optional_params
and optional_params["stream"] == True
else False

View file

@ -20,7 +20,10 @@ model_list:
api_base: os.environ/AZURE_API_BASE
input_cost_per_token: 0.0
output_cost_per_token: 0.0
- model_name: bert-classifier
litellm_params:
model: huggingface/text-classification/shahrukhx01/question-vs-statement-classifier
api_key: os.environ/HUGGINGFACE_API_KEY
router_settings:
redis_host: redis
# redis_password: <your redis password>

View file

@ -1318,6 +1318,10 @@ def test_hf_test_completion_tgi():
def mock_post(url, data=None, json=None, headers=None):
print(f"url={url}")
if "text-classification" in url:
raise Exception("Model not found")
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}