forked from phoenix/litellm-mirror
fix(huggingface_restapi.py): fix task extraction from model name
This commit is contained in:
parent
900bb9aba8
commit
8117af664c
3 changed files with 18 additions and 10 deletions
|
@ -6,7 +6,7 @@ import httpx, requests
|
|||
from .base import BaseLLM
|
||||
import time
|
||||
import litellm
|
||||
from typing import Callable, Dict, List, Any, Literal
|
||||
from typing import Callable, Dict, List, Any, Literal, Tuple
|
||||
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, Usage
|
||||
from typing import Optional
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
|
@ -227,20 +227,21 @@ def read_tgi_conv_models():
|
|||
return set(), set()
|
||||
|
||||
|
||||
def get_hf_task_for_model(model: str) -> hf_tasks:
|
||||
def get_hf_task_for_model(model: str) -> Tuple[hf_tasks, str]:
|
||||
# read text file, cast it to set
|
||||
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
|
||||
if model.split("/")[0] in hf_task_list:
|
||||
return model.split("/")[0] # type: ignore
|
||||
split_model = model.split("/", 1)
|
||||
return split_model[0], split_model[1] # type: ignore
|
||||
tgi_models, conversational_models = read_tgi_conv_models()
|
||||
if model in tgi_models:
|
||||
return "text-generation-inference"
|
||||
return "text-generation-inference", model
|
||||
elif model in conversational_models:
|
||||
return "conversational"
|
||||
return "conversational", model
|
||||
elif "roneneldan/TinyStories" in model:
|
||||
return "text-generation"
|
||||
return "text-generation", model
|
||||
else:
|
||||
return "text-generation-inference" # default to tgi
|
||||
return "text-generation-inference", model # default to tgi
|
||||
|
||||
|
||||
class Huggingface(BaseLLM):
|
||||
|
@ -403,7 +404,7 @@ class Huggingface(BaseLLM):
|
|||
exception_mapping_worked = False
|
||||
try:
|
||||
headers = self.validate_environment(api_key, headers)
|
||||
task = get_hf_task_for_model(model)
|
||||
task, model = get_hf_task_for_model(model)
|
||||
## VALIDATE API FORMAT
|
||||
if task is None or not isinstance(task, str) or task not in hf_task_list:
|
||||
raise Exception(
|
||||
|
@ -514,7 +515,7 @@ class Huggingface(BaseLLM):
|
|||
if task == "text-generation-inference":
|
||||
data["parameters"] = inference_params
|
||||
data["stream"] = ( # type: ignore
|
||||
True
|
||||
True # type: ignore
|
||||
if "stream" in optional_params
|
||||
and optional_params["stream"] == True
|
||||
else False
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue