mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-04 18:13:44 +00:00
Merge remote-tracking branch 'origin/main' into storage_fix
This commit is contained in:
commit
08024d44f2
89 changed files with 4786 additions and 3941 deletions
|
|
@ -516,3 +516,169 @@ def test_response_with_instructions(openai_client, client_with_models, text_mode
|
|||
|
||||
# Verify instructions from previous response was not carried over to the next response
|
||||
assert response_with_instructions2.instructions == instructions2
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Tool calling is not reliable.")
|
||||
def test_max_tool_calls_with_function_tools(openai_client, client_with_models, text_model_id):
|
||||
"""Test handling of max_tool_calls with function tools in responses."""
|
||||
if isinstance(client_with_models, LlamaStackAsLibraryClient):
|
||||
pytest.skip("OpenAI responses are not supported when testing with library client yet.")
|
||||
|
||||
client = openai_client
|
||||
max_tool_calls = 1
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"name": "get_weather",
|
||||
"description": "Get weather information for a specified location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city name (e.g., 'New York', 'London')",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "get_time",
|
||||
"description": "Get current time for a specified location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city name (e.g., 'New York', 'London')",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
# First create a response that triggers function tools
|
||||
response = client.responses.create(
|
||||
model=text_model_id,
|
||||
input="Can you tell me the weather in Paris and the current time?",
|
||||
tools=tools,
|
||||
stream=False,
|
||||
max_tool_calls=max_tool_calls,
|
||||
)
|
||||
|
||||
# Verify we got two function calls and that the max_tool_calls do not affect function tools
|
||||
assert len(response.output) == 2
|
||||
assert response.output[0].type == "function_call"
|
||||
assert response.output[0].name == "get_weather"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[1].type == "function_call"
|
||||
assert response.output[1].name == "get_time"
|
||||
assert response.output[0].status == "completed"
|
||||
|
||||
# Verify we have a valid max_tool_calls field
|
||||
assert response.max_tool_calls == max_tool_calls
|
||||
|
||||
|
||||
def test_max_tool_calls_invalid(openai_client, client_with_models, text_model_id):
|
||||
"""Test handling of invalid max_tool_calls in responses."""
|
||||
if isinstance(client_with_models, LlamaStackAsLibraryClient):
|
||||
pytest.skip("OpenAI responses are not supported when testing with library client yet.")
|
||||
|
||||
client = openai_client
|
||||
|
||||
input = "Search for today's top technology news."
|
||||
invalid_max_tool_calls = 0
|
||||
tools = [
|
||||
{"type": "web_search"},
|
||||
]
|
||||
|
||||
# Create a response with an invalid max_tool_calls value i.e. 0
|
||||
# Handle ValueError from LLS and BadRequestError from OpenAI client
|
||||
with pytest.raises((ValueError, BadRequestError)) as excinfo:
|
||||
client.responses.create(
|
||||
model=text_model_id,
|
||||
input=input,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
max_tool_calls=invalid_max_tool_calls,
|
||||
)
|
||||
|
||||
error_message = str(excinfo.value)
|
||||
assert f"Invalid max_tool_calls={invalid_max_tool_calls}; should be >= 1" in error_message, (
|
||||
f"Expected error message about invalid max_tool_calls, got: {error_message}"
|
||||
)
|
||||
|
||||
|
||||
def test_max_tool_calls_with_builtin_tools(openai_client, client_with_models, text_model_id):
|
||||
"""Test handling of max_tool_calls with built-in tools in responses."""
|
||||
if isinstance(client_with_models, LlamaStackAsLibraryClient):
|
||||
pytest.skip("OpenAI responses are not supported when testing with library client yet.")
|
||||
|
||||
client = openai_client
|
||||
|
||||
input = "Search for today's top technology and a positive news story. You MUST make exactly two separate web search calls."
|
||||
max_tool_calls = [1, 5]
|
||||
tools = [
|
||||
{"type": "web_search"},
|
||||
]
|
||||
|
||||
# First create a response that triggers web_search tools without max_tool_calls
|
||||
response = client.responses.create(
|
||||
model=text_model_id,
|
||||
input=input,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Verify we got two web search calls followed by a message
|
||||
assert len(response.output) == 3
|
||||
assert response.output[0].type == "web_search_call"
|
||||
assert response.output[0].status == "completed"
|
||||
assert response.output[1].type == "web_search_call"
|
||||
assert response.output[1].status == "completed"
|
||||
assert response.output[2].type == "message"
|
||||
assert response.output[2].status == "completed"
|
||||
assert response.output[2].role == "assistant"
|
||||
|
||||
# Next create a response that triggers web_search tools with max_tool_calls set to 1
|
||||
response_2 = client.responses.create(
|
||||
model=text_model_id,
|
||||
input=input,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
max_tool_calls=max_tool_calls[0],
|
||||
)
|
||||
|
||||
# Verify we got one web search tool call followed by a message
|
||||
assert len(response_2.output) == 2
|
||||
assert response_2.output[0].type == "web_search_call"
|
||||
assert response_2.output[0].status == "completed"
|
||||
assert response_2.output[1].type == "message"
|
||||
assert response_2.output[1].status == "completed"
|
||||
assert response_2.output[1].role == "assistant"
|
||||
|
||||
# Verify we have a valid max_tool_calls field
|
||||
assert response_2.max_tool_calls == max_tool_calls[0]
|
||||
|
||||
# Finally create a response that triggers web_search tools with max_tool_calls set to 5
|
||||
response_3 = client.responses.create(
|
||||
model=text_model_id,
|
||||
input=input,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
max_tool_calls=max_tool_calls[1],
|
||||
)
|
||||
|
||||
# Verify we got two web search calls followed by a message
|
||||
assert len(response_3.output) == 3
|
||||
assert response_3.output[0].type == "web_search_call"
|
||||
assert response_3.output[0].status == "completed"
|
||||
assert response_3.output[1].type == "web_search_call"
|
||||
assert response_3.output[1].status == "completed"
|
||||
assert response_3.output[2].type == "message"
|
||||
assert response_3.output[2].status == "completed"
|
||||
assert response_3.output[2].role == "assistant"
|
||||
|
||||
# Verify we have a valid max_tool_calls field
|
||||
assert response_3.max_tool_calls == max_tool_calls[1]
|
||||
|
|
|
|||
|
|
@ -334,7 +334,13 @@ def require_server(llama_stack_client):
|
|||
@pytest.fixture(scope="session")
|
||||
def openai_client(llama_stack_client, require_server):
|
||||
base_url = f"{llama_stack_client.base_url}/v1"
|
||||
return OpenAI(base_url=base_url, api_key="fake")
|
||||
client = OpenAI(base_url=base_url, api_key="fake", max_retries=0, timeout=30.0)
|
||||
yield client
|
||||
# Cleanup: close HTTP connections
|
||||
try:
|
||||
client.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(params=["openai_client", "client_with_models"])
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id)
|
|||
# {"error":{"message":"Unknown request URL: GET /openai/v1/completions. Please check the URL for typos,
|
||||
# or see the docs at https://console.groq.com/docs/","type":"invalid_request_error","code":"unknown_url"}}
|
||||
"remote::groq",
|
||||
"remote::oci",
|
||||
"remote::gemini", # https://generativelanguage.googleapis.com/v1beta/openai/completions -> 404
|
||||
"remote::anthropic", # at least claude-3-{5,7}-{haiku,sonnet}-* / claude-{sonnet,opus}-4-* are not supported
|
||||
"remote::azure", # {'error': {'code': 'OperationNotSupported', 'message': 'The completion operation
|
||||
|
|
|
|||
|
|
@ -138,6 +138,7 @@ def skip_if_model_doesnt_support_openai_embeddings(client, model_id):
|
|||
"remote::runpod",
|
||||
"remote::sambanova",
|
||||
"remote::tgi",
|
||||
"remote::oci",
|
||||
):
|
||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI embeddings.")
|
||||
|
||||
|
|
|
|||
4
tests/integration/recordings/README.md
generated
4
tests/integration/recordings/README.md
generated
|
|
@ -2,6 +2,10 @@
|
|||
|
||||
This directory contains recorded inference API responses used for deterministic testing without requiring live API access.
|
||||
|
||||
For more information, see the
|
||||
[docs](https://llamastack.github.io/docs/contributing/testing/record-replay).
|
||||
This README provides more technical information.
|
||||
|
||||
## Structure
|
||||
|
||||
- `responses/` - JSON files containing request/response pairs for inference operations
|
||||
|
|
|
|||
|
|
@ -115,7 +115,15 @@ def openai_client(base_url, api_key, provider):
|
|||
client = LlamaStackAsLibraryClient(config, skip_logger_removal=True)
|
||||
return client
|
||||
|
||||
return OpenAI(
|
||||
client = OpenAI(
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
max_retries=0,
|
||||
timeout=30.0,
|
||||
)
|
||||
yield client
|
||||
# Cleanup: close HTTP connections
|
||||
try:
|
||||
client.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -65,8 +65,14 @@ class TestConversationResponses:
|
|||
conversation_items = openai_client.conversations.items.list(conversation.id)
|
||||
assert len(conversation_items.data) >= 4 # 2 user + 2 assistant messages
|
||||
|
||||
@pytest.mark.timeout(60, method="thread")
|
||||
def test_conversation_context_loading(self, openai_client, text_model_id):
|
||||
"""Test that conversation context is properly loaded for responses."""
|
||||
"""Test that conversation context is properly loaded for responses.
|
||||
|
||||
Note: 60s timeout added due to CI-specific deadlock in pytest/OpenAI client/httpx
|
||||
after running 25+ tests. Hangs before first HTTP request is made. Works fine locally.
|
||||
Investigation needed: connection pool exhaustion or event loop state issue.
|
||||
"""
|
||||
conversation = openai_client.conversations.create(
|
||||
items=[
|
||||
{"type": "message", "role": "user", "content": "My name is Alice. I like to eat apples."},
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import pytest
|
|||
from llama_stack_client import BadRequestError
|
||||
from openai import BadRequestError as OpenAIBadRequestError
|
||||
|
||||
from llama_stack.apis.files import ExpiresAfter
|
||||
from llama_stack.apis.vector_io import Chunk
|
||||
from llama_stack.core.library_client import LlamaStackAsLibraryClient
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -907,16 +908,16 @@ def test_openai_vector_store_retrieve_file_contents(
|
|||
)
|
||||
|
||||
assert file_contents is not None
|
||||
assert len(file_contents.content) == 1
|
||||
content = file_contents.content[0]
|
||||
assert file_contents.object == "vector_store.file_content.page"
|
||||
assert len(file_contents.data) == 1
|
||||
content = file_contents.data[0]
|
||||
|
||||
# llama-stack-client returns a model, openai-python is a badboy and returns a dict
|
||||
if not isinstance(content, dict):
|
||||
content = content.model_dump()
|
||||
assert content["type"] == "text"
|
||||
assert content["text"] == test_content.decode("utf-8")
|
||||
assert file_contents.filename == file_name
|
||||
assert file_contents.attributes == attributes
|
||||
assert file_contents.has_more is False
|
||||
|
||||
|
||||
@vector_provider_wrapper
|
||||
|
|
@ -1483,14 +1484,12 @@ def test_openai_vector_store_file_batch_retrieve_contents(
|
|||
)
|
||||
|
||||
assert file_contents is not None
|
||||
assert file_contents.filename == file_data[i][0]
|
||||
assert len(file_contents.content) > 0
|
||||
assert file_contents.object == "vector_store.file_content.page"
|
||||
assert len(file_contents.data) > 0
|
||||
|
||||
# Verify the content matches what we uploaded
|
||||
content_text = (
|
||||
file_contents.content[0].text
|
||||
if hasattr(file_contents.content[0], "text")
|
||||
else file_contents.content[0]["text"]
|
||||
file_contents.data[0].text if hasattr(file_contents.data[0], "text") else file_contents.data[0]["text"]
|
||||
)
|
||||
assert file_data[i][1].decode("utf-8") in content_text
|
||||
|
||||
|
|
@ -1606,3 +1605,97 @@ def test_openai_vector_store_embedding_config_from_metadata(
|
|||
|
||||
assert "metadata_config_store" in store_names
|
||||
assert "consistent_config_store" in store_names
|
||||
|
||||
|
||||
@vector_provider_wrapper
|
||||
def test_openai_vector_store_file_contents_with_extra_query(
|
||||
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension, vector_io_provider_id
|
||||
):
|
||||
"""Test that vector store file contents endpoint supports extra_query parameter."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
compat_client = compat_client_with_empty_stores
|
||||
|
||||
# Create a vector store
|
||||
vector_store = compat_client.vector_stores.create(
|
||||
name="test_extra_query_store",
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"provider_id": vector_io_provider_id,
|
||||
},
|
||||
)
|
||||
|
||||
# Create and attach a file
|
||||
test_content = b"This is test content for extra_query validation."
|
||||
with BytesIO(test_content) as file_buffer:
|
||||
file_buffer.name = "test_extra_query.txt"
|
||||
file = compat_client.files.create(
|
||||
file=file_buffer,
|
||||
purpose="assistants",
|
||||
expires_after=ExpiresAfter(anchor="created_at", seconds=86400),
|
||||
)
|
||||
|
||||
file_attach_response = compat_client.vector_stores.files.create(
|
||||
vector_store_id=vector_store.id,
|
||||
file_id=file.id,
|
||||
extra_body={"embedding_model": embedding_model_id},
|
||||
)
|
||||
assert file_attach_response.status == "completed"
|
||||
|
||||
# Wait for processing
|
||||
time.sleep(2)
|
||||
|
||||
# Test that extra_query parameter is accepted and processed
|
||||
content_with_extra_query = compat_client.vector_stores.files.content(
|
||||
vector_store_id=vector_store.id,
|
||||
file_id=file.id,
|
||||
extra_query={"include_embeddings": True, "include_metadata": True},
|
||||
)
|
||||
|
||||
# Test without extra_query for comparison
|
||||
content_without_extra_query = compat_client.vector_stores.files.content(
|
||||
vector_store_id=vector_store.id,
|
||||
file_id=file.id,
|
||||
)
|
||||
|
||||
# Validate that both calls succeed
|
||||
assert content_with_extra_query is not None
|
||||
assert content_without_extra_query is not None
|
||||
assert len(content_with_extra_query.data) > 0
|
||||
assert len(content_without_extra_query.data) > 0
|
||||
|
||||
# Validate that extra_query parameter is processed correctly
|
||||
# Both should have the embedding/metadata fields available (may be None based on flags)
|
||||
first_chunk_with_flags = content_with_extra_query.data[0]
|
||||
first_chunk_without_flags = content_without_extra_query.data[0]
|
||||
|
||||
# The key validation: extra_query fields are present in the response
|
||||
# Handle both dict and object responses (different clients may return different formats)
|
||||
def has_field(obj, field):
|
||||
if isinstance(obj, dict):
|
||||
return field in obj
|
||||
else:
|
||||
return hasattr(obj, field)
|
||||
|
||||
# Validate that all expected fields are present in both responses
|
||||
expected_fields = ["embedding", "chunk_metadata", "metadata", "text"]
|
||||
for field in expected_fields:
|
||||
assert has_field(first_chunk_with_flags, field), f"Field '{field}' missing from response with extra_query"
|
||||
assert has_field(first_chunk_without_flags, field), f"Field '{field}' missing from response without extra_query"
|
||||
|
||||
# Validate content is the same
|
||||
def get_field(obj, field):
|
||||
if isinstance(obj, dict):
|
||||
return obj[field]
|
||||
else:
|
||||
return getattr(obj, field)
|
||||
|
||||
assert get_field(first_chunk_with_flags, "text") == test_content.decode("utf-8")
|
||||
assert get_field(first_chunk_without_flags, "text") == test_content.decode("utf-8")
|
||||
|
||||
with_flags_embedding = get_field(first_chunk_with_flags, "embedding")
|
||||
without_flags_embedding = get_field(first_chunk_without_flags, "embedding")
|
||||
|
||||
# Validate that embeddings are included when requested and excluded when not requested
|
||||
assert with_flags_embedding is not None, "Embeddings should be included when include_embeddings=True"
|
||||
assert len(with_flags_embedding) > 0, "Embedding should be a non-empty list"
|
||||
assert without_flags_embedding is None, "Embeddings should not be included when include_embeddings=False"
|
||||
|
|
|
|||
|
|
@ -55,3 +55,65 @@ async def test_create_vector_stores_multiple_providers_missing_provider_id_error
|
|||
|
||||
with pytest.raises(ValueError, match="Multiple vector_io providers available"):
|
||||
await router.openai_create_vector_store(request)
|
||||
|
||||
|
||||
async def test_update_vector_store_provider_id_change_fails():
|
||||
"""Test that updating a vector store with a different provider_id fails with clear error."""
|
||||
mock_routing_table = Mock()
|
||||
|
||||
# Mock an existing vector store with provider_id "faiss"
|
||||
mock_existing_store = Mock()
|
||||
mock_existing_store.provider_id = "inline::faiss"
|
||||
mock_existing_store.identifier = "vs_123"
|
||||
|
||||
mock_routing_table.get_object_by_identifier = AsyncMock(return_value=mock_existing_store)
|
||||
mock_routing_table.get_provider_impl = AsyncMock(
|
||||
return_value=Mock(openai_update_vector_store=AsyncMock(return_value=Mock(id="vs_123")))
|
||||
)
|
||||
|
||||
router = VectorIORouter(mock_routing_table)
|
||||
|
||||
# Try to update with different provider_id in metadata - this should fail
|
||||
with pytest.raises(ValueError, match="provider_id cannot be changed after vector store creation"):
|
||||
await router.openai_update_vector_store(
|
||||
vector_store_id="vs_123",
|
||||
name="updated_name",
|
||||
metadata={"provider_id": "inline::sqlite"}, # Different provider_id
|
||||
)
|
||||
|
||||
# Verify the existing store was looked up to check provider_id
|
||||
mock_routing_table.get_object_by_identifier.assert_called_once_with("vector_store", "vs_123")
|
||||
|
||||
# Provider should not be called since validation failed
|
||||
mock_routing_table.get_provider_impl.assert_not_called()
|
||||
|
||||
|
||||
async def test_update_vector_store_same_provider_id_succeeds():
|
||||
"""Test that updating a vector store with the same provider_id succeeds."""
|
||||
mock_routing_table = Mock()
|
||||
|
||||
# Mock an existing vector store with provider_id "faiss"
|
||||
mock_existing_store = Mock()
|
||||
mock_existing_store.provider_id = "inline::faiss"
|
||||
mock_existing_store.identifier = "vs_123"
|
||||
|
||||
mock_routing_table.get_object_by_identifier = AsyncMock(return_value=mock_existing_store)
|
||||
mock_routing_table.get_provider_impl = AsyncMock(
|
||||
return_value=Mock(openai_update_vector_store=AsyncMock(return_value=Mock(id="vs_123")))
|
||||
)
|
||||
|
||||
router = VectorIORouter(mock_routing_table)
|
||||
|
||||
# Update with same provider_id should succeed
|
||||
await router.openai_update_vector_store(
|
||||
vector_store_id="vs_123",
|
||||
name="updated_name",
|
||||
metadata={"provider_id": "inline::faiss"}, # Same provider_id
|
||||
)
|
||||
|
||||
# Verify the provider update method was called
|
||||
mock_routing_table.get_provider_impl.assert_called_once_with("vs_123")
|
||||
provider = await mock_routing_table.get_provider_impl("vs_123")
|
||||
provider.openai_update_vector_store.assert_called_once_with(
|
||||
vector_store_id="vs_123", name="updated_name", expires_after=None, metadata={"provider_id": "inline::faiss"}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,303 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
CompletionMessage,
|
||||
StopReason,
|
||||
SystemMessage,
|
||||
SystemMessageBehavior,
|
||||
ToolCall,
|
||||
ToolConfig,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_messages,
|
||||
chat_completion_request_to_prompt,
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
MODEL = "Llama3.1-8B-Instruct"
|
||||
MODEL3_2 = "Llama3.2-3B-Instruct"
|
||||
|
||||
|
||||
async def test_system_default():
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
assert len(messages) == 2
|
||||
assert messages[-1].content == content
|
||||
assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content)
|
||||
|
||||
|
||||
async def test_system_builtin_only():
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
assert len(messages) == 2
|
||||
assert messages[-1].content == content
|
||||
assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content)
|
||||
assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content)
|
||||
|
||||
|
||||
async def test_system_custom_only():
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"param1": {
|
||||
"type": "str",
|
||||
"description": "param1 description",
|
||||
},
|
||||
},
|
||||
"required": ["param1"],
|
||||
},
|
||||
)
|
||||
],
|
||||
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
assert len(messages) == 3
|
||||
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
|
||||
|
||||
assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content)
|
||||
assert messages[-1].content == content
|
||||
|
||||
|
||||
async def test_system_custom_and_builtin():
|
||||
content = "Hello !"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
ToolDefinition(tool_name=BuiltinTool.brave_search),
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"param1": {
|
||||
"type": "str",
|
||||
"description": "param1 description",
|
||||
},
|
||||
},
|
||||
"required": ["param1"],
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
assert len(messages) == 3
|
||||
|
||||
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
|
||||
assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content)
|
||||
|
||||
assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content)
|
||||
assert messages[-1].content == content
|
||||
|
||||
|
||||
async def test_completion_message_encoding():
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL3_2,
|
||||
messages=[
|
||||
UserMessage(content="hello"),
|
||||
CompletionMessage(
|
||||
content="",
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
tool_name="custom1",
|
||||
arguments='{"param1": "value1"}', # arguments must be a JSON string
|
||||
call_id="123",
|
||||
)
|
||||
],
|
||||
),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"param1": {
|
||||
"type": "str",
|
||||
"description": "param1 description",
|
||||
},
|
||||
},
|
||||
"required": ["param1"],
|
||||
},
|
||||
),
|
||||
],
|
||||
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list),
|
||||
)
|
||||
prompt = await chat_completion_request_to_prompt(request, request.model)
|
||||
assert '[custom1(param1="value1")]' in prompt
|
||||
|
||||
request.model = MODEL
|
||||
request.tool_config = ToolConfig(tool_prompt_format=ToolPromptFormat.json)
|
||||
prompt = await chat_completion_request_to_prompt(request, request.model)
|
||||
assert '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}' in prompt
|
||||
|
||||
|
||||
async def test_user_provided_system_message():
|
||||
content = "Hello !"
|
||||
system_prompt = "You are a pirate"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
SystemMessage(content=system_prompt),
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
],
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL)
|
||||
assert len(messages) == 2
|
||||
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
|
||||
|
||||
assert messages[-1].content == content
|
||||
|
||||
|
||||
async def test_replace_system_message_behavior_builtin_tools():
|
||||
content = "Hello !"
|
||||
system_prompt = "You are a pirate"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
SystemMessage(content=system_prompt),
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
],
|
||||
tool_config=ToolConfig(
|
||||
tool_choice="auto",
|
||||
tool_prompt_format=ToolPromptFormat.python_list,
|
||||
system_message_behavior=SystemMessageBehavior.replace,
|
||||
),
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
||||
assert len(messages) == 2
|
||||
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
|
||||
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
|
||||
assert messages[-1].content == content
|
||||
|
||||
|
||||
async def test_replace_system_message_behavior_custom_tools():
|
||||
content = "Hello !"
|
||||
system_prompt = "You are a pirate"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
SystemMessage(content=system_prompt),
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"param1": {
|
||||
"type": "str",
|
||||
"description": "param1 description",
|
||||
},
|
||||
},
|
||||
"required": ["param1"],
|
||||
},
|
||||
),
|
||||
],
|
||||
tool_config=ToolConfig(
|
||||
tool_choice="auto",
|
||||
tool_prompt_format=ToolPromptFormat.python_list,
|
||||
system_message_behavior=SystemMessageBehavior.replace,
|
||||
),
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
||||
|
||||
assert len(messages) == 2
|
||||
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
|
||||
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
|
||||
assert messages[-1].content == content
|
||||
|
||||
|
||||
async def test_replace_system_message_behavior_custom_tools_with_template():
|
||||
content = "Hello !"
|
||||
system_prompt = "You are a pirate {{ function_description }}"
|
||||
request = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[
|
||||
SystemMessage(content=system_prompt),
|
||||
UserMessage(content=content),
|
||||
],
|
||||
tools=[
|
||||
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
|
||||
ToolDefinition(
|
||||
tool_name="custom1",
|
||||
description="custom1 tool",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"param1": {
|
||||
"type": "str",
|
||||
"description": "param1 description",
|
||||
},
|
||||
},
|
||||
"required": ["param1"],
|
||||
},
|
||||
),
|
||||
],
|
||||
tool_config=ToolConfig(
|
||||
tool_choice="auto",
|
||||
tool_prompt_format=ToolPromptFormat.python_list,
|
||||
system_message_behavior=SystemMessageBehavior.replace,
|
||||
),
|
||||
)
|
||||
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
||||
|
||||
assert len(messages) == 2
|
||||
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
|
||||
assert "You are a pirate" in interleaved_content_as_str(messages[0].content)
|
||||
# function description is present in the system prompt
|
||||
assert '"name": "custom1"' in interleaved_content_as_str(messages[0].content)
|
||||
assert messages[-1].content == content
|
||||
5
tests/unit/providers/inline/inference/__init__.py
Normal file
5
tests/unit/providers/inline/inference/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
44
tests/unit/providers/inline/inference/test_meta_reference.py
Normal file
44
tests/unit/providers/inline/inference/test_meta_reference.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.providers.inline.inference.meta_reference.model_parallel import (
|
||||
ModelRunner,
|
||||
)
|
||||
|
||||
|
||||
class TestModelRunner:
|
||||
"""Test ModelRunner task dispatching for model-parallel inference."""
|
||||
|
||||
def test_chat_completion_task_dispatch(self):
|
||||
"""Verify ModelRunner correctly dispatches chat_completion tasks."""
|
||||
# Create a mock generator
|
||||
mock_generator = Mock()
|
||||
mock_generator.chat_completion = Mock(return_value=iter([]))
|
||||
|
||||
runner = ModelRunner(mock_generator)
|
||||
|
||||
# Create a chat_completion task
|
||||
fake_params = {"model": "test"}
|
||||
fake_messages = [{"role": "user", "content": "test"}]
|
||||
task = ("chat_completion", [fake_params, fake_messages])
|
||||
|
||||
# Execute task
|
||||
runner(task)
|
||||
|
||||
# Verify chat_completion was called with correct arguments
|
||||
mock_generator.chat_completion.assert_called_once_with(fake_params, fake_messages)
|
||||
|
||||
def test_invalid_task_type_raises_error(self):
|
||||
"""Verify ModelRunner rejects invalid task types."""
|
||||
mock_generator = Mock()
|
||||
runner = ModelRunner(mock_generator)
|
||||
|
||||
with pytest.raises(ValueError, match="Unexpected task type"):
|
||||
runner(("invalid_task", []))
|
||||
|
|
@ -10,11 +10,13 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.inference import CompletionMessage, UserMessage
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.safety import RunShieldResponse, ViolationLevel
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.models.llama.datatypes import StopReason
|
||||
from llama_stack.providers.remote.safety.nvidia.config import NVIDIASafetyConfig
|
||||
from llama_stack.providers.remote.safety.nvidia.nvidia import NVIDIASafetyAdapter
|
||||
|
||||
|
|
@ -136,11 +138,9 @@ async def test_run_shield_allowed(nvidia_adapter, mock_guardrails_post):
|
|||
|
||||
# Run the shield
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
OpenAIUserMessageParam(content="Hello, how are you?"),
|
||||
OpenAIAssistantMessageParam(
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason=StopReason.end_of_message,
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
|
|
@ -191,13 +191,10 @@ async def test_run_shield_blocked(nvidia_adapter, mock_guardrails_post):
|
|||
# Mock Guardrails API response
|
||||
mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}}
|
||||
|
||||
# Run the shield
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
OpenAIUserMessageParam(content="Hello, how are you?"),
|
||||
OpenAIAssistantMessageParam(
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason=StopReason.end_of_message,
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
|
|
@ -243,7 +240,7 @@ async def test_run_shield_not_found(nvidia_adapter, mock_guardrails_post):
|
|||
adapter.shield_store.get_shield.return_value = None
|
||||
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
OpenAIUserMessageParam(content="Hello, how are you?"),
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
|
|
@ -274,11 +271,9 @@ async def test_run_shield_http_error(nvidia_adapter, mock_guardrails_post):
|
|||
|
||||
# Running the shield should raise an exception
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
OpenAIUserMessageParam(content="Hello, how are you?"),
|
||||
OpenAIAssistantMessageParam(
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason=StopReason.end_of_message,
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,220 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from llama_stack.apis.common.content_types import TextContentItem
|
||||
from llama_stack.apis.inference import (
|
||||
CompletionMessage,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIImageURL,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_message_to_openai_dict,
|
||||
convert_message_to_openai_dict_new,
|
||||
openai_messages_to_messages,
|
||||
)
|
||||
|
||||
|
||||
async def test_convert_message_to_openai_dict():
|
||||
message = UserMessage(content=[TextContentItem(text="Hello, world!")], role="user")
|
||||
assert await convert_message_to_openai_dict(message) == {
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "Hello, world!"}],
|
||||
}
|
||||
|
||||
|
||||
# Test convert_message_to_openai_dict with a tool call
|
||||
async def test_convert_message_to_openai_dict_with_tool_call():
|
||||
message = CompletionMessage(
|
||||
content="",
|
||||
tool_calls=[ToolCall(call_id="123", tool_name="test_tool", arguments='{"foo": "bar"}')],
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
|
||||
openai_dict = await convert_message_to_openai_dict(message)
|
||||
|
||||
assert openai_dict == {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": ""}],
|
||||
"tool_calls": [
|
||||
{"id": "123", "type": "function", "function": {"name": "test_tool", "arguments": '{"foo": "bar"}'}}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
async def test_convert_message_to_openai_dict_with_builtin_tool_call():
|
||||
message = CompletionMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
call_id="123",
|
||||
tool_name=BuiltinTool.brave_search,
|
||||
arguments='{"foo": "bar"}',
|
||||
)
|
||||
],
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
|
||||
openai_dict = await convert_message_to_openai_dict(message)
|
||||
|
||||
assert openai_dict == {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": ""}],
|
||||
"tool_calls": [
|
||||
{"id": "123", "type": "function", "function": {"name": "brave_search", "arguments": '{"foo": "bar"}'}}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
async def test_openai_messages_to_messages_with_content_str():
|
||||
openai_messages = [
|
||||
OpenAISystemMessageParam(content="system message"),
|
||||
OpenAIUserMessageParam(content="user message"),
|
||||
OpenAIAssistantMessageParam(content="assistant message"),
|
||||
]
|
||||
|
||||
llama_messages = openai_messages_to_messages(openai_messages)
|
||||
assert len(llama_messages) == 3
|
||||
assert isinstance(llama_messages[0], SystemMessage)
|
||||
assert isinstance(llama_messages[1], UserMessage)
|
||||
assert isinstance(llama_messages[2], CompletionMessage)
|
||||
assert llama_messages[0].content == "system message"
|
||||
assert llama_messages[1].content == "user message"
|
||||
assert llama_messages[2].content == "assistant message"
|
||||
|
||||
|
||||
async def test_openai_messages_to_messages_with_content_list():
|
||||
openai_messages = [
|
||||
OpenAISystemMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="system message")]),
|
||||
OpenAIUserMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="user message")]),
|
||||
OpenAIAssistantMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="assistant message")]),
|
||||
]
|
||||
|
||||
llama_messages = openai_messages_to_messages(openai_messages)
|
||||
assert len(llama_messages) == 3
|
||||
assert isinstance(llama_messages[0], SystemMessage)
|
||||
assert isinstance(llama_messages[1], UserMessage)
|
||||
assert isinstance(llama_messages[2], CompletionMessage)
|
||||
assert llama_messages[0].content[0].text == "system message"
|
||||
assert llama_messages[1].content[0].text == "user message"
|
||||
assert llama_messages[2].content[0].text == "assistant message"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"message_class,kwargs",
|
||||
[
|
||||
(OpenAISystemMessageParam, {}),
|
||||
(OpenAIAssistantMessageParam, {}),
|
||||
(OpenAIDeveloperMessageParam, {}),
|
||||
(OpenAIUserMessageParam, {}),
|
||||
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
|
||||
],
|
||||
)
|
||||
def test_message_accepts_text_string(message_class, kwargs):
|
||||
"""Test that messages accept string text content."""
|
||||
msg = message_class(content="Test message", **kwargs)
|
||||
assert msg.content == "Test message"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"message_class,kwargs",
|
||||
[
|
||||
(OpenAISystemMessageParam, {}),
|
||||
(OpenAIAssistantMessageParam, {}),
|
||||
(OpenAIDeveloperMessageParam, {}),
|
||||
(OpenAIUserMessageParam, {}),
|
||||
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
|
||||
],
|
||||
)
|
||||
def test_message_accepts_text_list(message_class, kwargs):
|
||||
"""Test that messages accept list of text content parts."""
|
||||
content_list = [OpenAIChatCompletionContentPartTextParam(text="Test message")]
|
||||
msg = message_class(content=content_list, **kwargs)
|
||||
assert len(msg.content) == 1
|
||||
assert msg.content[0].text == "Test message"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"message_class,kwargs",
|
||||
[
|
||||
(OpenAISystemMessageParam, {}),
|
||||
(OpenAIAssistantMessageParam, {}),
|
||||
(OpenAIDeveloperMessageParam, {}),
|
||||
(OpenAIToolMessageParam, {"tool_call_id": "call_123"}),
|
||||
],
|
||||
)
|
||||
def test_message_rejects_images(message_class, kwargs):
|
||||
"""Test that system, assistant, developer, and tool messages reject image content."""
|
||||
with pytest.raises(ValidationError):
|
||||
message_class(
|
||||
content=[
|
||||
OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg"))
|
||||
],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def test_user_message_accepts_images():
|
||||
"""Test that user messages accept image content (unlike other message types)."""
|
||||
# List with images should work
|
||||
msg = OpenAIUserMessageParam(
|
||||
content=[
|
||||
OpenAIChatCompletionContentPartTextParam(text="Describe this image:"),
|
||||
OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url="http://example.com/image.jpg")),
|
||||
]
|
||||
)
|
||||
assert len(msg.content) == 2
|
||||
assert msg.content[0].text == "Describe this image:"
|
||||
assert msg.content[1].image_url.url == "http://example.com/image.jpg"
|
||||
|
||||
|
||||
async def test_convert_message_to_openai_dict_new_user_message():
|
||||
"""Test convert_message_to_openai_dict_new with UserMessage."""
|
||||
message = UserMessage(content="Hello, world!", role="user")
|
||||
result = await convert_message_to_openai_dict_new(message)
|
||||
|
||||
assert result["role"] == "user"
|
||||
assert result["content"] == "Hello, world!"
|
||||
|
||||
|
||||
async def test_convert_message_to_openai_dict_new_completion_message_with_tool_calls():
|
||||
"""Test convert_message_to_openai_dict_new with CompletionMessage containing tool calls."""
|
||||
message = CompletionMessage(
|
||||
content="I'll help you find the weather.",
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
call_id="call_123",
|
||||
tool_name="get_weather",
|
||||
arguments='{"city": "Sligo"}',
|
||||
)
|
||||
],
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
result = await convert_message_to_openai_dict_new(message)
|
||||
|
||||
# This would have failed with "Cannot instantiate typing.Union" before the fix
|
||||
assert result["role"] == "assistant"
|
||||
assert result["content"] == "I'll help you find the weather."
|
||||
assert "tool_calls" in result
|
||||
assert result["tool_calls"] is not None
|
||||
assert len(result["tool_calls"]) == 1
|
||||
|
||||
tool_call = result["tool_calls"][0]
|
||||
assert tool_call.id == "call_123"
|
||||
assert tool_call.type == "function"
|
||||
assert tool_call.function.name == "get_weather"
|
||||
assert tool_call.function.arguments == '{"city": "Sligo"}'
|
||||
35
tests/unit/providers/utils/inference/test_prompt_adapter.py
Normal file
35
tests/unit/providers/utils/inference/test_prompt_adapter.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import RawTextItem
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
convert_openai_message_to_raw_message,
|
||||
)
|
||||
|
||||
|
||||
class TestConvertOpenAIMessageToRawMessage:
|
||||
"""Test conversion of OpenAI message types to RawMessage format."""
|
||||
|
||||
async def test_user_message_conversion(self):
|
||||
msg = OpenAIUserMessageParam(role="user", content="Hello world")
|
||||
raw_msg = await convert_openai_message_to_raw_message(msg)
|
||||
|
||||
assert raw_msg.role == "user"
|
||||
assert isinstance(raw_msg.content, RawTextItem)
|
||||
assert raw_msg.content.text == "Hello world"
|
||||
|
||||
async def test_assistant_message_conversion(self):
|
||||
msg = OpenAIAssistantMessageParam(role="assistant", content="Hi there!")
|
||||
raw_msg = await convert_openai_message_to_raw_message(msg)
|
||||
|
||||
assert raw_msg.role == "assistant"
|
||||
assert isinstance(raw_msg.content, RawTextItem)
|
||||
assert raw_msg.content.text == "Hi there!"
|
||||
assert raw_msg.tool_calls == []
|
||||
|
|
@ -104,12 +104,18 @@ async def test_paginated_response_url_setting():
|
|||
|
||||
route_handler = create_dynamic_typed_route(mock_api_method, "get", "/test/route")
|
||||
|
||||
# Mock minimal request
|
||||
# Mock minimal request with proper state object
|
||||
request = MagicMock()
|
||||
request.scope = {"user_attributes": {}, "principal": ""}
|
||||
request.headers = {}
|
||||
request.body = AsyncMock(return_value=b"")
|
||||
|
||||
# Create a simple state object without auto-generating attributes
|
||||
class MockState:
|
||||
pass
|
||||
|
||||
request.state = MockState()
|
||||
|
||||
result = await route_handler(request)
|
||||
|
||||
assert isinstance(result, PaginatedResponse)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue