Merge branch 'main' into feat/add-dana-agent-provider-stub

This commit is contained in:
Zooey Nguyen 2025-11-10 16:36:05 -08:00 committed by GitHub
commit 3f85df3da2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
62 changed files with 3463 additions and 3817 deletions

View file

@ -1,303 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionMessage,
StopReason,
SystemMessage,
SystemMessageBehavior,
ToolCall,
ToolConfig,
UserMessage,
)
from llama_stack.models.llama.datatypes import (
BuiltinTool,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
chat_completion_request_to_prompt,
interleaved_content_as_str,
)
MODEL = "Llama3.1-8B-Instruct"
MODEL3_2 = "Llama3.2-3B-Instruct"
async def test_system_default():
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
assert len(messages) == 2
assert messages[-1].content == content
assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content)
async def test_system_builtin_only():
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(tool_name=BuiltinTool.brave_search),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
assert len(messages) == 2
assert messages[-1].content == content
assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content)
assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content)
async def test_system_custom_only():
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
tools=[
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
input_schema={
"type": "object",
"properties": {
"param1": {
"type": "str",
"description": "param1 description",
},
},
"required": ["param1"],
},
)
],
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
)
messages = chat_completion_request_to_messages(request, MODEL)
assert len(messages) == 3
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content)
assert messages[-1].content == content
async def test_system_custom_and_builtin():
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(tool_name=BuiltinTool.brave_search),
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
input_schema={
"type": "object",
"properties": {
"param1": {
"type": "str",
"description": "param1 description",
},
},
"required": ["param1"],
},
),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
assert len(messages) == 3
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content)
assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content)
assert messages[-1].content == content
async def test_completion_message_encoding():
request = ChatCompletionRequest(
model=MODEL3_2,
messages=[
UserMessage(content="hello"),
CompletionMessage(
content="",
stop_reason=StopReason.end_of_turn,
tool_calls=[
ToolCall(
tool_name="custom1",
arguments='{"param1": "value1"}', # arguments must be a JSON string
call_id="123",
)
],
),
],
tools=[
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
input_schema={
"type": "object",
"properties": {
"param1": {
"type": "str",
"description": "param1 description",
},
},
"required": ["param1"],
},
),
],
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list),
)
prompt = await chat_completion_request_to_prompt(request, request.model)
assert '[custom1(param1="value1")]' in prompt
request.model = MODEL
request.tool_config = ToolConfig(tool_prompt_format=ToolPromptFormat.json)
prompt = await chat_completion_request_to_prompt(request, request.model)
assert '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}' in prompt
async def test_user_provided_system_message():
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
assert len(messages) == 2
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
assert messages[-1].content == content
async def test_replace_system_message_behavior_builtin_tools():
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format=ToolPromptFormat.python_list,
system_message_behavior=SystemMessageBehavior.replace,
),
)
messages = chat_completion_request_to_messages(request, MODEL3_2)
assert len(messages) == 2
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
assert messages[-1].content == content
async def test_replace_system_message_behavior_custom_tools():
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
input_schema={
"type": "object",
"properties": {
"param1": {
"type": "str",
"description": "param1 description",
},
},
"required": ["param1"],
},
),
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format=ToolPromptFormat.python_list,
system_message_behavior=SystemMessageBehavior.replace,
),
)
messages = chat_completion_request_to_messages(request, MODEL3_2)
assert len(messages) == 2
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
assert messages[-1].content == content
async def test_replace_system_message_behavior_custom_tools_with_template():
content = "Hello !"
system_prompt = "You are a pirate {{ function_description }}"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
input_schema={
"type": "object",
"properties": {
"param1": {
"type": "str",
"description": "param1 description",
},
},
"required": ["param1"],
},
),
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format=ToolPromptFormat.python_list,
system_message_behavior=SystemMessageBehavior.replace,
),
)
messages = chat_completion_request_to_messages(request, MODEL3_2)
assert len(messages) == 2
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
assert "You are a pirate" in interleaved_content_as_str(messages[0].content)
# function description is present in the system prompt
assert '"name": "custom1"' in interleaved_content_as_str(messages[0].content)
assert messages[-1].content == content

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,44 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import Mock
import pytest
from llama_stack.providers.inline.inference.meta_reference.model_parallel import (
ModelRunner,
)
class TestModelRunner:
"""Test ModelRunner task dispatching for model-parallel inference."""
def test_chat_completion_task_dispatch(self):
"""Verify ModelRunner correctly dispatches chat_completion tasks."""
# Create a mock generator
mock_generator = Mock()
mock_generator.chat_completion = Mock(return_value=iter([]))
runner = ModelRunner(mock_generator)
# Create a chat_completion task
fake_params = {"model": "test"}
fake_messages = [{"role": "user", "content": "test"}]
task = ("chat_completion", [fake_params, fake_messages])
# Execute task
runner(task)
# Verify chat_completion was called with correct arguments
mock_generator.chat_completion.assert_called_once_with(fake_params, fake_messages)
def test_invalid_task_type_raises_error(self):
"""Verify ModelRunner rejects invalid task types."""
mock_generator = Mock()
runner = ModelRunner(mock_generator)
with pytest.raises(ValueError, match="Unexpected task type"):
runner(("invalid_task", []))

View file

@ -10,11 +10,13 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from llama_stack.apis.inference import CompletionMessage, UserMessage
from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIUserMessageParam,
)
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.safety import RunShieldResponse, ViolationLevel
from llama_stack.apis.shields import Shield
from llama_stack.models.llama.datatypes import StopReason
from llama_stack.providers.remote.safety.nvidia.config import NVIDIASafetyConfig
from llama_stack.providers.remote.safety.nvidia.nvidia import NVIDIASafetyAdapter
@ -136,11 +138,9 @@ async def test_run_shield_allowed(nvidia_adapter, mock_guardrails_post):
# Run the shield
messages = [
UserMessage(role="user", content="Hello, how are you?"),
CompletionMessage(
role="assistant",
OpenAIUserMessageParam(content="Hello, how are you?"),
OpenAIAssistantMessageParam(
content="I'm doing well, thank you for asking!",
stop_reason=StopReason.end_of_message,
tool_calls=[],
),
]
@ -191,13 +191,10 @@ async def test_run_shield_blocked(nvidia_adapter, mock_guardrails_post):
# Mock Guardrails API response
mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}}
# Run the shield
messages = [
UserMessage(role="user", content="Hello, how are you?"),
CompletionMessage(
role="assistant",
OpenAIUserMessageParam(content="Hello, how are you?"),
OpenAIAssistantMessageParam(
content="I'm doing well, thank you for asking!",
stop_reason=StopReason.end_of_message,
tool_calls=[],
),
]
@ -243,7 +240,7 @@ async def test_run_shield_not_found(nvidia_adapter, mock_guardrails_post):
adapter.shield_store.get_shield.return_value = None
messages = [
UserMessage(role="user", content="Hello, how are you?"),
OpenAIUserMessageParam(content="Hello, how are you?"),
]
with pytest.raises(ValueError):
@ -274,11 +271,9 @@ async def test_run_shield_http_error(nvidia_adapter, mock_guardrails_post):
# Running the shield should raise an exception
messages = [
UserMessage(role="user", content="Hello, how are you?"),
CompletionMessage(
role="assistant",
OpenAIUserMessageParam(content="Hello, how are you?"),
OpenAIAssistantMessageParam(
content="I'm doing well, thank you for asking!",
stop_reason=StopReason.end_of_message,
tool_calls=[],
),
]

View file

@ -1,220 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from pydantic import ValidationError
from llama_stack.apis.common.content_types import TextContentItem
from llama_stack.apis.inference import (
CompletionMessage,
OpenAIAssistantMessageParam,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIDeveloperMessageParam,
OpenAIImageURL,
OpenAISystemMessageParam,
OpenAIToolMessageParam,
OpenAIUserMessageParam,
SystemMessage,
UserMessage,
)
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
convert_message_to_openai_dict_new,
openai_messages_to_messages,
)
async def test_convert_message_to_openai_dict():
message = UserMessage(content=[TextContentItem(text="Hello, world!")], role="user")
assert await convert_message_to_openai_dict(message) == {
"role": "user",
"content": [{"type": "text", "text": "Hello, world!"}],
}
# Test convert_message_to_openai_dict with a tool call
async def test_convert_message_to_openai_dict_with_tool_call():
message = CompletionMessage(
content="",
tool_calls=[ToolCall(call_id="123", tool_name="test_tool", arguments='{"foo": "bar"}')],
stop_reason=StopReason.end_of_turn,
)
openai_dict = await convert_message_to_openai_dict(message)
assert openai_dict == {
"role": "assistant",
"content": [{"type": "text", "text": ""}],
"tool_calls": [
{"id": "123", "type": "function", "function": {"name": "test_tool", "arguments": '{"foo": "bar"}'}}
],
}
async def test_convert_message_to_openai_dict_with_builtin_tool_call():
message = CompletionMessage(
content="",
tool_calls=[
ToolCall(
call_id="123",
tool_name=BuiltinTool.brave_search,
arguments='{"foo": "bar"}',
)
],
stop_reason=StopReason.end_of_turn,
)
openai_dict = await convert_message_to_openai_dict(message)
assert openai_dict == {
"role": "assistant",
"content": [{"type": "text", "text": ""}],
"tool_calls": [
{"id": "123", "type": "function", "function": {"name": "brave_search", "arguments": '{"foo": "bar"}'}}
],
}
async def test_openai_messages_to_messages_with_content_str():
openai_messages = [
OpenAISystemMessageParam(content="system message"),
OpenAIUserMessageParam(content="user message"),
OpenAIAssistantMessageParam(content="assistant message"),
]
llama_messages = openai_messages_to_messages(openai_messages)
assert len(llama_messages) == 3
assert isinstance(llama_messages[0], SystemMessage)
assert isinstance(llama_messages[1], UserMessage)
assert isinstance(llama_messages[2], CompletionMessage)
assert llama_messages[0].content == "system message"
assert llama_messages[1].content == "user message"
assert llama_messages[2].content == "assistant message"
async def test_openai_messages_to_messages_with_content_list():
openai_messages = [
OpenAISystemMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="system message")]),
OpenAIUserMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="user message")]),
OpenAIAssistantMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="assistant message")]),
]
llama_messages = openai_messages_to_messages(openai_messages)
assert len(llama_messages) == 3
assert isinstance(llama_messages[0], SystemMessage)
assert isinstance(llama_messages[1], UserMessage)
assert isinstance(llama_messages[2], CompletionMessage)
assert llama_messages[0].content[0].text == "system message"
assert llama_messages[1].content[0].text == "user message"
assert llama_messages[2].content[0].text == "assistant message"
@pytest.mark.parametrize(
"message_class,kwargs",
[
(OpenAISystemMessageParam, {}),
(OpenAIAssistantMessageParam, {}),
(OpenAIDeveloperMessageParam, {}),
(OpenAIUserMessageParam, {}),
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
],
)
def test_message_accepts_text_string(message_class, kwargs):
"""Test that messages accept string text content."""
msg = message_class(content="Test message", **kwargs)
assert msg.content == "Test message"
@pytest.mark.parametrize(
"message_class,kwargs",
[
(OpenAISystemMessageParam, {}),
(OpenAIAssistantMessageParam, {}),
(OpenAIDeveloperMessageParam, {}),
(OpenAIUserMessageParam, {}),
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
],
)
def test_message_accepts_text_list(message_class, kwargs):
"""Test that messages accept list of text content parts."""
content_list = [OpenAIChatCompletionContentPartTextParam(text="Test message")]
msg = message_class(content=content_list, **kwargs)
assert len(msg.content) == 1
assert msg.content[0].text == "Test message"
@pytest.mark.parametrize(
"message_class,kwargs",
[
(OpenAISystemMessageParam, {}),
(OpenAIAssistantMessageParam, {}),
(OpenAIDeveloperMessageParam, {}),
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
],
)
def test_message_rejects_images(message_class, kwargs):
"""Test that system, assistant, developer, and tool messages reject image content."""
with pytest.raises(ValidationError):
message_class(
content=[
OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg"))
],
**kwargs,
)
def test_user_message_accepts_images():
"""Test that user messages accept image content (unlike other message types)."""
# List with images should work
msg = OpenAIUserMessageParam(
content=[
OpenAIChatCompletionContentPartTextParam(text="Describe this image:"),
OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg")),
]
)
assert len(msg.content) == 2
assert msg.content[0].text == "Describe this image:"
assert msg.content[1].image_url.url == "http://example.com/image.jpg"
async def test_convert_message_to_openai_dict_new_user_message():
"""Test convert_message_to_openai_dict_new with UserMessage."""
message = UserMessage(content="Hello, world!", role="user")
result = await convert_message_to_openai_dict_new(message)
assert result["role"] == "user"
assert result["content"] == "Hello, world!"
async def test_convert_message_to_openai_dict_new_completion_message_with_tool_calls():
"""Test convert_message_to_openai_dict_new with CompletionMessage containing tool calls."""
message = CompletionMessage(
content="I'll help you find the weather.",
tool_calls=[
ToolCall(
call_id="call_123",
tool_name="get_weather",
arguments='{"city": "Sligo"}',
)
],
stop_reason=StopReason.end_of_turn,
)
result = await convert_message_to_openai_dict_new(message)
# This would have failed with "Cannot instantiate typing.Union" before the fix
assert result["role"] == "assistant"
assert result["content"] == "I'll help you find the weather."
assert "tool_calls" in result
assert result["tool_calls"] is not None
assert len(result["tool_calls"]) == 1
tool_call = result["tool_calls"][0]
assert tool_call.id == "call_123"
assert tool_call.type == "function"
assert tool_call.function.name == "get_weather"
assert tool_call.function.arguments == '{"city": "Sligo"}'

View file

@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIUserMessageParam,
)
from llama_stack.models.llama.datatypes import RawTextItem
from llama_stack.providers.utils.inference.prompt_adapter import (
convert_openai_message_to_raw_message,
)
class TestConvertOpenAIMessageToRawMessage:
"""Test conversion of OpenAI message types to RawMessage format."""
async def test_user_message_conversion(self):
msg = OpenAIUserMessageParam(role="user", content="Hello world")
raw_msg = await convert_openai_message_to_raw_message(msg)
assert raw_msg.role == "user"
assert isinstance(raw_msg.content, RawTextItem)
assert raw_msg.content.text == "Hello world"
async def test_assistant_message_conversion(self):
msg = OpenAIAssistantMessageParam(role="assistant", content="Hi there!")
raw_msg = await convert_openai_message_to_raw_message(msg)
assert raw_msg.role == "assistant"
assert isinstance(raw_msg.content, RawTextItem)
assert raw_msg.content.text == "Hi there!"
assert raw_msg.tool_calls == []