fix(factory.py): support 'add_generation_prompt' field for hf chat templates

Fixes https://github.com/BerriAI/litellm/pull/5178#issuecomment-2306362008
This commit is contained in:
Krrish Dholakia 2024-08-23 08:06:21 -07:00
parent afb00a27cb
commit 874d58fe8a
3 changed files with 84 additions and 7 deletions

View file

@ -179,6 +179,9 @@ class HuggingfaceConfig:
optional_params["decoder_input_details"] = True
return optional_params
def get_hf_api_key(self) -> Optional[str]:
return litellm.utils.get_secret("HUGGINGFACE_API_KEY")
def output_parser(generated_text: str):
"""

View file

@ -14,6 +14,7 @@ import litellm
import litellm.types
import litellm.types.llms
import litellm.types.llms.vertex_ai
from litellm import verbose_logger
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from litellm.types.completion import (
ChatCompletionFunctionMessageParam,
@ -380,12 +381,14 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
if chat_template is None:
def _get_tokenizer_config(hf_model_name):
url = (
f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json"
)
try:
url = f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json"
# Make a GET request to fetch the JSON data
client = HTTPHandler(concurrent_limit=1)
response = client.get(url)
except Exception as e:
raise e
if response.status_code == 200:
# Parse the JSON data
tokenizer_config = json.loads(response.content)
@ -397,6 +400,7 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
tokenizer_config = known_tokenizer_config[model]
else:
tokenizer_config = _get_tokenizer_config(model)
if (
tokenizer_config["status"] == "failure"
or "chat_template" not in tokenizer_config["tokenizer"]
@ -406,7 +410,13 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
tokenizer_config = tokenizer_config["tokenizer"] # type: ignore
bos_token = tokenizer_config["bos_token"] # type: ignore
if bos_token is not None and not isinstance(bos_token, str):
if isinstance(bos_token, dict):
bos_token = bos_token.get("content", None)
eos_token = tokenizer_config["eos_token"] # type: ignore
if eos_token is not None and not isinstance(eos_token, str):
if isinstance(eos_token, dict):
eos_token = eos_token.get("content", None)
chat_template = tokenizer_config["chat_template"] # type: ignore
try:
template = env.from_string(chat_template) # type: ignore
@ -431,7 +441,10 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
# Render the template with the provided values
if _is_system_in_template():
rendered_text = template.render(
bos_token=bos_token, eos_token=eos_token, messages=messages
bos_token=bos_token,
eos_token=eos_token,
messages=messages,
add_generation_prompt=True,
)
else:
# treat a system message as a user message, if system not in template
@ -448,6 +461,7 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
bos_token=bos_token,
eos_token=eos_token,
messages=reformatted_messages,
add_generation_prompt=True,
)
except Exception as e:
if "Conversation roles must alternate user/assistant" in str(e):
@ -469,8 +483,12 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
rendered_text = template.render(
bos_token=bos_token, eos_token=eos_token, messages=new_messages
)
return rendered_text
except Exception as e:
verbose_logger.exception(
"Error rendering huggingface chat template - {}".format(str(e))
)
raise Exception(f"Error rendering template - {str(e)}")

View file

@ -253,6 +253,62 @@ async def test_completion_sagemaker_non_stream():
)
@pytest.mark.asyncio
async def test_completion_sagemaker_prompt_template_non_stream():
mock_response = MagicMock()
def return_val():
return {
"generated_text": "This is a mock response from SageMaker.",
"id": "cmpl-mockid",
"object": "text_completion",
"created": 1629800000,
"model": "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
"choices": [
{
"text": "This is a mock response from SageMaker.",
"index": 0,
"logprobs": None,
"finish_reason": "length",
}
],
"usage": {"prompt_tokens": 1, "completion_tokens": 8, "total_tokens": 9},
}
mock_response.json = return_val
mock_response.status_code = 200
expected_payload = {
"inputs": "<begin▁of▁sentence>You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\n\n### Instruction:\nhi\n\n\n### Response:\n",
"parameters": {"temperature": 0.2, "max_new_tokens": 80},
}
with patch(
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
return_value=mock_response,
) as mock_post:
# Act: Call the litellm.acompletion function
response = litellm.completion(
model="sagemaker/deepseek_coder_6.7_instruct",
messages=[
{"role": "user", "content": "hi"},
],
temperature=0.2,
max_tokens=80,
hf_model_name="mistralai/Mistral-7B-Instruct-v0.1",
)
# Print what was called on the mock
print("call args=", mock_post.call_args)
# Assert
mock_post.assert_called_once()
_, kwargs = mock_post.call_args
args_to_sagemaker = kwargs["json"]
print("Arguments passed to sagemaker=", args_to_sagemaker)
assert args_to_sagemaker == expected_payload
@pytest.mark.asyncio
async def test_completion_sagemaker_non_stream_with_aws_params():
mock_response = MagicMock()