Merge branch 'main' into test-modelregistryhelper

This commit is contained in:
Matthew Farrellee 2025-04-27 10:56:30 -04:00
commit 7fd8a61b4d
80 changed files with 2918 additions and 386 deletions

View file

@ -524,11 +524,26 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
else:
content = [await _convert_content(message.content)]
return {
result = {
"role": message.role,
"content": content,
}
if hasattr(message, "tool_calls") and message.tool_calls:
result["tool_calls"] = []
for tc in message.tool_calls:
result["tool_calls"].append(
{
"id": tc.call_id,
"type": "function",
"function": {
"name": tc.tool_name,
"arguments": tc.arguments_json if hasattr(tc, "arguments_json") else json.dumps(tc.arguments),
},
}
)
return result
class UnparseableToolCall(BaseModel):
"""

View file

@ -52,6 +52,9 @@ from llama_stack.models.llama.llama3.prompt_templates import (
SystemDefaultGenerator,
)
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4,
)
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
from llama_stack.providers.utils.inference import supported_inference_models
@ -306,10 +309,11 @@ def chat_completion_request_to_messages(
elif model.model_family in (
ModelFamily.llama3_2,
ModelFamily.llama3_3,
ModelFamily.llama4,
):
# llama3.2, llama3.3 and llama4 models follow the same tool prompt format
messages = augment_messages_for_tools_llama_3_2(request)
# llama3.2, llama3.3 follow the same tool prompt format
messages = augment_messages_for_tools_llama(request, PythonListCustomToolGenerator)
elif model.model_family == ModelFamily.llama4:
messages = augment_messages_for_tools_llama(request, PythonListCustomToolGeneratorLlama4)
else:
messages = request.messages
@ -399,8 +403,9 @@ def augment_messages_for_tools_llama_3_1(
return messages
def augment_messages_for_tools_llama_3_2(
def augment_messages_for_tools_llama(
request: ChatCompletionRequest,
custom_tool_prompt_generator,
) -> List[Message]:
existing_messages = request.messages
existing_system_message = None
@ -434,7 +439,7 @@ def augment_messages_for_tools_llama_3_2(
if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace:
system_prompt = existing_system_message.content
tool_template = PythonListCustomToolGenerator().gen(custom_tools, system_prompt)
tool_template = custom_tool_prompt_generator().gen(custom_tools, system_prompt)
sys_content += tool_template.render()
sys_content += "\n"