test(test_completion.py): reintegrate testing for huggingface tgi + non-tgi

This commit is contained in:
Krrish Dholakia 2024-05-10 14:07:01 -07:00
parent 781d5888c3
commit c17f221b89
3 changed files with 218 additions and 58 deletions

View file

@ -6,10 +6,12 @@ import httpx, requests
from .base import BaseLLM from .base import BaseLLM
import time import time
import litellm import litellm
from typing import Callable, Dict, List, Any from typing import Callable, Dict, List, Any, Literal
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, Usage from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, Usage
from typing import Optional from typing import Optional
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.types.completion import ChatCompletionMessageToolCallParam
import enum
class HuggingfaceError(Exception): class HuggingfaceError(Exception):
@ -39,11 +41,29 @@ class HuggingfaceError(Exception):
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
hf_task_list = [
"text-generation-inference",
"conversational",
"text-classification",
"text-generation",
]
hf_tasks = Literal[
"text-generation-inference",
"conversational",
"text-classification",
"text-generation",
]
class HuggingfaceConfig: class HuggingfaceConfig:
""" """
Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
""" """
hf_task: Optional[hf_tasks] = (
None # litellm-specific param, used to know the api spec to use when calling huggingface api
)
best_of: Optional[int] = None best_of: Optional[int] = None
decoder_input_details: Optional[bool] = None decoder_input_details: Optional[bool] = None
details: Optional[bool] = True # enables returning logprobs + best of details: Optional[bool] = True # enables returning logprobs + best of
@ -101,6 +121,51 @@ class HuggingfaceConfig:
and v is not None and v is not None
} }
def get_supported_openai_params(self):
return [
"stream",
"temperature",
"max_tokens",
"top_p",
"stop",
"n",
"echo",
]
def map_openai_params(
self, non_default_params: dict, optional_params: dict
) -> dict:
for param, value in non_default_params.items():
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
if param == "temperature":
if value == 0.0 or value == 0:
# hugging face exception raised when temp==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
value = 0.01
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "n":
optional_params["best_of"] = value
optional_params["do_sample"] = (
True # Need to sample if you want best of for hf inference endpoints
)
if param == "stream":
optional_params["stream"] = value
if param == "stop":
optional_params["stop"] = value
if param == "max_tokens":
# HF TGI raises the following exception when max_new_tokens==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
if value == 0:
value = 1
optional_params["max_new_tokens"] = value
if param == "echo":
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
optional_params["decoder_input_details"] = True
return optional_params
def output_parser(generated_text: str): def output_parser(generated_text: str):
""" """
@ -162,7 +227,7 @@ def read_tgi_conv_models():
return set(), set() return set(), set()
def get_hf_task_for_model(model): def get_hf_task_for_model(model: str) -> hf_tasks:
# read text file, cast it to set # read text file, cast it to set
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt" # read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
tgi_models, conversational_models = read_tgi_conv_models() tgi_models, conversational_models = read_tgi_conv_models()
@ -171,7 +236,7 @@ def get_hf_task_for_model(model):
elif model in conversational_models: elif model in conversational_models:
return "conversational" return "conversational"
elif "roneneldan/TinyStories" in model: elif "roneneldan/TinyStories" in model:
return None return "text-generation"
else: else:
return "text-generation-inference" # default to tgi return "text-generation-inference" # default to tgi
@ -202,7 +267,7 @@ class Huggingface(BaseLLM):
self, self,
completion_response, completion_response,
model_response, model_response,
task, task: hf_tasks,
optional_params, optional_params,
encoding, encoding,
input_text, input_text,
@ -270,6 +335,10 @@ class Huggingface(BaseLLM):
) )
choices_list.append(choice_obj) choices_list.append(choice_obj)
model_response["choices"].extend(choices_list) model_response["choices"].extend(choices_list)
elif task == "text-classification":
model_response["choices"][0]["message"]["content"] = json.dumps(
completion_response
)
else: else:
if len(completion_response[0]["generated_text"]) > 0: if len(completion_response[0]["generated_text"]) > 0:
model_response["choices"][0]["message"]["content"] = output_parser( model_response["choices"][0]["message"]["content"] = output_parser(
@ -332,7 +401,16 @@ class Huggingface(BaseLLM):
exception_mapping_worked = False exception_mapping_worked = False
try: try:
headers = self.validate_environment(api_key, headers) headers = self.validate_environment(api_key, headers)
task = get_hf_task_for_model(model) if optional_params.get("hf_task") is None:
task = get_hf_task_for_model(model)
else:
task = optional_params.get("hf_task") # type: ignore
## VALIDATE API FORMAT
if task is None or not isinstance(task, str) or task not in hf_task_list:
raise Exception(
"Invalid hf task - {}. Valid formats - {}.".format(task, hf_tasks)
)
print_verbose(f"{model}, {task}") print_verbose(f"{model}, {task}")
completion_url = "" completion_url = ""
input_text = "" input_text = ""
@ -433,14 +511,15 @@ class Huggingface(BaseLLM):
inference_params.pop("return_full_text") inference_params.pop("return_full_text")
data = { data = {
"inputs": prompt, "inputs": prompt,
"parameters": inference_params, }
"stream": ( # type: ignore if task == "text-generation-inference":
data["parameters"] = inference_params
data["stream"] = ( # type: ignore
True True
if "stream" in optional_params if "stream" in optional_params
and optional_params["stream"] == True and optional_params["stream"] == True
else False else False
), )
}
input_text = prompt input_text = prompt
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
@ -531,10 +610,10 @@ class Huggingface(BaseLLM):
isinstance(completion_response, dict) isinstance(completion_response, dict)
and "error" in completion_response and "error" in completion_response
): ):
print_verbose(f"completion error: {completion_response['error']}") print_verbose(f"completion error: {completion_response['error']}") # type: ignore
print_verbose(f"response.status_code: {response.status_code}") print_verbose(f"response.status_code: {response.status_code}")
raise HuggingfaceError( raise HuggingfaceError(
message=completion_response["error"], message=completion_response["error"], # type: ignore
status_code=response.status_code, status_code=response.status_code,
) )
return self.convert_to_model_response_object( return self.convert_to_model_response_object(
@ -563,7 +642,7 @@ class Huggingface(BaseLLM):
data: dict, data: dict,
headers: dict, headers: dict,
model_response: ModelResponse, model_response: ModelResponse,
task: str, task: hf_tasks,
encoding: Any, encoding: Any,
input_text: str, input_text: str,
model: str, model: str,

View file

@ -13,6 +13,7 @@ import litellm
from litellm import embedding, completion, completion_cost, Timeout from litellm import embedding, completion, completion_cost, Timeout
from litellm import RateLimitError from litellm import RateLimitError
from litellm.llms.prompt_templates.factory import anthropic_messages_pt from litellm.llms.prompt_templates.factory import anthropic_messages_pt
from unittest.mock import patch, MagicMock
# litellm.num_retries=3 # litellm.num_retries=3
litellm.cache = None litellm.cache = None
@ -1145,15 +1146,92 @@ def test_get_hf_task_for_model():
# ################### Hugging Face TGI models ######################## # ################### Hugging Face TGI models ########################
# # TGI model # # TGI model
# # this is a TGI model https://huggingface.co/glaiveai/glaive-coder-7b # # this is a TGI model https://huggingface.co/glaiveai/glaive-coder-7b
def hf_test_completion_tgi(): def tgi_mock_post(url, data=None, json=None, headers=None):
# litellm.set_verbose=True mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json.return_value = [
{
"generated_text": "<|assistant|>\nI'm",
"details": {
"finish_reason": "length",
"generated_tokens": 10,
"seed": None,
"prefill": [],
"tokens": [
{
"id": 28789,
"text": "<",
"logprob": -0.025222778,
"special": False,
},
{
"id": 28766,
"text": "|",
"logprob": -0.000003695488,
"special": False,
},
{
"id": 489,
"text": "ass",
"logprob": -0.0000019073486,
"special": False,
},
{
"id": 11143,
"text": "istant",
"logprob": -0.000002026558,
"special": False,
},
{
"id": 28766,
"text": "|",
"logprob": -0.0000015497208,
"special": False,
},
{
"id": 28767,
"text": ">",
"logprob": -0.0000011920929,
"special": False,
},
{
"id": 13,
"text": "\n",
"logprob": -0.00009703636,
"special": False,
},
{"id": 28737, "text": "I", "logprob": -0.1953125, "special": False},
{
"id": 28742,
"text": "'",
"logprob": -0.88183594,
"special": False,
},
{
"id": 28719,
"text": "m",
"logprob": -0.00032639503,
"special": False,
},
],
},
}
]
return mock_response
def test_hf_test_completion_tgi():
litellm.set_verbose = True
try: try:
response = completion( with patch("requests.post", side_effect=tgi_mock_post):
model="huggingface/HuggingFaceH4/zephyr-7b-beta", response = completion(
messages=[{"content": "Hello, how are you?", "role": "user"}], model="huggingface/HuggingFaceH4/zephyr-7b-beta",
) messages=[{"content": "Hello, how are you?", "role": "user"}],
# Add any assertions here to check the response max_tokens=10,
print(response) )
# Add any assertions here to check the response
print(response)
except litellm.ServiceUnavailableError as e: except litellm.ServiceUnavailableError as e:
pass pass
except Exception as e: except Exception as e:
@ -1191,6 +1269,41 @@ def hf_test_completion_tgi():
# except Exception as e: # except Exception as e:
# pytest.fail(f"Error occurred: {e}") # pytest.fail(f"Error occurred: {e}")
# hf_test_completion_none_task() # hf_test_completion_none_task()
def mock_post(url, data=None, json=None, headers=None):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json.return_value = [
[
{"label": "LABEL_0", "score": 0.9990691542625427},
{"label": "LABEL_1", "score": 0.0009308889275416732},
]
]
return mock_response
def test_hf_classifier_task():
try:
with patch("requests.post", side_effect=mock_post):
litellm.set_verbose = True
user_message = "I like you. I love you"
messages = [{"content": user_message, "role": "user"}]
response = completion(
model="huggingface/shahrukhx01/question-vs-statement-classifier",
messages=messages,
hf_task="text-classification",
)
print(f"response: {response}")
assert isinstance(response, litellm.ModelResponse)
assert isinstance(response.choices[0], litellm.Choices)
assert response.choices[0].message.content is not None
assert isinstance(response.choices[0].message.content, str)
except Exception as e:
pytest.fail(f"Error occurred: {str(e)}")
########################### End of Hugging Face Tests ############################################## ########################### End of Hugging Face Tests ##############################################
# def test_completion_hf_api(): # def test_completion_hf_api():
# # failing on circle ci commenting out # # failing on circle ci commenting out

View file

@ -4840,6 +4840,7 @@ def get_optional_params_embeddings(
def get_optional_params( def get_optional_params(
# use the openai defaults # use the openai defaults
# https://platform.openai.com/docs/api-reference/chat/create # https://platform.openai.com/docs/api-reference/chat/create
model: str,
functions=None, functions=None,
function_call=None, function_call=None,
temperature=None, temperature=None,
@ -4853,7 +4854,6 @@ def get_optional_params(
frequency_penalty=None, frequency_penalty=None,
logit_bias=None, logit_bias=None,
user=None, user=None,
model=None,
custom_llm_provider="", custom_llm_provider="",
response_format=None, response_format=None,
seed=None, seed=None,
@ -4882,7 +4882,7 @@ def get_optional_params(
passed_params[k] = v passed_params[k] = v
optional_params = {} optional_params: Dict = {}
common_auth_dict = litellm.common_cloud_provider_auth_params common_auth_dict = litellm.common_cloud_provider_auth_params
if custom_llm_provider in common_auth_dict["providers"]: if custom_llm_provider in common_auth_dict["providers"]:
@ -5156,41 +5156,9 @@ def get_optional_params(
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider
) )
_check_valid_arg(supported_params=supported_params) _check_valid_arg(supported_params=supported_params)
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None optional_params = litellm.HuggingfaceConfig().map_openai_params(
if temperature is not None: non_default_params=non_default_params, optional_params=optional_params
if temperature == 0.0 or temperature == 0: )
# hugging face exception raised when temp==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
temperature = 0.01
optional_params["temperature"] = temperature
if top_p is not None:
optional_params["top_p"] = top_p
if n is not None:
optional_params["best_of"] = n
optional_params["do_sample"] = (
True # Need to sample if you want best of for hf inference endpoints
)
if stream is not None:
optional_params["stream"] = stream
if stop is not None:
optional_params["stop"] = stop
if max_tokens is not None:
# HF TGI raises the following exception when max_new_tokens==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
if max_tokens == 0:
max_tokens = 1
optional_params["max_new_tokens"] = max_tokens
if n is not None:
optional_params["best_of"] = n
if presence_penalty is not None:
optional_params["repetition_penalty"] = presence_penalty
if "echo" in passed_params:
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
optional_params["decoder_input_details"] = special_params["echo"]
passed_params.pop(
"echo", None
) # since we handle translating echo, we should not send it to TGI request
elif custom_llm_provider == "together_ai": elif custom_llm_provider == "together_ai":
## check if unsupported param passed in ## check if unsupported param passed in
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
@ -6150,7 +6118,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
"seed", "seed",
] ]
elif custom_llm_provider == "huggingface": elif custom_llm_provider == "huggingface":
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"] return litellm.HuggingfaceConfig().get_supported_openai_params()
elif custom_llm_provider == "together_ai": elif custom_llm_provider == "together_ai":
return [ return [
"stream", "stream",