test(test_completion.py): reintegrate testing for huggingface tgi + non-tgi

This commit is contained in:
Krrish Dholakia 2024-05-10 14:07:01 -07:00
parent 781d5888c3
commit c17f221b89
3 changed files with 218 additions and 58 deletions

View file

@ -6,10 +6,12 @@ import httpx, requests
from .base import BaseLLM
import time
import litellm
from typing import Callable, Dict, List, Any
from typing import Callable, Dict, List, Any, Literal
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, Usage
from typing import Optional
from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.types.completion import ChatCompletionMessageToolCallParam
import enum
class HuggingfaceError(Exception):
@ -39,11 +41,29 @@ class HuggingfaceError(Exception):
) # Call the base class constructor with the parameters it needs
hf_task_list = [
"text-generation-inference",
"conversational",
"text-classification",
"text-generation",
]
hf_tasks = Literal[
"text-generation-inference",
"conversational",
"text-classification",
"text-generation",
]
class HuggingfaceConfig:
"""
Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
"""
hf_task: Optional[hf_tasks] = (
None # litellm-specific param, used to know the api spec to use when calling huggingface api
)
best_of: Optional[int] = None
decoder_input_details: Optional[bool] = None
details: Optional[bool] = True # enables returning logprobs + best of
@ -101,6 +121,51 @@ class HuggingfaceConfig:
and v is not None
}
def get_supported_openai_params(self):
return [
"stream",
"temperature",
"max_tokens",
"top_p",
"stop",
"n",
"echo",
]
def map_openai_params(
self, non_default_params: dict, optional_params: dict
) -> dict:
for param, value in non_default_params.items():
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
if param == "temperature":
if value == 0.0 or value == 0:
# hugging face exception raised when temp==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
value = 0.01
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "n":
optional_params["best_of"] = value
optional_params["do_sample"] = (
True # Need to sample if you want best of for hf inference endpoints
)
if param == "stream":
optional_params["stream"] = value
if param == "stop":
optional_params["stop"] = value
if param == "max_tokens":
# HF TGI raises the following exception when max_new_tokens==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
if value == 0:
value = 1
optional_params["max_new_tokens"] = value
if param == "echo":
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
optional_params["decoder_input_details"] = True
return optional_params
def output_parser(generated_text: str):
"""
@ -162,7 +227,7 @@ def read_tgi_conv_models():
return set(), set()
def get_hf_task_for_model(model):
def get_hf_task_for_model(model: str) -> hf_tasks:
# read text file, cast it to set
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
tgi_models, conversational_models = read_tgi_conv_models()
@ -171,7 +236,7 @@ def get_hf_task_for_model(model):
elif model in conversational_models:
return "conversational"
elif "roneneldan/TinyStories" in model:
return None
return "text-generation"
else:
return "text-generation-inference" # default to tgi
@ -202,7 +267,7 @@ class Huggingface(BaseLLM):
self,
completion_response,
model_response,
task,
task: hf_tasks,
optional_params,
encoding,
input_text,
@ -270,6 +335,10 @@ class Huggingface(BaseLLM):
)
choices_list.append(choice_obj)
model_response["choices"].extend(choices_list)
elif task == "text-classification":
model_response["choices"][0]["message"]["content"] = json.dumps(
completion_response
)
else:
if len(completion_response[0]["generated_text"]) > 0:
model_response["choices"][0]["message"]["content"] = output_parser(
@ -332,7 +401,16 @@ class Huggingface(BaseLLM):
exception_mapping_worked = False
try:
headers = self.validate_environment(api_key, headers)
task = get_hf_task_for_model(model)
if optional_params.get("hf_task") is None:
task = get_hf_task_for_model(model)
else:
task = optional_params.get("hf_task") # type: ignore
## VALIDATE API FORMAT
if task is None or not isinstance(task, str) or task not in hf_task_list:
raise Exception(
"Invalid hf task - {}. Valid formats - {}.".format(task, hf_tasks)
)
print_verbose(f"{model}, {task}")
completion_url = ""
input_text = ""
@ -433,14 +511,15 @@ class Huggingface(BaseLLM):
inference_params.pop("return_full_text")
data = {
"inputs": prompt,
"parameters": inference_params,
"stream": ( # type: ignore
}
if task == "text-generation-inference":
data["parameters"] = inference_params
data["stream"] = ( # type: ignore
True
if "stream" in optional_params
and optional_params["stream"] == True
else False
),
}
)
input_text = prompt
## LOGGING
logging_obj.pre_call(
@ -531,10 +610,10 @@ class Huggingface(BaseLLM):
isinstance(completion_response, dict)
and "error" in completion_response
):
print_verbose(f"completion error: {completion_response['error']}")
print_verbose(f"completion error: {completion_response['error']}") # type: ignore
print_verbose(f"response.status_code: {response.status_code}")
raise HuggingfaceError(
message=completion_response["error"],
message=completion_response["error"], # type: ignore
status_code=response.status_code,
)
return self.convert_to_model_response_object(
@ -563,7 +642,7 @@ class Huggingface(BaseLLM):
data: dict,
headers: dict,
model_response: ModelResponse,
task: str,
task: hf_tasks,
encoding: Any,
input_text: str,
model: str,

View file

@ -13,6 +13,7 @@ import litellm
from litellm import embedding, completion, completion_cost, Timeout
from litellm import RateLimitError
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
from unittest.mock import patch, MagicMock
# litellm.num_retries=3
litellm.cache = None
@ -1145,15 +1146,92 @@ def test_get_hf_task_for_model():
# ################### Hugging Face TGI models ########################
# # TGI model
# # this is a TGI model https://huggingface.co/glaiveai/glaive-coder-7b
def hf_test_completion_tgi():
# litellm.set_verbose=True
def tgi_mock_post(url, data=None, json=None, headers=None):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json.return_value = [
{
"generated_text": "<|assistant|>\nI'm",
"details": {
"finish_reason": "length",
"generated_tokens": 10,
"seed": None,
"prefill": [],
"tokens": [
{
"id": 28789,
"text": "<",
"logprob": -0.025222778,
"special": False,
},
{
"id": 28766,
"text": "|",
"logprob": -0.000003695488,
"special": False,
},
{
"id": 489,
"text": "ass",
"logprob": -0.0000019073486,
"special": False,
},
{
"id": 11143,
"text": "istant",
"logprob": -0.000002026558,
"special": False,
},
{
"id": 28766,
"text": "|",
"logprob": -0.0000015497208,
"special": False,
},
{
"id": 28767,
"text": ">",
"logprob": -0.0000011920929,
"special": False,
},
{
"id": 13,
"text": "\n",
"logprob": -0.00009703636,
"special": False,
},
{"id": 28737, "text": "I", "logprob": -0.1953125, "special": False},
{
"id": 28742,
"text": "'",
"logprob": -0.88183594,
"special": False,
},
{
"id": 28719,
"text": "m",
"logprob": -0.00032639503,
"special": False,
},
],
},
}
]
return mock_response
def test_hf_test_completion_tgi():
litellm.set_verbose = True
try:
response = completion(
model="huggingface/HuggingFaceH4/zephyr-7b-beta",
messages=[{"content": "Hello, how are you?", "role": "user"}],
)
# Add any assertions here to check the response
print(response)
with patch("requests.post", side_effect=tgi_mock_post):
response = completion(
model="huggingface/HuggingFaceH4/zephyr-7b-beta",
messages=[{"content": "Hello, how are you?", "role": "user"}],
max_tokens=10,
)
# Add any assertions here to check the response
print(response)
except litellm.ServiceUnavailableError as e:
pass
except Exception as e:
@ -1191,6 +1269,41 @@ def hf_test_completion_tgi():
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# hf_test_completion_none_task()
def mock_post(url, data=None, json=None, headers=None):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json.return_value = [
[
{"label": "LABEL_0", "score": 0.9990691542625427},
{"label": "LABEL_1", "score": 0.0009308889275416732},
]
]
return mock_response
def test_hf_classifier_task():
try:
with patch("requests.post", side_effect=mock_post):
litellm.set_verbose = True
user_message = "I like you. I love you"
messages = [{"content": user_message, "role": "user"}]
response = completion(
model="huggingface/shahrukhx01/question-vs-statement-classifier",
messages=messages,
hf_task="text-classification",
)
print(f"response: {response}")
assert isinstance(response, litellm.ModelResponse)
assert isinstance(response.choices[0], litellm.Choices)
assert response.choices[0].message.content is not None
assert isinstance(response.choices[0].message.content, str)
except Exception as e:
pytest.fail(f"Error occurred: {str(e)}")
########################### End of Hugging Face Tests ##############################################
# def test_completion_hf_api():
# # failing on circle ci commenting out

View file

@ -4840,6 +4840,7 @@ def get_optional_params_embeddings(
def get_optional_params(
# use the openai defaults
# https://platform.openai.com/docs/api-reference/chat/create
model: str,
functions=None,
function_call=None,
temperature=None,
@ -4853,7 +4854,6 @@ def get_optional_params(
frequency_penalty=None,
logit_bias=None,
user=None,
model=None,
custom_llm_provider="",
response_format=None,
seed=None,
@ -4882,7 +4882,7 @@ def get_optional_params(
passed_params[k] = v
optional_params = {}
optional_params: Dict = {}
common_auth_dict = litellm.common_cloud_provider_auth_params
if custom_llm_provider in common_auth_dict["providers"]:
@ -5156,41 +5156,9 @@ def get_optional_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
if temperature is not None:
if temperature == 0.0 or temperature == 0:
# hugging face exception raised when temp==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
temperature = 0.01
optional_params["temperature"] = temperature
if top_p is not None:
optional_params["top_p"] = top_p
if n is not None:
optional_params["best_of"] = n
optional_params["do_sample"] = (
True # Need to sample if you want best of for hf inference endpoints
)
if stream is not None:
optional_params["stream"] = stream
if stop is not None:
optional_params["stop"] = stop
if max_tokens is not None:
# HF TGI raises the following exception when max_new_tokens==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
if max_tokens == 0:
max_tokens = 1
optional_params["max_new_tokens"] = max_tokens
if n is not None:
optional_params["best_of"] = n
if presence_penalty is not None:
optional_params["repetition_penalty"] = presence_penalty
if "echo" in passed_params:
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
optional_params["decoder_input_details"] = special_params["echo"]
passed_params.pop(
"echo", None
) # since we handle translating echo, we should not send it to TGI request
optional_params = litellm.HuggingfaceConfig().map_openai_params(
non_default_params=non_default_params, optional_params=optional_params
)
elif custom_llm_provider == "together_ai":
## check if unsupported param passed in
supported_params = get_supported_openai_params(
@ -6150,7 +6118,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
"seed",
]
elif custom_llm_provider == "huggingface":
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
return litellm.HuggingfaceConfig().get_supported_openai_params()
elif custom_llm_provider == "together_ai":
return [
"stream",