test(test_completion.py): reintegrate testing for huggingface tgi + non-tgi

This commit is contained in:
Krrish Dholakia 2024-05-10 14:07:01 -07:00
parent 781d5888c3
commit c17f221b89
3 changed files with 218 additions and 58 deletions

View file

@ -6,10 +6,12 @@ import httpx, requests
from .base import BaseLLM
import time
import litellm
from typing import Callable, Dict, List, Any
from typing import Callable, Dict, List, Any, Literal
from litellm.utils import ModelResponse, Choices, Message, CustomStreamWrapper, Usage
from typing import Optional
from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.types.completion import ChatCompletionMessageToolCallParam
import enum
class HuggingfaceError(Exception):
@ -39,11 +41,29 @@ class HuggingfaceError(Exception):
) # Call the base class constructor with the parameters it needs
hf_task_list = [
"text-generation-inference",
"conversational",
"text-classification",
"text-generation",
]
hf_tasks = Literal[
"text-generation-inference",
"conversational",
"text-classification",
"text-generation",
]
class HuggingfaceConfig:
"""
Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
"""
hf_task: Optional[hf_tasks] = (
None # litellm-specific param, used to know the api spec to use when calling huggingface api
)
best_of: Optional[int] = None
decoder_input_details: Optional[bool] = None
details: Optional[bool] = True # enables returning logprobs + best of
@ -101,6 +121,51 @@ class HuggingfaceConfig:
and v is not None
}
def get_supported_openai_params(self):
return [
"stream",
"temperature",
"max_tokens",
"top_p",
"stop",
"n",
"echo",
]
def map_openai_params(
self, non_default_params: dict, optional_params: dict
) -> dict:
for param, value in non_default_params.items():
# temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
if param == "temperature":
if value == 0.0 or value == 0:
# hugging face exception raised when temp==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
value = 0.01
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "n":
optional_params["best_of"] = value
optional_params["do_sample"] = (
True # Need to sample if you want best of for hf inference endpoints
)
if param == "stream":
optional_params["stream"] = value
if param == "stop":
optional_params["stop"] = value
if param == "max_tokens":
# HF TGI raises the following exception when max_new_tokens==0
# Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
if value == 0:
value = 1
optional_params["max_new_tokens"] = value
if param == "echo":
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
# Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
optional_params["decoder_input_details"] = True
return optional_params
def output_parser(generated_text: str):
"""
@ -162,7 +227,7 @@ def read_tgi_conv_models():
return set(), set()
def get_hf_task_for_model(model):
def get_hf_task_for_model(model: str) -> hf_tasks:
# read text file, cast it to set
# read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
tgi_models, conversational_models = read_tgi_conv_models()
@ -171,7 +236,7 @@ def get_hf_task_for_model(model):
elif model in conversational_models:
return "conversational"
elif "roneneldan/TinyStories" in model:
return None
return "text-generation"
else:
return "text-generation-inference" # default to tgi
@ -202,7 +267,7 @@ class Huggingface(BaseLLM):
self,
completion_response,
model_response,
task,
task: hf_tasks,
optional_params,
encoding,
input_text,
@ -270,6 +335,10 @@ class Huggingface(BaseLLM):
)
choices_list.append(choice_obj)
model_response["choices"].extend(choices_list)
elif task == "text-classification":
model_response["choices"][0]["message"]["content"] = json.dumps(
completion_response
)
else:
if len(completion_response[0]["generated_text"]) > 0:
model_response["choices"][0]["message"]["content"] = output_parser(
@ -332,7 +401,16 @@ class Huggingface(BaseLLM):
exception_mapping_worked = False
try:
headers = self.validate_environment(api_key, headers)
task = get_hf_task_for_model(model)
if optional_params.get("hf_task") is None:
task = get_hf_task_for_model(model)
else:
task = optional_params.get("hf_task") # type: ignore
## VALIDATE API FORMAT
if task is None or not isinstance(task, str) or task not in hf_task_list:
raise Exception(
"Invalid hf task - {}. Valid formats - {}.".format(task, hf_tasks)
)
print_verbose(f"{model}, {task}")
completion_url = ""
input_text = ""
@ -433,14 +511,15 @@ class Huggingface(BaseLLM):
inference_params.pop("return_full_text")
data = {
"inputs": prompt,
"parameters": inference_params,
"stream": ( # type: ignore
}
if task == "text-generation-inference":
data["parameters"] = inference_params
data["stream"] = ( # type: ignore
True
if "stream" in optional_params
and optional_params["stream"] == True
else False
),
}
)
input_text = prompt
## LOGGING
logging_obj.pre_call(
@ -531,10 +610,10 @@ class Huggingface(BaseLLM):
isinstance(completion_response, dict)
and "error" in completion_response
):
print_verbose(f"completion error: {completion_response['error']}")
print_verbose(f"completion error: {completion_response['error']}") # type: ignore
print_verbose(f"response.status_code: {response.status_code}")
raise HuggingfaceError(
message=completion_response["error"],
message=completion_response["error"], # type: ignore
status_code=response.status_code,
)
return self.convert_to_model_response_object(
@ -563,7 +642,7 @@ class Huggingface(BaseLLM):
data: dict,
headers: dict,
model_response: ModelResponse,
task: str,
task: hf_tasks,
encoding: Any,
input_text: str,
model: str,