diff --git a/litellm/llms/huggingface_restapi.py b/litellm/llms/huggingface_restapi.py index a2c4457c2..26591b95d 100644 --- a/litellm/llms/huggingface_restapi.py +++ b/litellm/llms/huggingface_restapi.py @@ -6,10 +6,12 @@ import httpx, requests from .base import BaseLLM import time import litellm -from typing import Callable, Dict, List, Any +from typing import Callable, Dict, List, Any, Literal from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, Usage from typing import Optional from .prompt_templates.factory import prompt_factory, custom_prompt +from litellm.types.completion import ChatCompletionMessageToolCallParam +import enum class HuggingfaceError(Exception): @@ -39,11 +41,29 @@ class HuggingfaceError(Exception): ) # Call the base class constructor with the parameters it needs +hf_task_list = [ + "text-generation-inference", + "conversational", + "text-classification", + "text-generation", +] + +hf_tasks = Literal[ + "text-generation-inference", + "conversational", + "text-classification", + "text-generation", +] + + class HuggingfaceConfig: """ Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate """ + hf_task: Optional[hf_tasks] = ( + None # litellm-specific param, used to know the api spec to use when calling huggingface api + ) best_of: Optional[int] = None decoder_input_details: Optional[bool] = None details: Optional[bool] = True # enables returning logprobs + best of @@ -101,6 +121,51 @@ class HuggingfaceConfig: and v is not None } + def get_supported_openai_params(self): + return [ + "stream", + "temperature", + "max_tokens", + "top_p", + "stop", + "n", + "echo", + ] + + def map_openai_params( + self, non_default_params: dict, optional_params: dict + ) -> dict: + for param, value in non_default_params.items(): + # temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None + if param == "temperature": + if value == 0.0 or value == 0: + # hugging face exception raised when temp==0 + # Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive + value = 0.01 + optional_params["temperature"] = value + if param == "top_p": + optional_params["top_p"] = value + if param == "n": + optional_params["best_of"] = value + optional_params["do_sample"] = ( + True # Need to sample if you want best of for hf inference endpoints + ) + if param == "stream": + optional_params["stream"] = value + if param == "stop": + optional_params["stop"] = value + if param == "max_tokens": + # HF TGI raises the following exception when max_new_tokens==0 + # Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive + if value == 0: + value = 1 + optional_params["max_new_tokens"] = value + if param == "echo": + # https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details + # Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False + optional_params["decoder_input_details"] = True + return optional_params + def output_parser(generated_text: str): """ @@ -162,7 +227,7 @@ def read_tgi_conv_models(): return set(), set() -def get_hf_task_for_model(model): +def get_hf_task_for_model(model: str) -> hf_tasks: # read text file, cast it to set # read the file called "huggingface_llms_metadata/hf_text_generation_models.txt" tgi_models, conversational_models = read_tgi_conv_models() @@ -171,7 +236,7 @@ def get_hf_task_for_model(model): elif model in conversational_models: return "conversational" elif "roneneldan/TinyStories" in model: - return None + return "text-generation" else: return "text-generation-inference" # default to tgi @@ -202,7 +267,7 @@ class Huggingface(BaseLLM): self, completion_response, model_response, - task, + task: hf_tasks, optional_params, encoding, input_text, @@ -270,6 +335,10 @@ class Huggingface(BaseLLM): ) choices_list.append(choice_obj) model_response["choices"].extend(choices_list) + elif task == "text-classification": + model_response["choices"][0]["message"]["content"] = json.dumps( + completion_response + ) else: if len(completion_response[0]["generated_text"]) > 0: model_response["choices"][0]["message"]["content"] = output_parser( @@ -332,7 +401,16 @@ class Huggingface(BaseLLM): exception_mapping_worked = False try: headers = self.validate_environment(api_key, headers) - task = get_hf_task_for_model(model) + if optional_params.get("hf_task") is None: + task = get_hf_task_for_model(model) + else: + task = optional_params.get("hf_task") # type: ignore + ## VALIDATE API FORMAT + if task is None or not isinstance(task, str) or task not in hf_task_list: + raise Exception( + "Invalid hf task - {}. Valid formats - {}.".format(task, hf_tasks) + ) + print_verbose(f"{model}, {task}") completion_url = "" input_text = "" @@ -433,14 +511,15 @@ class Huggingface(BaseLLM): inference_params.pop("return_full_text") data = { "inputs": prompt, - "parameters": inference_params, - "stream": ( # type: ignore + } + if task == "text-generation-inference": + data["parameters"] = inference_params + data["stream"] = ( # type: ignore True if "stream" in optional_params and optional_params["stream"] == True else False - ), - } + ) input_text = prompt ## LOGGING logging_obj.pre_call( @@ -531,10 +610,10 @@ class Huggingface(BaseLLM): isinstance(completion_response, dict) and "error" in completion_response ): - print_verbose(f"completion error: {completion_response['error']}") + print_verbose(f"completion error: {completion_response['error']}") # type: ignore print_verbose(f"response.status_code: {response.status_code}") raise HuggingfaceError( - message=completion_response["error"], + message=completion_response["error"], # type: ignore status_code=response.status_code, ) return self.convert_to_model_response_object( @@ -563,7 +642,7 @@ class Huggingface(BaseLLM): data: dict, headers: dict, model_response: ModelResponse, - task: str, + task: hf_tasks, encoding: Any, input_text: str, model: str, diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 630baf346..4da489cc5 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -13,6 +13,7 @@ import litellm from litellm import embedding, completion, completion_cost, Timeout from litellm import RateLimitError from litellm.llms.prompt_templates.factory import anthropic_messages_pt +from unittest.mock import patch, MagicMock # litellm.num_retries=3 litellm.cache = None @@ -1145,15 +1146,92 @@ def test_get_hf_task_for_model(): # ################### Hugging Face TGI models ######################## # # TGI model # # this is a TGI model https://huggingface.co/glaiveai/glaive-coder-7b -def hf_test_completion_tgi(): - # litellm.set_verbose=True +def tgi_mock_post(url, data=None, json=None, headers=None): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {"Content-Type": "application/json"} + mock_response.json.return_value = [ + { + "generated_text": "<|assistant|>\nI'm", + "details": { + "finish_reason": "length", + "generated_tokens": 10, + "seed": None, + "prefill": [], + "tokens": [ + { + "id": 28789, + "text": "<", + "logprob": -0.025222778, + "special": False, + }, + { + "id": 28766, + "text": "|", + "logprob": -0.000003695488, + "special": False, + }, + { + "id": 489, + "text": "ass", + "logprob": -0.0000019073486, + "special": False, + }, + { + "id": 11143, + "text": "istant", + "logprob": -0.000002026558, + "special": False, + }, + { + "id": 28766, + "text": "|", + "logprob": -0.0000015497208, + "special": False, + }, + { + "id": 28767, + "text": ">", + "logprob": -0.0000011920929, + "special": False, + }, + { + "id": 13, + "text": "\n", + "logprob": -0.00009703636, + "special": False, + }, + {"id": 28737, "text": "I", "logprob": -0.1953125, "special": False}, + { + "id": 28742, + "text": "'", + "logprob": -0.88183594, + "special": False, + }, + { + "id": 28719, + "text": "m", + "logprob": -0.00032639503, + "special": False, + }, + ], + }, + } + ] + return mock_response + + +def test_hf_test_completion_tgi(): + litellm.set_verbose = True try: - response = completion( - model="huggingface/HuggingFaceH4/zephyr-7b-beta", - messages=[{"content": "Hello, how are you?", "role": "user"}], - ) - # Add any assertions here to check the response - print(response) + with patch("requests.post", side_effect=tgi_mock_post): + response = completion( + model="huggingface/HuggingFaceH4/zephyr-7b-beta", + messages=[{"content": "Hello, how are you?", "role": "user"}], + max_tokens=10, + ) + # Add any assertions here to check the response + print(response) except litellm.ServiceUnavailableError as e: pass except Exception as e: @@ -1191,6 +1269,41 @@ def hf_test_completion_tgi(): # except Exception as e: # pytest.fail(f"Error occurred: {e}") # hf_test_completion_none_task() + + +def mock_post(url, data=None, json=None, headers=None): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {"Content-Type": "application/json"} + mock_response.json.return_value = [ + [ + {"label": "LABEL_0", "score": 0.9990691542625427}, + {"label": "LABEL_1", "score": 0.0009308889275416732}, + ] + ] + return mock_response + + +def test_hf_classifier_task(): + try: + with patch("requests.post", side_effect=mock_post): + litellm.set_verbose = True + user_message = "I like you. I love you" + messages = [{"content": user_message, "role": "user"}] + response = completion( + model="huggingface/shahrukhx01/question-vs-statement-classifier", + messages=messages, + hf_task="text-classification", + ) + print(f"response: {response}") + assert isinstance(response, litellm.ModelResponse) + assert isinstance(response.choices[0], litellm.Choices) + assert response.choices[0].message.content is not None + assert isinstance(response.choices[0].message.content, str) + except Exception as e: + pytest.fail(f"Error occurred: {str(e)}") + + ########################### End of Hugging Face Tests ############################################## # def test_completion_hf_api(): # # failing on circle ci commenting out diff --git a/litellm/utils.py b/litellm/utils.py index 838d0fe55..f20cd220c 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4840,6 +4840,7 @@ def get_optional_params_embeddings( def get_optional_params( # use the openai defaults # https://platform.openai.com/docs/api-reference/chat/create + model: str, functions=None, function_call=None, temperature=None, @@ -4853,7 +4854,6 @@ def get_optional_params( frequency_penalty=None, logit_bias=None, user=None, - model=None, custom_llm_provider="", response_format=None, seed=None, @@ -4882,7 +4882,7 @@ def get_optional_params( passed_params[k] = v - optional_params = {} + optional_params: Dict = {} common_auth_dict = litellm.common_cloud_provider_auth_params if custom_llm_provider in common_auth_dict["providers"]: @@ -5156,41 +5156,9 @@ def get_optional_params( model=model, custom_llm_provider=custom_llm_provider ) _check_valid_arg(supported_params=supported_params) - # temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None - if temperature is not None: - if temperature == 0.0 or temperature == 0: - # hugging face exception raised when temp==0 - # Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive - temperature = 0.01 - optional_params["temperature"] = temperature - if top_p is not None: - optional_params["top_p"] = top_p - if n is not None: - optional_params["best_of"] = n - optional_params["do_sample"] = ( - True # Need to sample if you want best of for hf inference endpoints - ) - if stream is not None: - optional_params["stream"] = stream - if stop is not None: - optional_params["stop"] = stop - if max_tokens is not None: - # HF TGI raises the following exception when max_new_tokens==0 - # Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive - if max_tokens == 0: - max_tokens = 1 - optional_params["max_new_tokens"] = max_tokens - if n is not None: - optional_params["best_of"] = n - if presence_penalty is not None: - optional_params["repetition_penalty"] = presence_penalty - if "echo" in passed_params: - # https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details - # Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False - optional_params["decoder_input_details"] = special_params["echo"] - passed_params.pop( - "echo", None - ) # since we handle translating echo, we should not send it to TGI request + optional_params = litellm.HuggingfaceConfig().map_openai_params( + non_default_params=non_default_params, optional_params=optional_params + ) elif custom_llm_provider == "together_ai": ## check if unsupported param passed in supported_params = get_supported_openai_params( @@ -6150,7 +6118,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str): "seed", ] elif custom_llm_provider == "huggingface": - return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"] + return litellm.HuggingfaceConfig().get_supported_openai_params() elif custom_llm_provider == "together_ai": return [ "stream",