new prompt

This commit is contained in:
Eric Huang 2025-04-24 13:16:42 -07:00
parent 7ed137e963
commit 8e9217774a
2 changed files with 226 additions and 27 deletions

View file

@ -12,7 +12,6 @@ import re
from typing import List, Optional, Tuple, Union
import httpx
from PIL import Image as PIL_Image
from llama_stack.apis.common.content_types import (
ImageContentItem,
@ -52,9 +51,13 @@ from llama_stack.models.llama.llama3.prompt_templates import (
SystemDefaultGenerator,
)
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4,
)
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
from llama_stack.models.llama.sku_types import is_multimodal, ModelFamily
from llama_stack.providers.utils.inference import supported_inference_models
from PIL import Image as PIL_Image
log = get_logger(name=__name__, category="inference")
@ -128,7 +131,9 @@ async def interleaved_content_convert_to_raw(
if image.url.uri.startswith("data"):
match = re.match(r"data:image/(\w+);base64,(.+)", image.url.uri)
if not match:
raise ValueError(f"Invalid data URL format, {image.url.uri[:40]}...")
raise ValueError(
f"Invalid data URL format, {image.url.uri[:40]}..."
)
_, image_data = match.groups()
data = base64.b64decode(image_data)
elif image.url.uri.startswith("file://"):
@ -208,13 +213,17 @@ async def convert_image_content_to_url(
content, format = await localize_image_content(media)
if include_format:
return f"data:image/{format};base64," + base64.b64encode(content).decode("utf-8")
return f"data:image/{format};base64," + base64.b64encode(content).decode(
"utf-8"
)
else:
return base64.b64encode(content).decode("utf-8")
async def completion_request_to_prompt(request: CompletionRequest) -> str:
content = augment_content_with_response_format_prompt(request.response_format, request.content)
content = augment_content_with_response_format_prompt(
request.response_format, request.content
)
request.content = content
request = await convert_request_to_raw(request)
@ -226,7 +235,9 @@ async def completion_request_to_prompt(request: CompletionRequest) -> str:
async def completion_request_to_prompt_model_input_info(
request: CompletionRequest,
) -> Tuple[str, int]:
content = augment_content_with_response_format_prompt(request.response_format, request.content)
content = augment_content_with_response_format_prompt(
request.response_format, request.content
)
request.content = content
request = await convert_request_to_raw(request)
@ -247,7 +258,9 @@ def augment_content_with_response_format_prompt(response_format, content):
return content
async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llama_model: str) -> str:
async def chat_completion_request_to_prompt(
request: ChatCompletionRequest, llama_model: str
) -> str:
messages = chat_completion_request_to_messages(request, llama_model)
request.messages = messages
request = await convert_request_to_raw(request)
@ -255,7 +268,8 @@ async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llam
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_dialog_prompt(
request.messages,
tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
tool_prompt_format=request.tool_config.tool_prompt_format
or get_default_tool_prompt_format(llama_model),
)
return formatter.tokenizer.decode(model_input.tokens)
@ -270,7 +284,8 @@ async def chat_completion_request_to_model_input_info(
formatter = ChatFormat(tokenizer=Tokenizer.get_instance())
model_input = formatter.encode_dialog_prompt(
request.messages,
tool_prompt_format=request.tool_config.tool_prompt_format or get_default_tool_prompt_format(llama_model),
tool_prompt_format=request.tool_config.tool_prompt_format
or get_default_tool_prompt_format(llama_model),
)
return (
formatter.tokenizer.decode(model_input.tokens),
@ -299,17 +314,23 @@ def chat_completion_request_to_messages(
return request.messages
if model.model_family == ModelFamily.llama3_1 or (
model.model_family == ModelFamily.llama3_2 and is_multimodal(model.core_model_id)
model.model_family == ModelFamily.llama3_2
and is_multimodal(model.core_model_id)
):
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
messages = augment_messages_for_tools_llama_3_1(request)
elif model.model_family in (
ModelFamily.llama3_2,
ModelFamily.llama3_3,
ModelFamily.llama4,
):
# llama3.2, llama3.3 and llama4 models follow the same tool prompt format
messages = augment_messages_for_tools_llama_3_2(request)
# llama3.2, llama3.3 follow the same tool prompt format
messages = augment_messages_for_tools_llama_3_2(
request, PythonListCustomToolGenerator
)
elif model.model_family == ModelFamily.llama4:
messages = augment_messages_for_tools_llama_3_2(
request, PythonListCustomToolGeneratorLlama4
)
else:
messages = request.messages
@ -339,7 +360,9 @@ def augment_messages_for_tools_llama_3_1(
if existing_messages[0].role == Role.system.value:
existing_system_message = existing_messages.pop(0)
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
assert (
existing_messages[0].role != Role.system.value
), "Should only have 1 system message"
messages = []
@ -371,9 +394,13 @@ def augment_messages_for_tools_llama_3_1(
if isinstance(existing_system_message.content, str):
sys_content += _process(existing_system_message.content)
elif isinstance(existing_system_message.content, list):
sys_content += "\n".join([_process(c) for c in existing_system_message.content])
sys_content += "\n".join(
[_process(c) for c in existing_system_message.content]
)
tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools)
tool_choice_prompt = _get_tool_choice_prompt(
request.tool_config.tool_choice, request.tools
)
if tool_choice_prompt:
sys_content += "\n" + tool_choice_prompt
@ -401,13 +428,16 @@ def augment_messages_for_tools_llama_3_1(
def augment_messages_for_tools_llama_3_2(
request: ChatCompletionRequest,
custom_tool_prompt_generator,
) -> List[Message]:
existing_messages = request.messages
existing_system_message = None
if existing_messages[0].role == Role.system.value:
existing_system_message = existing_messages.pop(0)
assert existing_messages[0].role != Role.system.value, "Should only have 1 system message"
assert (
existing_messages[0].role != Role.system.value
), "Should only have 1 system message"
sys_content = ""
custom_tools, builtin_tools = [], []
@ -428,23 +458,34 @@ def augment_messages_for_tools_llama_3_2(
if custom_tools:
fmt = request.tool_config.tool_prompt_format or ToolPromptFormat.python_list
if fmt != ToolPromptFormat.python_list:
raise ValueError(f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}")
raise ValueError(
f"Non supported ToolPromptFormat {request.tool_config.tool_prompt_format}"
)
system_prompt = None
if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace:
if (
existing_system_message
and request.tool_config.system_message_behavior
== SystemMessageBehavior.replace
):
system_prompt = existing_system_message.content
tool_template = PythonListCustomToolGenerator().gen(custom_tools, system_prompt)
tool_template = custom_tool_prompt_generator().gen(custom_tools, system_prompt)
sys_content += tool_template.render()
sys_content += "\n"
if existing_system_message and (
request.tool_config.system_message_behavior == SystemMessageBehavior.append or not custom_tools
request.tool_config.system_message_behavior == SystemMessageBehavior.append
or not custom_tools
):
sys_content += interleaved_content_as_str(existing_system_message.content, sep="\n")
sys_content += interleaved_content_as_str(
existing_system_message.content, sep="\n"
)
tool_choice_prompt = _get_tool_choice_prompt(request.tool_config.tool_choice, request.tools)
tool_choice_prompt = _get_tool_choice_prompt(
request.tool_config.tool_choice, request.tools
)
if tool_choice_prompt:
sys_content += "\n" + tool_choice_prompt
@ -452,11 +493,15 @@ def augment_messages_for_tools_llama_3_2(
return messages
def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefinition]) -> str:
def _get_tool_choice_prompt(
tool_choice: ToolChoice | str, tools: List[ToolDefinition]
) -> str:
if tool_choice == ToolChoice.auto:
return ""
elif tool_choice == ToolChoice.required:
return "You MUST use one of the provided functions/tools to answer the user query."
return (
"You MUST use one of the provided functions/tools to answer the user query."
)
elif tool_choice == ToolChoice.none:
# tools are already not passed in
return ""
@ -468,11 +513,14 @@ def _get_tool_choice_prompt(tool_choice: ToolChoice | str, tools: List[ToolDefin
def get_default_tool_prompt_format(model: str) -> ToolPromptFormat:
llama_model = resolve_model(model)
if llama_model is None:
log.warning(f"Could not resolve model {model}, defaulting to json tool prompt format")
log.warning(
f"Could not resolve model {model}, defaulting to json tool prompt format"
)
return ToolPromptFormat.json
if llama_model.model_family == ModelFamily.llama3_1 or (
llama_model.model_family == ModelFamily.llama3_2 and is_multimodal(llama_model.core_model_id)
llama_model.model_family == ModelFamily.llama3_2
and is_multimodal(llama_model.core_model_id)
):
# llama3.1 and llama3.2 multimodal models follow the same tool prompt format
return ToolPromptFormat.json