mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-05 02:17:31 +00:00
Merge branch 'main' into feat/add-dana-agent-provider-stub
This commit is contained in:
commit
3f85df3da2
62 changed files with 3463 additions and 3817 deletions
5
tests/unit/providers/inline/inference/__init__.py
Normal file
5
tests/unit/providers/inline/inference/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
44
tests/unit/providers/inline/inference/test_meta_reference.py
Normal file
44
tests/unit/providers/inline/inference/test_meta_reference.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.providers.inline.inference.meta_reference.model_parallel import (
|
||||
ModelRunner,
|
||||
)
|
||||
|
||||
|
||||
class TestModelRunner:
|
||||
"""Test ModelRunner task dispatching for model-parallel inference."""
|
||||
|
||||
def test_chat_completion_task_dispatch(self):
|
||||
"""Verify ModelRunner correctly dispatches chat_completion tasks."""
|
||||
# Create a mock generator
|
||||
mock_generator = Mock()
|
||||
mock_generator.chat_completion = Mock(return_value=iter([]))
|
||||
|
||||
runner = ModelRunner(mock_generator)
|
||||
|
||||
# Create a chat_completion task
|
||||
fake_params = {"model": "test"}
|
||||
fake_messages = [{"role": "user", "content": "test"}]
|
||||
task = ("chat_completion", [fake_params, fake_messages])
|
||||
|
||||
# Execute task
|
||||
runner(task)
|
||||
|
||||
# Verify chat_completion was called with correct arguments
|
||||
mock_generator.chat_completion.assert_called_once_with(fake_params, fake_messages)
|
||||
|
||||
def test_invalid_task_type_raises_error(self):
|
||||
"""Verify ModelRunner rejects invalid task types."""
|
||||
mock_generator = Mock()
|
||||
runner = ModelRunner(mock_generator)
|
||||
|
||||
with pytest.raises(ValueError, match="Unexpected task type"):
|
||||
runner(("invalid_task", []))
|
||||
|
|
@ -10,11 +10,13 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.inference import CompletionMessage, UserMessage
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.safety import RunShieldResponse, ViolationLevel
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.models.llama.datatypes import StopReason
|
||||
from llama_stack.providers.remote.safety.nvidia.config import NVIDIASafetyConfig
|
||||
from llama_stack.providers.remote.safety.nvidia.nvidia import NVIDIASafetyAdapter
|
||||
|
||||
|
|
@ -136,11 +138,9 @@ async def test_run_shield_allowed(nvidia_adapter, mock_guardrails_post):
|
|||
|
||||
# Run the shield
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
OpenAIUserMessageParam(content="Hello, how are you?"),
|
||||
OpenAIAssistantMessageParam(
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason=StopReason.end_of_message,
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
|
|
@ -191,13 +191,10 @@ async def test_run_shield_blocked(nvidia_adapter, mock_guardrails_post):
|
|||
# Mock Guardrails API response
|
||||
mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}}
|
||||
|
||||
# Run the shield
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
OpenAIUserMessageParam(content="Hello, how are you?"),
|
||||
OpenAIAssistantMessageParam(
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason=StopReason.end_of_message,
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
|
|
@ -243,7 +240,7 @@ async def test_run_shield_not_found(nvidia_adapter, mock_guardrails_post):
|
|||
adapter.shield_store.get_shield.return_value = None
|
||||
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
OpenAIUserMessageParam(content="Hello, how are you?"),
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
|
|
@ -274,11 +271,9 @@ async def test_run_shield_http_error(nvidia_adapter, mock_guardrails_post):
|
|||
|
||||
# Running the shield should raise an exception
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
OpenAIUserMessageParam(content="Hello, how are you?"),
|
||||
OpenAIAssistantMessageParam(
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason=StopReason.end_of_message,
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,220 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from llama_stack.apis.common.content_types import TextContentItem
|
||||
from llama_stack.apis.inference import (
|
||||
CompletionMessage,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIImageURL,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_message_to_openai_dict,
|
||||
convert_message_to_openai_dict_new,
|
||||
openai_messages_to_messages,
|
||||
)
|
||||
|
||||
|
||||
async def test_convert_message_to_openai_dict():
|
||||
message = UserMessage(content=[TextContentItem(text="Hello, world!")], role="user")
|
||||
assert await convert_message_to_openai_dict(message) == {
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "Hello, world!"}],
|
||||
}
|
||||
|
||||
|
||||
# Test convert_message_to_openai_dict with a tool call
|
||||
async def test_convert_message_to_openai_dict_with_tool_call():
|
||||
message = CompletionMessage(
|
||||
content="",
|
||||
tool_calls=[ToolCall(call_id="123", tool_name="test_tool", arguments='{"foo": "bar"}')],
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
|
||||
openai_dict = await convert_message_to_openai_dict(message)
|
||||
|
||||
assert openai_dict == {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": ""}],
|
||||
"tool_calls": [
|
||||
{"id": "123", "type": "function", "function": {"name": "test_tool", "arguments": '{"foo": "bar"}'}}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
async def test_convert_message_to_openai_dict_with_builtin_tool_call():
|
||||
message = CompletionMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
call_id="123",
|
||||
tool_name=BuiltinTool.brave_search,
|
||||
arguments='{"foo": "bar"}',
|
||||
)
|
||||
],
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
|
||||
openai_dict = await convert_message_to_openai_dict(message)
|
||||
|
||||
assert openai_dict == {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": ""}],
|
||||
"tool_calls": [
|
||||
{"id": "123", "type": "function", "function": {"name": "brave_search", "arguments": '{"foo": "bar"}'}}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
async def test_openai_messages_to_messages_with_content_str():
|
||||
openai_messages = [
|
||||
OpenAISystemMessageParam(content="system message"),
|
||||
OpenAIUserMessageParam(content="user message"),
|
||||
OpenAIAssistantMessageParam(content="assistant message"),
|
||||
]
|
||||
|
||||
llama_messages = openai_messages_to_messages(openai_messages)
|
||||
assert len(llama_messages) == 3
|
||||
assert isinstance(llama_messages[0], SystemMessage)
|
||||
assert isinstance(llama_messages[1], UserMessage)
|
||||
assert isinstance(llama_messages[2], CompletionMessage)
|
||||
assert llama_messages[0].content == "system message"
|
||||
assert llama_messages[1].content == "user message"
|
||||
assert llama_messages[2].content == "assistant message"
|
||||
|
||||
|
||||
async def test_openai_messages_to_messages_with_content_list():
|
||||
openai_messages = [
|
||||
OpenAISystemMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="system message")]),
|
||||
OpenAIUserMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="user message")]),
|
||||
OpenAIAssistantMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="assistant message")]),
|
||||
]
|
||||
|
||||
llama_messages = openai_messages_to_messages(openai_messages)
|
||||
assert len(llama_messages) == 3
|
||||
assert isinstance(llama_messages[0], SystemMessage)
|
||||
assert isinstance(llama_messages[1], UserMessage)
|
||||
assert isinstance(llama_messages[2], CompletionMessage)
|
||||
assert llama_messages[0].content[0].text == "system message"
|
||||
assert llama_messages[1].content[0].text == "user message"
|
||||
assert llama_messages[2].content[0].text == "assistant message"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"message_class,kwargs",
|
||||
[
|
||||
(OpenAISystemMessageParam, {}),
|
||||
(OpenAIAssistantMessageParam, {}),
|
||||
(OpenAIDeveloperMessageParam, {}),
|
||||
(OpenAIUserMessageParam, {}),
|
||||
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
|
||||
],
|
||||
)
|
||||
def test_message_accepts_text_string(message_class, kwargs):
|
||||
"""Test that messages accept string text content."""
|
||||
msg = message_class(content="Test message", **kwargs)
|
||||
assert msg.content == "Test message"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"message_class,kwargs",
|
||||
[
|
||||
(OpenAISystemMessageParam, {}),
|
||||
(OpenAIAssistantMessageParam, {}),
|
||||
(OpenAIDeveloperMessageParam, {}),
|
||||
(OpenAIUserMessageParam, {}),
|
||||
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
|
||||
],
|
||||
)
|
||||
def test_message_accepts_text_list(message_class, kwargs):
|
||||
"""Test that messages accept list of text content parts."""
|
||||
content_list = [OpenAIChatCompletionContentPartTextParam(text="Test message")]
|
||||
msg = message_class(content=content_list, **kwargs)
|
||||
assert len(msg.content) == 1
|
||||
assert msg.content[0].text == "Test message"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"message_class,kwargs",
|
||||
[
|
||||
(OpenAISystemMessageParam, {}),
|
||||
(OpenAIAssistantMessageParam, {}),
|
||||
(OpenAIDeveloperMessageParam, {}),
|
||||
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
|
||||
],
|
||||
)
|
||||
def test_message_rejects_images(message_class, kwargs):
|
||||
"""Test that system, assistant, developer, and tool messages reject image content."""
|
||||
with pytest.raises(ValidationError):
|
||||
message_class(
|
||||
content=[
|
||||
OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg"))
|
||||
],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def test_user_message_accepts_images():
|
||||
"""Test that user messages accept image content (unlike other message types)."""
|
||||
# List with images should work
|
||||
msg = OpenAIUserMessageParam(
|
||||
content=[
|
||||
OpenAIChatCompletionContentPartTextParam(text="Describe this image:"),
|
||||
OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg")),
|
||||
]
|
||||
)
|
||||
assert len(msg.content) == 2
|
||||
assert msg.content[0].text == "Describe this image:"
|
||||
assert msg.content[1].image_url.url == "http://example.com/image.jpg"
|
||||
|
||||
|
||||
async def test_convert_message_to_openai_dict_new_user_message():
|
||||
"""Test convert_message_to_openai_dict_new with UserMessage."""
|
||||
message = UserMessage(content="Hello, world!", role="user")
|
||||
result = await convert_message_to_openai_dict_new(message)
|
||||
|
||||
assert result["role"] == "user"
|
||||
assert result["content"] == "Hello, world!"
|
||||
|
||||
|
||||
async def test_convert_message_to_openai_dict_new_completion_message_with_tool_calls():
|
||||
"""Test convert_message_to_openai_dict_new with CompletionMessage containing tool calls."""
|
||||
message = CompletionMessage(
|
||||
content="I'll help you find the weather.",
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
call_id="call_123",
|
||||
tool_name="get_weather",
|
||||
arguments='{"city": "Sligo"}',
|
||||
)
|
||||
],
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
result = await convert_message_to_openai_dict_new(message)
|
||||
|
||||
# This would have failed with "Cannot instantiate typing.Union" before the fix
|
||||
assert result["role"] == "assistant"
|
||||
assert result["content"] == "I'll help you find the weather."
|
||||
assert "tool_calls" in result
|
||||
assert result["tool_calls"] is not None
|
||||
assert len(result["tool_calls"]) == 1
|
||||
|
||||
tool_call = result["tool_calls"][0]
|
||||
assert tool_call.id == "call_123"
|
||||
assert tool_call.type == "function"
|
||||
assert tool_call.function.name == "get_weather"
|
||||
assert tool_call.function.arguments == '{"city": "Sligo"}'
|
||||
35
tests/unit/providers/utils/inference/test_prompt_adapter.py
Normal file
35
tests/unit/providers/utils/inference/test_prompt_adapter.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import RawTextItem
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
convert_openai_message_to_raw_message,
|
||||
)
|
||||
|
||||
|
||||
class TestConvertOpenAIMessageToRawMessage:
|
||||
"""Test conversion of OpenAI message types to RawMessage format."""
|
||||
|
||||
async def test_user_message_conversion(self):
|
||||
msg = OpenAIUserMessageParam(role="user", content="Hello world")
|
||||
raw_msg = await convert_openai_message_to_raw_message(msg)
|
||||
|
||||
assert raw_msg.role == "user"
|
||||
assert isinstance(raw_msg.content, RawTextItem)
|
||||
assert raw_msg.content.text == "Hello world"
|
||||
|
||||
async def test_assistant_message_conversion(self):
|
||||
msg = OpenAIAssistantMessageParam(role="assistant", content="Hi there!")
|
||||
raw_msg = await convert_openai_message_to_raw_message(msg)
|
||||
|
||||
assert raw_msg.role == "assistant"
|
||||
assert isinstance(raw_msg.content, RawTextItem)
|
||||
assert raw_msg.content.text == "Hi there!"
|
||||
assert raw_msg.tool_calls == []
|
||||
Loading…
Add table
Add a link
Reference in a new issue