diff --git a/litellm/llms/huggingface_restapi.py b/litellm/llms/huggingface_restapi.py index 796cf6b00c..7783068afc 100644 --- a/litellm/llms/huggingface_restapi.py +++ b/litellm/llms/huggingface_restapi.py @@ -25,6 +25,52 @@ def validate_environment(api_key): headers["Authorization"] = f"Bearer {api_key}" return headers +tgi_models_cache = None +conv_models_cache = None +def read_tgi_conv_models(): + try: + global tgi_models_cache, conv_models_cache + # Check if the cache is already populated + # so we don't keep on reading txt file if there are 1k requests + if (tgi_models_cache is not None) and (conv_models_cache is not None): + return tgi_models_cache, conv_models_cache + # If not, read the file and populate the cache + tgi_models = set() + script_directory = os.path.dirname(os.path.abspath(__file__)) + # Construct the file path relative to the script's directory + file_path = os.path.join(script_directory, "huggingface_llms_metadata", "hf_text_generation_models.txt") + + with open(file_path, 'r') as file: + for line in file: + tgi_models.add(line.strip()) + + # Cache the set for future use + tgi_models_cache = tgi_models + + # If not, read the file and populate the cache + file_path = os.path.join(script_directory, "huggingface_llms_metadata", "hf_conversational_models.txt") + conv_models = set() + with open(file_path, 'r') as file: + for line in file: + conv_models.add(line.strip()) + # Cache the set for future use + conv_models_cache = conv_models + return tgi_models, conv_models + except: + return set(), set() + + +def get_hf_task_for_model(model): + # read text file, cast it to set + # read the file called "huggingface_llms_metadata/hf_text_generation_models.txt" + tgi_models, conversational_models = read_tgi_conv_models() + if model in tgi_models: + return "text-generation-inference" + elif model in conversational_models: + return "conversational" + else: + return None + def completion( model: str, messages: list, @@ -40,7 +86,8 @@ def completion( logger_fn=None, ): headers = validate_environment(api_key) - task = optional_params.pop("task") + task = get_hf_task_for_model(model) + print_verbose(f"{model}, {task}") completion_url = "" input_text = None if "https" in model: @@ -59,6 +106,7 @@ def completion( inference_params = copy.deepcopy(optional_params) inference_params.pop("details") inference_params.pop("return_full_text") + inference_params.pop("task") past_user_inputs = [] generated_responses = [] text = "" @@ -79,6 +127,7 @@ def completion( } input_text = "".join(message["content"] for message in messages) elif task == "text-generation-inference": + # always send "details" and "return_full_text" as params if model in custom_prompt_dict: # check if the model has a registered custom prompt model_prompt_details = custom_prompt_dict[model] @@ -92,7 +141,6 @@ def completion( prompt = prompt_factory(model=model, messages=messages) if "https://api-inference.huggingface.co/models" in completion_url: inference_params = copy.deepcopy(optional_params) - inference_params.pop("details") data = { "inputs": prompt, "parameters": inference_params, @@ -105,7 +153,9 @@ def completion( "stream": True if "stream" in optional_params and optional_params["stream"] == True else False, } input_text = prompt - elif task == "other" or task == None: + else: + # Non TGI and Conversational llms + # We need this branch, it removes 'details' and 'return_full_text' from params if model in custom_prompt_dict: # check if the model has a registered custom prompt model_prompt_details = custom_prompt_dict[model] @@ -120,6 +170,7 @@ def completion( inference_params = copy.deepcopy(optional_params) inference_params.pop("details") inference_params.pop("return_full_text") + inference_params.pop("task") data = { "inputs": prompt, "parameters": inference_params,