forked from phoenix/litellm-mirror
fix(huggingface_restapi.py): fix task extraction from model name
This commit is contained in:
parent
900bb9aba8
commit
8117af664c
3 changed files with 18 additions and 10 deletions
|
@ -6,7 +6,7 @@ import httpx, requests
|
||||||
from .base import BaseLLM
|
from .base import BaseLLM
|
||||||
import time
|
import time
|
||||||
import litellm
|
import litellm
|
||||||
from typing import Callable, Dict, List, Any, Literal
|
from typing import Callable, Dict, List, Any, Literal, Tuple
|
||||||
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, Usage
|
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, Usage
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||||
|
@ -227,20 +227,21 @@ def read_tgi_conv_models():
|
||||||
return set(), set()
|
return set(), set()
|
||||||
|
|
||||||
|
|
||||||
def get_hf_task_for_model(model: str) -> hf_tasks:
|
def get_hf_task_for_model(model: str) -> Tuple[hf_tasks, str]:
|
||||||
# read text file, cast it to set
|
# read text file, cast it to set
|
||||||
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
|
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
|
||||||
if model.split("/")[0] in hf_task_list:
|
if model.split("/")[0] in hf_task_list:
|
||||||
return model.split("/")[0] # type: ignore
|
split_model = model.split("/", 1)
|
||||||
|
return split_model[0], split_model[1] # type: ignore
|
||||||
tgi_models, conversational_models = read_tgi_conv_models()
|
tgi_models, conversational_models = read_tgi_conv_models()
|
||||||
if model in tgi_models:
|
if model in tgi_models:
|
||||||
return "text-generation-inference"
|
return "text-generation-inference", model
|
||||||
elif model in conversational_models:
|
elif model in conversational_models:
|
||||||
return "conversational"
|
return "conversational", model
|
||||||
elif "roneneldan/TinyStories" in model:
|
elif "roneneldan/TinyStories" in model:
|
||||||
return "text-generation"
|
return "text-generation", model
|
||||||
else:
|
else:
|
||||||
return "text-generation-inference" # default to tgi
|
return "text-generation-inference", model # default to tgi
|
||||||
|
|
||||||
|
|
||||||
class Huggingface(BaseLLM):
|
class Huggingface(BaseLLM):
|
||||||
|
@ -403,7 +404,7 @@ class Huggingface(BaseLLM):
|
||||||
exception_mapping_worked = False
|
exception_mapping_worked = False
|
||||||
try:
|
try:
|
||||||
headers = self.validate_environment(api_key, headers)
|
headers = self.validate_environment(api_key, headers)
|
||||||
task = get_hf_task_for_model(model)
|
task, model = get_hf_task_for_model(model)
|
||||||
## VALIDATE API FORMAT
|
## VALIDATE API FORMAT
|
||||||
if task is None or not isinstance(task, str) or task not in hf_task_list:
|
if task is None or not isinstance(task, str) or task not in hf_task_list:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
|
@ -514,7 +515,7 @@ class Huggingface(BaseLLM):
|
||||||
if task == "text-generation-inference":
|
if task == "text-generation-inference":
|
||||||
data["parameters"] = inference_params
|
data["parameters"] = inference_params
|
||||||
data["stream"] = ( # type: ignore
|
data["stream"] = ( # type: ignore
|
||||||
True
|
True # type: ignore
|
||||||
if "stream" in optional_params
|
if "stream" in optional_params
|
||||||
and optional_params["stream"] == True
|
and optional_params["stream"] == True
|
||||||
else False
|
else False
|
||||||
|
|
|
@ -20,7 +20,10 @@ model_list:
|
||||||
api_base: os.environ/AZURE_API_BASE
|
api_base: os.environ/AZURE_API_BASE
|
||||||
input_cost_per_token: 0.0
|
input_cost_per_token: 0.0
|
||||||
output_cost_per_token: 0.0
|
output_cost_per_token: 0.0
|
||||||
|
- model_name: bert-classifier
|
||||||
|
litellm_params:
|
||||||
|
model: huggingface/text-classification/shahrukhx01/question-vs-statement-classifier
|
||||||
|
api_key: os.environ/HUGGINGFACE_API_KEY
|
||||||
router_settings:
|
router_settings:
|
||||||
redis_host: redis
|
redis_host: redis
|
||||||
# redis_password: <your redis password>
|
# redis_password: <your redis password>
|
||||||
|
|
|
@ -1318,6 +1318,10 @@ def test_hf_test_completion_tgi():
|
||||||
|
|
||||||
|
|
||||||
def mock_post(url, data=None, json=None, headers=None):
|
def mock_post(url, data=None, json=None, headers=None):
|
||||||
|
|
||||||
|
print(f"url={url}")
|
||||||
|
if "text-classification" in url:
|
||||||
|
raise Exception("Model not found")
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.status_code = 200
|
mock_response.status_code = 200
|
||||||
mock_response.headers = {"Content-Type": "application/json"}
|
mock_response.headers = {"Content-Type": "application/json"}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue