forked from phoenix/litellm-mirror
test(test_completion.py): reintegrate testing for huggingface tgi + non-tgi
This commit is contained in:
parent
781d5888c3
commit
c17f221b89
3 changed files with 218 additions and 58 deletions
|
@ -6,10 +6,12 @@ import httpx, requests
|
|||
from .base import BaseLLM
|
||||
import time
|
||||
import litellm
|
||||
from typing import Callable, Dict, List, Any
|
||||
from typing import Callable, Dict, List, Any, Literal
|
||||
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, Usage
|
||||
from typing import Optional
|
||||
from .prompt_templates.factory import prompt_factory, custom_prompt
|
||||
from litellm.types.completion import ChatCompletionMessageToolCallParam
|
||||
import enum
|
||||
|
||||
|
||||
class HuggingfaceError(Exception):
|
||||
|
@ -39,11 +41,29 @@ class HuggingfaceError(Exception):
|
|||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
hf_task_list = [
|
||||
"text-generation-inference",
|
||||
"conversational",
|
||||
"text-classification",
|
||||
"text-generation",
|
||||
]
|
||||
|
||||
hf_tasks = Literal[
|
||||
"text-generation-inference",
|
||||
"conversational",
|
||||
"text-classification",
|
||||
"text-generation",
|
||||
]
|
||||
|
||||
|
||||
class HuggingfaceConfig:
|
||||
"""
|
||||
Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
|
||||
"""
|
||||
|
||||
hf_task: Optional[hf_tasks] = (
|
||||
None # litellm-specific param, used to know the api spec to use when calling huggingface api
|
||||
)
|
||||
best_of: Optional[int] = None
|
||||
decoder_input_details: Optional[bool] = None
|
||||
details: Optional[bool] = True # enables returning logprobs + best of
|
||||
|
@ -101,6 +121,51 @@ class HuggingfaceConfig:
|
|||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return [
|
||||
"stream",
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
"stop",
|
||||
"n",
|
||||
"echo",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self, non_default_params: dict, optional_params: dict
|
||||
) -> dict:
|
||||
for param, value in non_default_params.items():
|
||||
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
|
||||
if param == "temperature":
|
||||
if value == 0.0 or value == 0:
|
||||
# hugging face exception raised when temp==0
|
||||
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
|
||||
value = 0.01
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
if param == "n":
|
||||
optional_params["best_of"] = value
|
||||
optional_params["do_sample"] = (
|
||||
True # Need to sample if you want best of for hf inference endpoints
|
||||
)
|
||||
if param == "stream":
|
||||
optional_params["stream"] = value
|
||||
if param == "stop":
|
||||
optional_params["stop"] = value
|
||||
if param == "max_tokens":
|
||||
# HF TGI raises the following exception when max_new_tokens==0
|
||||
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
|
||||
if value == 0:
|
||||
value = 1
|
||||
optional_params["max_new_tokens"] = value
|
||||
if param == "echo":
|
||||
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
|
||||
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
|
||||
optional_params["decoder_input_details"] = True
|
||||
return optional_params
|
||||
|
||||
|
||||
def output_parser(generated_text: str):
|
||||
"""
|
||||
|
@ -162,7 +227,7 @@ def read_tgi_conv_models():
|
|||
return set(), set()
|
||||
|
||||
|
||||
def get_hf_task_for_model(model):
|
||||
def get_hf_task_for_model(model: str) -> hf_tasks:
|
||||
# read text file, cast it to set
|
||||
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
|
||||
tgi_models, conversational_models = read_tgi_conv_models()
|
||||
|
@ -171,7 +236,7 @@ def get_hf_task_for_model(model):
|
|||
elif model in conversational_models:
|
||||
return "conversational"
|
||||
elif "roneneldan/TinyStories" in model:
|
||||
return None
|
||||
return "text-generation"
|
||||
else:
|
||||
return "text-generation-inference" # default to tgi
|
||||
|
||||
|
@ -202,7 +267,7 @@ class Huggingface(BaseLLM):
|
|||
self,
|
||||
completion_response,
|
||||
model_response,
|
||||
task,
|
||||
task: hf_tasks,
|
||||
optional_params,
|
||||
encoding,
|
||||
input_text,
|
||||
|
@ -270,6 +335,10 @@ class Huggingface(BaseLLM):
|
|||
)
|
||||
choices_list.append(choice_obj)
|
||||
model_response["choices"].extend(choices_list)
|
||||
elif task == "text-classification":
|
||||
model_response["choices"][0]["message"]["content"] = json.dumps(
|
||||
completion_response
|
||||
)
|
||||
else:
|
||||
if len(completion_response[0]["generated_text"]) > 0:
|
||||
model_response["choices"][0]["message"]["content"] = output_parser(
|
||||
|
@ -332,7 +401,16 @@ class Huggingface(BaseLLM):
|
|||
exception_mapping_worked = False
|
||||
try:
|
||||
headers = self.validate_environment(api_key, headers)
|
||||
task = get_hf_task_for_model(model)
|
||||
if optional_params.get("hf_task") is None:
|
||||
task = get_hf_task_for_model(model)
|
||||
else:
|
||||
task = optional_params.get("hf_task") # type: ignore
|
||||
## VALIDATE API FORMAT
|
||||
if task is None or not isinstance(task, str) or task not in hf_task_list:
|
||||
raise Exception(
|
||||
"Invalid hf task - {}. Valid formats - {}.".format(task, hf_tasks)
|
||||
)
|
||||
|
||||
print_verbose(f"{model}, {task}")
|
||||
completion_url = ""
|
||||
input_text = ""
|
||||
|
@ -433,14 +511,15 @@ class Huggingface(BaseLLM):
|
|||
inference_params.pop("return_full_text")
|
||||
data = {
|
||||
"inputs": prompt,
|
||||
"parameters": inference_params,
|
||||
"stream": ( # type: ignore
|
||||
}
|
||||
if task == "text-generation-inference":
|
||||
data["parameters"] = inference_params
|
||||
data["stream"] = ( # type: ignore
|
||||
True
|
||||
if "stream" in optional_params
|
||||
and optional_params["stream"] == True
|
||||
else False
|
||||
),
|
||||
}
|
||||
)
|
||||
input_text = prompt
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
|
@ -531,10 +610,10 @@ class Huggingface(BaseLLM):
|
|||
isinstance(completion_response, dict)
|
||||
and "error" in completion_response
|
||||
):
|
||||
print_verbose(f"completion error: {completion_response['error']}")
|
||||
print_verbose(f"completion error: {completion_response['error']}") # type: ignore
|
||||
print_verbose(f"response.status_code: {response.status_code}")
|
||||
raise HuggingfaceError(
|
||||
message=completion_response["error"],
|
||||
message=completion_response["error"], # type: ignore
|
||||
status_code=response.status_code,
|
||||
)
|
||||
return self.convert_to_model_response_object(
|
||||
|
@ -563,7 +642,7 @@ class Huggingface(BaseLLM):
|
|||
data: dict,
|
||||
headers: dict,
|
||||
model_response: ModelResponse,
|
||||
task: str,
|
||||
task: hf_tasks,
|
||||
encoding: Any,
|
||||
input_text: str,
|
||||
model: str,
|
||||
|
|
|
@ -13,6 +13,7 @@ import litellm
|
|||
from litellm import embedding, completion, completion_cost, Timeout
|
||||
from litellm import RateLimitError
|
||||
from litellm.llms.prompt_templates.factory import anthropic_messages_pt
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
# litellm.num_retries=3
|
||||
litellm.cache = None
|
||||
|
@ -1145,15 +1146,92 @@ def test_get_hf_task_for_model():
|
|||
# ################### Hugging Face TGI models ########################
|
||||
# # TGI model
|
||||
# # this is a TGI model https://huggingface.co/glaiveai/glaive-coder-7b
|
||||
def hf_test_completion_tgi():
|
||||
# litellm.set_verbose=True
|
||||
def tgi_mock_post(url, data=None, json=None, headers=None):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"Content-Type": "application/json"}
|
||||
mock_response.json.return_value = [
|
||||
{
|
||||
"generated_text": "<|assistant|>\nI'm",
|
||||
"details": {
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"seed": None,
|
||||
"prefill": [],
|
||||
"tokens": [
|
||||
{
|
||||
"id": 28789,
|
||||
"text": "<",
|
||||
"logprob": -0.025222778,
|
||||
"special": False,
|
||||
},
|
||||
{
|
||||
"id": 28766,
|
||||
"text": "|",
|
||||
"logprob": -0.000003695488,
|
||||
"special": False,
|
||||
},
|
||||
{
|
||||
"id": 489,
|
||||
"text": "ass",
|
||||
"logprob": -0.0000019073486,
|
||||
"special": False,
|
||||
},
|
||||
{
|
||||
"id": 11143,
|
||||
"text": "istant",
|
||||
"logprob": -0.000002026558,
|
||||
"special": False,
|
||||
},
|
||||
{
|
||||
"id": 28766,
|
||||
"text": "|",
|
||||
"logprob": -0.0000015497208,
|
||||
"special": False,
|
||||
},
|
||||
{
|
||||
"id": 28767,
|
||||
"text": ">",
|
||||
"logprob": -0.0000011920929,
|
||||
"special": False,
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"text": "\n",
|
||||
"logprob": -0.00009703636,
|
||||
"special": False,
|
||||
},
|
||||
{"id": 28737, "text": "I", "logprob": -0.1953125, "special": False},
|
||||
{
|
||||
"id": 28742,
|
||||
"text": "'",
|
||||
"logprob": -0.88183594,
|
||||
"special": False,
|
||||
},
|
||||
{
|
||||
"id": 28719,
|
||||
"text": "m",
|
||||
"logprob": -0.00032639503,
|
||||
"special": False,
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
]
|
||||
return mock_response
|
||||
|
||||
|
||||
def test_hf_test_completion_tgi():
|
||||
litellm.set_verbose = True
|
||||
try:
|
||||
response = completion(
|
||||
model="huggingface/HuggingFaceH4/zephyr-7b-beta",
|
||||
messages=[{"content": "Hello, how are you?", "role": "user"}],
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
with patch("requests.post", side_effect=tgi_mock_post):
|
||||
response = completion(
|
||||
model="huggingface/HuggingFaceH4/zephyr-7b-beta",
|
||||
messages=[{"content": "Hello, how are you?", "role": "user"}],
|
||||
max_tokens=10,
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
except litellm.ServiceUnavailableError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
|
@ -1191,6 +1269,41 @@ def hf_test_completion_tgi():
|
|||
# except Exception as e:
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
# hf_test_completion_none_task()
|
||||
|
||||
|
||||
def mock_post(url, data=None, json=None, headers=None):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"Content-Type": "application/json"}
|
||||
mock_response.json.return_value = [
|
||||
[
|
||||
{"label": "LABEL_0", "score": 0.9990691542625427},
|
||||
{"label": "LABEL_1", "score": 0.0009308889275416732},
|
||||
]
|
||||
]
|
||||
return mock_response
|
||||
|
||||
|
||||
def test_hf_classifier_task():
|
||||
try:
|
||||
with patch("requests.post", side_effect=mock_post):
|
||||
litellm.set_verbose = True
|
||||
user_message = "I like you. I love you"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
response = completion(
|
||||
model="huggingface/shahrukhx01/question-vs-statement-classifier",
|
||||
messages=messages,
|
||||
hf_task="text-classification",
|
||||
)
|
||||
print(f"response: {response}")
|
||||
assert isinstance(response, litellm.ModelResponse)
|
||||
assert isinstance(response.choices[0], litellm.Choices)
|
||||
assert response.choices[0].message.content is not None
|
||||
assert isinstance(response.choices[0].message.content, str)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {str(e)}")
|
||||
|
||||
|
||||
########################### End of Hugging Face Tests ##############################################
|
||||
# def test_completion_hf_api():
|
||||
# # failing on circle ci commenting out
|
||||
|
|
|
@ -4840,6 +4840,7 @@ def get_optional_params_embeddings(
|
|||
def get_optional_params(
|
||||
# use the openai defaults
|
||||
# https://platform.openai.com/docs/api-reference/chat/create
|
||||
model: str,
|
||||
functions=None,
|
||||
function_call=None,
|
||||
temperature=None,
|
||||
|
@ -4853,7 +4854,6 @@ def get_optional_params(
|
|||
frequency_penalty=None,
|
||||
logit_bias=None,
|
||||
user=None,
|
||||
model=None,
|
||||
custom_llm_provider="",
|
||||
response_format=None,
|
||||
seed=None,
|
||||
|
@ -4882,7 +4882,7 @@ def get_optional_params(
|
|||
|
||||
passed_params[k] = v
|
||||
|
||||
optional_params = {}
|
||||
optional_params: Dict = {}
|
||||
|
||||
common_auth_dict = litellm.common_cloud_provider_auth_params
|
||||
if custom_llm_provider in common_auth_dict["providers"]:
|
||||
|
@ -5156,41 +5156,9 @@ def get_optional_params(
|
|||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
|
||||
if temperature is not None:
|
||||
if temperature == 0.0 or temperature == 0:
|
||||
# hugging face exception raised when temp==0
|
||||
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
|
||||
temperature = 0.01
|
||||
optional_params["temperature"] = temperature
|
||||
if top_p is not None:
|
||||
optional_params["top_p"] = top_p
|
||||
if n is not None:
|
||||
optional_params["best_of"] = n
|
||||
optional_params["do_sample"] = (
|
||||
True # Need to sample if you want best of for hf inference endpoints
|
||||
)
|
||||
if stream is not None:
|
||||
optional_params["stream"] = stream
|
||||
if stop is not None:
|
||||
optional_params["stop"] = stop
|
||||
if max_tokens is not None:
|
||||
# HF TGI raises the following exception when max_new_tokens==0
|
||||
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
|
||||
if max_tokens == 0:
|
||||
max_tokens = 1
|
||||
optional_params["max_new_tokens"] = max_tokens
|
||||
if n is not None:
|
||||
optional_params["best_of"] = n
|
||||
if presence_penalty is not None:
|
||||
optional_params["repetition_penalty"] = presence_penalty
|
||||
if "echo" in passed_params:
|
||||
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
|
||||
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
|
||||
optional_params["decoder_input_details"] = special_params["echo"]
|
||||
passed_params.pop(
|
||||
"echo", None
|
||||
) # since we handle translating echo, we should not send it to TGI request
|
||||
optional_params = litellm.HuggingfaceConfig().map_openai_params(
|
||||
non_default_params=non_default_params, optional_params=optional_params
|
||||
)
|
||||
elif custom_llm_provider == "together_ai":
|
||||
## check if unsupported param passed in
|
||||
supported_params = get_supported_openai_params(
|
||||
|
@ -6150,7 +6118,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
|
|||
"seed",
|
||||
]
|
||||
elif custom_llm_provider == "huggingface":
|
||||
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
|
||||
return litellm.HuggingfaceConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "together_ai":
|
||||
return [
|
||||
"stream",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue