fix(factory.py): support 'add_generation_prompt' field for hf chat templates

Fixes https://github.com/BerriAI/litellm/pull/5178#issuecomment-2306362008
This commit is contained in:
Krrish Dholakia 2024-08-23 08:06:21 -07:00
parent afb00a27cb
commit 874d58fe8a
3 changed files with 84 additions and 7 deletions

View file

@ -179,6 +179,9 @@ class HuggingfaceConfig:
optional_params["decoder_input_details"] = True optional_params["decoder_input_details"] = True
return optional_params return optional_params
def get_hf_api_key(self) -> Optional[str]:
return litellm.utils.get_secret("HUGGINGFACE_API_KEY")
def output_parser(generated_text: str): def output_parser(generated_text: str):
""" """

View file

@ -14,6 +14,7 @@ import litellm
import litellm.types import litellm.types
import litellm.types.llms import litellm.types.llms
import litellm.types.llms.vertex_ai import litellm.types.llms.vertex_ai
from litellm import verbose_logger
from litellm.llms.custom_httpx.http_handler import HTTPHandler from litellm.llms.custom_httpx.http_handler import HTTPHandler
from litellm.types.completion import ( from litellm.types.completion import (
ChatCompletionFunctionMessageParam, ChatCompletionFunctionMessageParam,
@ -380,12 +381,14 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
if chat_template is None: if chat_template is None:
def _get_tokenizer_config(hf_model_name): def _get_tokenizer_config(hf_model_name):
url = ( try:
f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json" url = f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json"
)
# Make a GET request to fetch the JSON data # Make a GET request to fetch the JSON data
client = HTTPHandler(concurrent_limit=1) client = HTTPHandler(concurrent_limit=1)
response = client.get(url) response = client.get(url)
except Exception as e:
raise e
if response.status_code == 200: if response.status_code == 200:
# Parse the JSON data # Parse the JSON data
tokenizer_config = json.loads(response.content) tokenizer_config = json.loads(response.content)
@ -397,6 +400,7 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
tokenizer_config = known_tokenizer_config[model] tokenizer_config = known_tokenizer_config[model]
else: else:
tokenizer_config = _get_tokenizer_config(model) tokenizer_config = _get_tokenizer_config(model)
if ( if (
tokenizer_config["status"] == "failure" tokenizer_config["status"] == "failure"
or "chat_template" not in tokenizer_config["tokenizer"] or "chat_template" not in tokenizer_config["tokenizer"]
@ -406,7 +410,13 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
tokenizer_config = tokenizer_config["tokenizer"] # type: ignore tokenizer_config = tokenizer_config["tokenizer"] # type: ignore
bos_token = tokenizer_config["bos_token"] # type: ignore bos_token = tokenizer_config["bos_token"] # type: ignore
if bos_token is not None and not isinstance(bos_token, str):
if isinstance(bos_token, dict):
bos_token = bos_token.get("content", None)
eos_token = tokenizer_config["eos_token"] # type: ignore eos_token = tokenizer_config["eos_token"] # type: ignore
if eos_token is not None and not isinstance(eos_token, str):
if isinstance(eos_token, dict):
eos_token = eos_token.get("content", None)
chat_template = tokenizer_config["chat_template"] # type: ignore chat_template = tokenizer_config["chat_template"] # type: ignore
try: try:
template = env.from_string(chat_template) # type: ignore template = env.from_string(chat_template) # type: ignore
@ -431,7 +441,10 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
# Render the template with the provided values # Render the template with the provided values
if _is_system_in_template(): if _is_system_in_template():
rendered_text = template.render( rendered_text = template.render(
bos_token=bos_token, eos_token=eos_token, messages=messages bos_token=bos_token,
eos_token=eos_token,
messages=messages,
add_generation_prompt=True,
) )
else: else:
# treat a system message as a user message, if system not in template # treat a system message as a user message, if system not in template
@ -448,6 +461,7 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
bos_token=bos_token, bos_token=bos_token,
eos_token=eos_token, eos_token=eos_token,
messages=reformatted_messages, messages=reformatted_messages,
add_generation_prompt=True,
) )
except Exception as e: except Exception as e:
if "Conversation roles must alternate user/assistant" in str(e): if "Conversation roles must alternate user/assistant" in str(e):
@ -469,8 +483,12 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
rendered_text = template.render( rendered_text = template.render(
bos_token=bos_token, eos_token=eos_token, messages=new_messages bos_token=bos_token, eos_token=eos_token, messages=new_messages
) )
return rendered_text return rendered_text
except Exception as e: except Exception as e:
verbose_logger.exception(
"Error rendering huggingface chat template - {}".format(str(e))
)
raise Exception(f"Error rendering template - {str(e)}") raise Exception(f"Error rendering template - {str(e)}")

View file

@ -253,6 +253,62 @@ async def test_completion_sagemaker_non_stream():
) )
@pytest.mark.asyncio
async def test_completion_sagemaker_prompt_template_non_stream():
mock_response = MagicMock()
def return_val():
return {
"generated_text": "This is a mock response from SageMaker.",
"id": "cmpl-mockid",
"object": "text_completion",
"created": 1629800000,
"model": "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
"choices": [
{
"text": "This is a mock response from SageMaker.",
"index": 0,
"logprobs": None,
"finish_reason": "length",
}
],
"usage": {"prompt_tokens": 1, "completion_tokens": 8, "total_tokens": 9},
}
mock_response.json = return_val
mock_response.status_code = 200
expected_payload = {
"inputs": "<begin▁of▁sentence>You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\n\n### Instruction:\nhi\n\n\n### Response:\n",
"parameters": {"temperature": 0.2, "max_new_tokens": 80},
}
with patch(
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
return_value=mock_response,
) as mock_post:
# Act: Call the litellm.acompletion function
response = litellm.completion(
model="sagemaker/deepseek_coder_6.7_instruct",
messages=[
{"role": "user", "content": "hi"},
],
temperature=0.2,
max_tokens=80,
hf_model_name="mistralai/Mistral-7B-Instruct-v0.1",
)
# Print what was called on the mock
print("call args=", mock_post.call_args)
# Assert
mock_post.assert_called_once()
_, kwargs = mock_post.call_args
args_to_sagemaker = kwargs["json"]
print("Arguments passed to sagemaker=", args_to_sagemaker)
assert args_to_sagemaker == expected_payload
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_completion_sagemaker_non_stream_with_aws_params(): async def test_completion_sagemaker_non_stream_with_aws_params():
mock_response = MagicMock() mock_response = MagicMock()