fix(huggingface_restapi.py): fix task extraction from model name

This commit is contained in:
Krrish Dholakia 2024-05-15 07:28:19 -07:00
parent 900bb9aba8
commit 8117af664c
3 changed files with 18 additions and 10 deletions

View file

@ -6,7 +6,7 @@ import httpx, requests
from .base import BaseLLM
import time
import litellm
from typing import Callable, Dict, List, Any, Literal
from typing import Callable, Dict, List, Any, Literal, Tuple
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, Usage
from typing import Optional
from .prompt_templates.factory import prompt_factory, custom_prompt
@ -227,20 +227,21 @@ def read_tgi_conv_models():
return set(), set()
def get_hf_task_for_model(model: str) -> hf_tasks:
def get_hf_task_for_model(model: str) -> Tuple[hf_tasks, str]:
# read text file, cast it to set
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
if model.split("/")[0] in hf_task_list:
return model.split("/")[0] # type: ignore
split_model = model.split("/", 1)
return split_model[0], split_model[1] # type: ignore
tgi_models, conversational_models = read_tgi_conv_models()
if model in tgi_models:
return "text-generation-inference"
return "text-generation-inference", model
elif model in conversational_models:
return "conversational"
return "conversational", model
elif "roneneldan/TinyStories" in model:
return "text-generation"
return "text-generation", model
else:
return "text-generation-inference" # default to tgi
return "text-generation-inference", model # default to tgi
class Huggingface(BaseLLM):
@ -403,7 +404,7 @@ class Huggingface(BaseLLM):
exception_mapping_worked = False
try:
headers = self.validate_environment(api_key, headers)
task = get_hf_task_for_model(model)
task, model = get_hf_task_for_model(model)
## VALIDATE API FORMAT
if task is None or not isinstance(task, str) or task not in hf_task_list:
raise Exception(
@ -514,7 +515,7 @@ class Huggingface(BaseLLM):
if task == "text-generation-inference":
data["parameters"] = inference_params
data["stream"] = ( # type: ignore
True
True # type: ignore
if "stream" in optional_params
and optional_params["stream"] == True
else False