forked from phoenix/litellm-mirror
fix(factory.py): support 'add_generation_prompt' field for hf chat templates
Fixes https://github.com/BerriAI/litellm/pull/5178#issuecomment-2306362008
This commit is contained in:
parent
afb00a27cb
commit
874d58fe8a
3 changed files with 84 additions and 7 deletions
|
@ -179,6 +179,9 @@ class HuggingfaceConfig:
|
||||||
optional_params["decoder_input_details"] = True
|
optional_params["decoder_input_details"] = True
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
|
def get_hf_api_key(self) -> Optional[str]:
|
||||||
|
return litellm.utils.get_secret("HUGGINGFACE_API_KEY")
|
||||||
|
|
||||||
|
|
||||||
def output_parser(generated_text: str):
|
def output_parser(generated_text: str):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -14,6 +14,7 @@ import litellm
|
||||||
import litellm.types
|
import litellm.types
|
||||||
import litellm.types.llms
|
import litellm.types.llms
|
||||||
import litellm.types.llms.vertex_ai
|
import litellm.types.llms.vertex_ai
|
||||||
|
from litellm import verbose_logger
|
||||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
from litellm.types.completion import (
|
from litellm.types.completion import (
|
||||||
ChatCompletionFunctionMessageParam,
|
ChatCompletionFunctionMessageParam,
|
||||||
|
@ -380,12 +381,14 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
|
||||||
if chat_template is None:
|
if chat_template is None:
|
||||||
|
|
||||||
def _get_tokenizer_config(hf_model_name):
|
def _get_tokenizer_config(hf_model_name):
|
||||||
url = (
|
try:
|
||||||
f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json"
|
url = f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json"
|
||||||
)
|
|
||||||
# Make a GET request to fetch the JSON data
|
# Make a GET request to fetch the JSON data
|
||||||
client = HTTPHandler(concurrent_limit=1)
|
client = HTTPHandler(concurrent_limit=1)
|
||||||
|
|
||||||
response = client.get(url)
|
response = client.get(url)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
# Parse the JSON data
|
# Parse the JSON data
|
||||||
tokenizer_config = json.loads(response.content)
|
tokenizer_config = json.loads(response.content)
|
||||||
|
@ -397,6 +400,7 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
|
||||||
tokenizer_config = known_tokenizer_config[model]
|
tokenizer_config = known_tokenizer_config[model]
|
||||||
else:
|
else:
|
||||||
tokenizer_config = _get_tokenizer_config(model)
|
tokenizer_config = _get_tokenizer_config(model)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
tokenizer_config["status"] == "failure"
|
tokenizer_config["status"] == "failure"
|
||||||
or "chat_template" not in tokenizer_config["tokenizer"]
|
or "chat_template" not in tokenizer_config["tokenizer"]
|
||||||
|
@ -406,7 +410,13 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
|
||||||
tokenizer_config = tokenizer_config["tokenizer"] # type: ignore
|
tokenizer_config = tokenizer_config["tokenizer"] # type: ignore
|
||||||
|
|
||||||
bos_token = tokenizer_config["bos_token"] # type: ignore
|
bos_token = tokenizer_config["bos_token"] # type: ignore
|
||||||
|
if bos_token is not None and not isinstance(bos_token, str):
|
||||||
|
if isinstance(bos_token, dict):
|
||||||
|
bos_token = bos_token.get("content", None)
|
||||||
eos_token = tokenizer_config["eos_token"] # type: ignore
|
eos_token = tokenizer_config["eos_token"] # type: ignore
|
||||||
|
if eos_token is not None and not isinstance(eos_token, str):
|
||||||
|
if isinstance(eos_token, dict):
|
||||||
|
eos_token = eos_token.get("content", None)
|
||||||
chat_template = tokenizer_config["chat_template"] # type: ignore
|
chat_template = tokenizer_config["chat_template"] # type: ignore
|
||||||
try:
|
try:
|
||||||
template = env.from_string(chat_template) # type: ignore
|
template = env.from_string(chat_template) # type: ignore
|
||||||
|
@ -431,7 +441,10 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
|
||||||
# Render the template with the provided values
|
# Render the template with the provided values
|
||||||
if _is_system_in_template():
|
if _is_system_in_template():
|
||||||
rendered_text = template.render(
|
rendered_text = template.render(
|
||||||
bos_token=bos_token, eos_token=eos_token, messages=messages
|
bos_token=bos_token,
|
||||||
|
eos_token=eos_token,
|
||||||
|
messages=messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# treat a system message as a user message, if system not in template
|
# treat a system message as a user message, if system not in template
|
||||||
|
@ -448,6 +461,7 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
|
||||||
bos_token=bos_token,
|
bos_token=bos_token,
|
||||||
eos_token=eos_token,
|
eos_token=eos_token,
|
||||||
messages=reformatted_messages,
|
messages=reformatted_messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if "Conversation roles must alternate user/assistant" in str(e):
|
if "Conversation roles must alternate user/assistant" in str(e):
|
||||||
|
@ -469,8 +483,12 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
|
||||||
rendered_text = template.render(
|
rendered_text = template.render(
|
||||||
bos_token=bos_token, eos_token=eos_token, messages=new_messages
|
bos_token=bos_token, eos_token=eos_token, messages=new_messages
|
||||||
)
|
)
|
||||||
|
|
||||||
return rendered_text
|
return rendered_text
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
verbose_logger.exception(
|
||||||
|
"Error rendering huggingface chat template - {}".format(str(e))
|
||||||
|
)
|
||||||
raise Exception(f"Error rendering template - {str(e)}")
|
raise Exception(f"Error rendering template - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -253,6 +253,62 @@ async def test_completion_sagemaker_non_stream():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_completion_sagemaker_prompt_template_non_stream():
|
||||||
|
mock_response = MagicMock()
|
||||||
|
|
||||||
|
def return_val():
|
||||||
|
return {
|
||||||
|
"generated_text": "This is a mock response from SageMaker.",
|
||||||
|
"id": "cmpl-mockid",
|
||||||
|
"object": "text_completion",
|
||||||
|
"created": 1629800000,
|
||||||
|
"model": "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"text": "This is a mock response from SageMaker.",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": None,
|
||||||
|
"finish_reason": "length",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {"prompt_tokens": 1, "completion_tokens": 8, "total_tokens": 9},
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_response.json = return_val
|
||||||
|
mock_response.status_code = 200
|
||||||
|
|
||||||
|
expected_payload = {
|
||||||
|
"inputs": "<|begin▁of▁sentence|>You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\n\n### Instruction:\nhi\n\n\n### Response:\n",
|
||||||
|
"parameters": {"temperature": 0.2, "max_new_tokens": 80},
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
|
||||||
|
return_value=mock_response,
|
||||||
|
) as mock_post:
|
||||||
|
# Act: Call the litellm.acompletion function
|
||||||
|
response = litellm.completion(
|
||||||
|
model="sagemaker/deepseek_coder_6.7_instruct",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "hi"},
|
||||||
|
],
|
||||||
|
temperature=0.2,
|
||||||
|
max_tokens=80,
|
||||||
|
hf_model_name="mistralai/Mistral-7B-Instruct-v0.1",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print what was called on the mock
|
||||||
|
print("call args=", mock_post.call_args)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
_, kwargs = mock_post.call_args
|
||||||
|
args_to_sagemaker = kwargs["json"]
|
||||||
|
print("Arguments passed to sagemaker=", args_to_sagemaker)
|
||||||
|
assert args_to_sagemaker == expected_payload
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_completion_sagemaker_non_stream_with_aws_params():
|
async def test_completion_sagemaker_non_stream_with_aws_params():
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue