mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-27 07:18:03 +00:00
Merge branch 'main' into fix_DPOAlignmentConfig_schema
This commit is contained in:
commit
c897d0b894
5 changed files with 515 additions and 334 deletions
|
|
@ -129,6 +129,22 @@ repos:
|
||||||
require_serial: true
|
require_serial: true
|
||||||
always_run: true
|
always_run: true
|
||||||
files: ^llama_stack/.*$
|
files: ^llama_stack/.*$
|
||||||
|
- id: forbid-pytest-asyncio
|
||||||
|
name: Block @pytest.mark.asyncio and @pytest_asyncio.fixture
|
||||||
|
entry: bash
|
||||||
|
language: system
|
||||||
|
types: [python]
|
||||||
|
pass_filenames: true
|
||||||
|
args:
|
||||||
|
- -c
|
||||||
|
- |
|
||||||
|
grep -EnH '^[^#]*@pytest\.mark\.asyncio|@pytest_asyncio\.fixture' "$@" && {
|
||||||
|
echo;
|
||||||
|
echo "❌ Do not use @pytest.mark.asyncio or @pytest_asyncio.fixture."
|
||||||
|
echo " pytest is already configured with async-mode=auto."
|
||||||
|
echo;
|
||||||
|
exit 1;
|
||||||
|
} || true
|
||||||
|
|
||||||
ci:
|
ci:
|
||||||
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
|
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
|
||||||
|
|
|
||||||
|
|
@ -47,8 +47,7 @@ class StackRun(Subcommand):
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
"--image-name",
|
"--image-name",
|
||||||
type=str,
|
type=str,
|
||||||
default=os.environ.get("CONDA_DEFAULT_ENV"),
|
help="Name of the image to run.",
|
||||||
help="Name of the image to run. Defaults to the current environment",
|
|
||||||
)
|
)
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
"--env",
|
"--env",
|
||||||
|
|
|
||||||
|
|
@ -5,41 +5,183 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
from datetime import UTC, datetime
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from llama_stack_client import Agent
|
from llama_stack_client import Agent
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="telemetry is not stable")
|
@pytest.fixture(scope="module", autouse=True)
|
||||||
def test_agent_query_spans(llama_stack_client, text_model_id):
|
def setup_telemetry_data(llama_stack_client, text_model_id):
|
||||||
|
"""Setup fixture that creates telemetry data before tests run."""
|
||||||
agent = Agent(llama_stack_client, model=text_model_id, instructions="You are a helpful assistant")
|
agent = Agent(llama_stack_client, model=text_model_id, instructions="You are a helpful assistant")
|
||||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
|
||||||
agent.create_turn(
|
session_id = agent.create_session(f"test-setup-session-{uuid4()}")
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
{
|
"What is 2 + 2?",
|
||||||
"role": "user",
|
"Tell me a short joke",
|
||||||
"content": "Give me a sentence that contains the word: hello",
|
]
|
||||||
}
|
|
||||||
],
|
for msg in messages:
|
||||||
|
agent.create_turn(
|
||||||
|
messages=[{"role": "user", "content": msg}],
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
stream=False,
|
stream=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Wait for the span to be logged
|
for i in range(2):
|
||||||
time.sleep(2)
|
llama_stack_client.inference.chat_completion(
|
||||||
|
model_id=text_model_id, messages=[{"role": "user", "content": f"Test trace {i}"}]
|
||||||
|
)
|
||||||
|
|
||||||
agent_logs = []
|
start_time = time.time()
|
||||||
|
|
||||||
for span in llama_stack_client.telemetry.query_spans(
|
while time.time() - start_time < 30:
|
||||||
attribute_filters=[
|
traces = llama_stack_client.telemetry.query_traces(limit=10)
|
||||||
{"key": "session_id", "op": "eq", "value": session_id},
|
if len(traces) >= 4:
|
||||||
],
|
break
|
||||||
attributes_to_return=["input", "output"],
|
time.sleep(1)
|
||||||
):
|
|
||||||
if span.attributes["output"] != "no shields":
|
|
||||||
agent_logs.append(span.attributes)
|
|
||||||
|
|
||||||
assert len(agent_logs) == 1
|
if len(traces) < 4:
|
||||||
assert "Give me a sentence that contains the word: hello" in agent_logs[0]["input"]
|
pytest.fail(f"Failed to create sufficient telemetry data after 30s. Got {len(traces)} traces.")
|
||||||
assert "hello" in agent_logs[0]["output"].lower()
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_traces_basic(llama_stack_client):
|
||||||
|
"""Test basic trace querying functionality with proper data validation."""
|
||||||
|
all_traces = llama_stack_client.telemetry.query_traces(limit=5)
|
||||||
|
|
||||||
|
assert isinstance(all_traces, list), "Should return a list of traces"
|
||||||
|
assert len(all_traces) >= 4, "Should have at least 4 traces from setup"
|
||||||
|
|
||||||
|
# Verify trace structure and data quality
|
||||||
|
first_trace = all_traces[0]
|
||||||
|
assert hasattr(first_trace, "trace_id"), "Trace should have trace_id"
|
||||||
|
assert hasattr(first_trace, "start_time"), "Trace should have start_time"
|
||||||
|
assert hasattr(first_trace, "root_span_id"), "Trace should have root_span_id"
|
||||||
|
|
||||||
|
# Validate trace_id is a valid UUID format
|
||||||
|
assert isinstance(first_trace.trace_id, str) and len(first_trace.trace_id) > 0, (
|
||||||
|
"trace_id should be non-empty string"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate start_time format and not in the future
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
if isinstance(first_trace.start_time, str):
|
||||||
|
trace_time = datetime.fromisoformat(first_trace.start_time.replace("Z", "+00:00"))
|
||||||
|
else:
|
||||||
|
# start_time is already a datetime object
|
||||||
|
trace_time = first_trace.start_time
|
||||||
|
if trace_time.tzinfo is None:
|
||||||
|
trace_time = trace_time.replace(tzinfo=UTC)
|
||||||
|
|
||||||
|
# Ensure trace time is not in the future (but allow any age in the past for persistent test data)
|
||||||
|
time_diff = (now - trace_time).total_seconds()
|
||||||
|
assert time_diff >= 0, f"Trace start_time should not be in the future, got {time_diff}s"
|
||||||
|
|
||||||
|
# Validate root_span_id exists and is non-empty
|
||||||
|
assert isinstance(first_trace.root_span_id, str) and len(first_trace.root_span_id) > 0, (
|
||||||
|
"root_span_id should be non-empty string"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test querying specific trace by ID
|
||||||
|
specific_trace = llama_stack_client.telemetry.get_trace(trace_id=first_trace.trace_id)
|
||||||
|
assert specific_trace.trace_id == first_trace.trace_id, "Retrieved trace should match requested ID"
|
||||||
|
assert specific_trace.start_time == first_trace.start_time, "Retrieved trace should have same start_time"
|
||||||
|
assert specific_trace.root_span_id == first_trace.root_span_id, "Retrieved trace should have same root_span_id"
|
||||||
|
|
||||||
|
# Test pagination with proper validation
|
||||||
|
recent_traces = llama_stack_client.telemetry.query_traces(limit=3, offset=0)
|
||||||
|
assert len(recent_traces) <= 3, "Should return at most 3 traces when limit=3"
|
||||||
|
assert len(recent_traces) >= 1, "Should return at least 1 trace"
|
||||||
|
|
||||||
|
# Verify all traces have required fields
|
||||||
|
for trace in recent_traces:
|
||||||
|
assert hasattr(trace, "trace_id") and trace.trace_id, "Each trace should have non-empty trace_id"
|
||||||
|
assert hasattr(trace, "start_time") and trace.start_time, "Each trace should have non-empty start_time"
|
||||||
|
assert hasattr(trace, "root_span_id") and trace.root_span_id, "Each trace should have non-empty root_span_id"
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_spans_basic(llama_stack_client):
|
||||||
|
"""Test basic span querying functionality with proper validation."""
|
||||||
|
spans = llama_stack_client.telemetry.query_spans(attribute_filters=[], attributes_to_return=[])
|
||||||
|
|
||||||
|
assert isinstance(spans, list), "Should return a list of spans"
|
||||||
|
assert len(spans) >= 1, "Should have at least one span from setup"
|
||||||
|
|
||||||
|
# Verify span structure and data quality
|
||||||
|
first_span = spans[0]
|
||||||
|
required_attrs = ["span_id", "name", "trace_id"]
|
||||||
|
for attr in required_attrs:
|
||||||
|
assert hasattr(first_span, attr), f"Span should have {attr} attribute"
|
||||||
|
assert getattr(first_span, attr), f"Span {attr} should not be empty"
|
||||||
|
|
||||||
|
# Validate span data types and content
|
||||||
|
assert isinstance(first_span.span_id, str) and len(first_span.span_id) > 0, "span_id should be non-empty string"
|
||||||
|
assert isinstance(first_span.name, str) and len(first_span.name) > 0, "span name should be non-empty string"
|
||||||
|
assert isinstance(first_span.trace_id, str) and len(first_span.trace_id) > 0, "trace_id should be non-empty string"
|
||||||
|
|
||||||
|
# Verify span belongs to a valid trace (test with traces we know exist)
|
||||||
|
all_traces = llama_stack_client.telemetry.query_traces(limit=10)
|
||||||
|
trace_ids = {t.trace_id for t in all_traces}
|
||||||
|
if first_span.trace_id in trace_ids:
|
||||||
|
trace = llama_stack_client.telemetry.get_trace(trace_id=first_span.trace_id)
|
||||||
|
assert trace is not None, "Should be able to retrieve trace for valid trace_id"
|
||||||
|
assert trace.trace_id == first_span.trace_id, "Trace ID should match span's trace_id"
|
||||||
|
|
||||||
|
# Test with span filtering and validate results
|
||||||
|
filtered_spans = llama_stack_client.telemetry.query_spans(
|
||||||
|
attribute_filters=[{"key": "name", "op": "eq", "value": first_span.name}],
|
||||||
|
attributes_to_return=["name", "span_id"],
|
||||||
|
)
|
||||||
|
assert isinstance(filtered_spans, list), "Should return a list with span name filter"
|
||||||
|
|
||||||
|
# Validate filtered spans if filtering works
|
||||||
|
if len(filtered_spans) > 0:
|
||||||
|
for span in filtered_spans:
|
||||||
|
assert hasattr(span, "name"), "Filtered spans should have name attribute"
|
||||||
|
assert hasattr(span, "span_id"), "Filtered spans should have span_id attribute"
|
||||||
|
assert span.name == first_span.name, "Filtered spans should match the filter criteria"
|
||||||
|
assert isinstance(span.span_id, str) and len(span.span_id) > 0, "Filtered span_id should be valid"
|
||||||
|
|
||||||
|
# Test that all spans have consistent structure
|
||||||
|
for span in spans:
|
||||||
|
for attr in required_attrs:
|
||||||
|
assert hasattr(span, attr) and getattr(span, attr), f"All spans should have non-empty {attr}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_telemetry_pagination(llama_stack_client):
|
||||||
|
"""Test pagination in telemetry queries."""
|
||||||
|
# Get total count of traces
|
||||||
|
all_traces = llama_stack_client.telemetry.query_traces(limit=20)
|
||||||
|
total_count = len(all_traces)
|
||||||
|
assert total_count >= 4, "Should have at least 4 traces from setup"
|
||||||
|
|
||||||
|
# Test trace pagination
|
||||||
|
page1 = llama_stack_client.telemetry.query_traces(limit=2, offset=0)
|
||||||
|
page2 = llama_stack_client.telemetry.query_traces(limit=2, offset=2)
|
||||||
|
|
||||||
|
assert len(page1) == 2, "First page should have exactly 2 traces"
|
||||||
|
assert len(page2) >= 1, "Second page should have at least 1 trace"
|
||||||
|
|
||||||
|
# Verify no overlap between pages
|
||||||
|
page1_ids = {t.trace_id for t in page1}
|
||||||
|
page2_ids = {t.trace_id for t in page2}
|
||||||
|
assert len(page1_ids.intersection(page2_ids)) == 0, "Pages should contain different traces"
|
||||||
|
|
||||||
|
# Test ordering
|
||||||
|
ordered_traces = llama_stack_client.telemetry.query_traces(limit=5, order_by=["start_time"])
|
||||||
|
assert len(ordered_traces) >= 4, "Should have at least 4 traces for ordering test"
|
||||||
|
|
||||||
|
# Verify ordering by start_time
|
||||||
|
for i in range(len(ordered_traces) - 1):
|
||||||
|
current_time = ordered_traces[i].start_time
|
||||||
|
next_time = ordered_traces[i + 1].start_time
|
||||||
|
assert current_time <= next_time, f"Traces should be ordered by start_time: {current_time} > {next_time}"
|
||||||
|
|
||||||
|
# Test limit behavior
|
||||||
|
limited = llama_stack_client.telemetry.query_traces(limit=3)
|
||||||
|
assert len(limited) == 3, "Should return exactly 3 traces when limit=3"
|
||||||
|
|
|
||||||
|
|
@ -4,14 +4,14 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import pytest
|
||||||
import unittest
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
CompletionMessage,
|
CompletionMessage,
|
||||||
StopReason,
|
StopReason,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
|
SystemMessageBehavior,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
|
|
@ -25,17 +25,15 @@ from llama_stack.models.llama.datatypes import (
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_messages,
|
chat_completion_request_to_messages,
|
||||||
chat_completion_request_to_prompt,
|
chat_completion_request_to_prompt,
|
||||||
|
interleaved_content_as_str,
|
||||||
)
|
)
|
||||||
|
|
||||||
MODEL = "Llama3.1-8B-Instruct"
|
MODEL = "Llama3.1-8B-Instruct"
|
||||||
MODEL3_2 = "Llama3.2-3B-Instruct"
|
MODEL3_2 = "Llama3.2-3B-Instruct"
|
||||||
|
|
||||||
|
|
||||||
class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
@pytest.mark.asyncio
|
||||||
async def asyncSetUp(self):
|
async def test_system_default():
|
||||||
asyncio.get_running_loop().set_debug(False)
|
|
||||||
|
|
||||||
async def test_system_default(self):
|
|
||||||
content = "Hello !"
|
content = "Hello !"
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=MODEL,
|
model=MODEL,
|
||||||
|
|
@ -44,11 +42,13 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
messages = chat_completion_request_to_messages(request, MODEL)
|
messages = chat_completion_request_to_messages(request, MODEL)
|
||||||
self.assertEqual(len(messages), 2)
|
assert len(messages) == 2
|
||||||
self.assertEqual(messages[-1].content, content)
|
assert messages[-1].content == content
|
||||||
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
|
assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content)
|
||||||
|
|
||||||
async def test_system_builtin_only(self):
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_system_builtin_only():
|
||||||
content = "Hello !"
|
content = "Hello !"
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=MODEL,
|
model=MODEL,
|
||||||
|
|
@ -61,12 +61,14 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
messages = chat_completion_request_to_messages(request, MODEL)
|
messages = chat_completion_request_to_messages(request, MODEL)
|
||||||
self.assertEqual(len(messages), 2)
|
assert len(messages) == 2
|
||||||
self.assertEqual(messages[-1].content, content)
|
assert messages[-1].content == content
|
||||||
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
|
assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content)
|
||||||
self.assertTrue("Tools: brave_search" in messages[0].content)
|
assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content)
|
||||||
|
|
||||||
async def test_system_custom_only(self):
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_system_custom_only():
|
||||||
content = "Hello !"
|
content = "Hello !"
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=MODEL,
|
model=MODEL,
|
||||||
|
|
@ -89,13 +91,15 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
||||||
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
|
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
|
||||||
)
|
)
|
||||||
messages = chat_completion_request_to_messages(request, MODEL)
|
messages = chat_completion_request_to_messages(request, MODEL)
|
||||||
self.assertEqual(len(messages), 3)
|
assert len(messages) == 3
|
||||||
self.assertTrue("Environment: ipython" in messages[0].content)
|
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
|
||||||
|
|
||||||
self.assertTrue("Return function calls in JSON format" in messages[1].content)
|
assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content)
|
||||||
self.assertEqual(messages[-1].content, content)
|
assert messages[-1].content == content
|
||||||
|
|
||||||
async def test_system_custom_and_builtin(self):
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_system_custom_and_builtin():
|
||||||
content = "Hello !"
|
content = "Hello !"
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=MODEL,
|
model=MODEL,
|
||||||
|
|
@ -119,15 +123,17 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
messages = chat_completion_request_to_messages(request, MODEL)
|
messages = chat_completion_request_to_messages(request, MODEL)
|
||||||
self.assertEqual(len(messages), 3)
|
assert len(messages) == 3
|
||||||
|
|
||||||
self.assertTrue("Environment: ipython" in messages[0].content)
|
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
|
||||||
self.assertTrue("Tools: brave_search" in messages[0].content)
|
assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content)
|
||||||
|
|
||||||
self.assertTrue("Return function calls in JSON format" in messages[1].content)
|
assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content)
|
||||||
self.assertEqual(messages[-1].content, content)
|
assert messages[-1].content == content
|
||||||
|
|
||||||
async def test_completion_message_encoding(self):
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_completion_message_encoding():
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=MODEL3_2,
|
model=MODEL3_2,
|
||||||
messages=[
|
messages=[
|
||||||
|
|
@ -160,17 +166,16 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
||||||
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list),
|
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list),
|
||||||
)
|
)
|
||||||
prompt = await chat_completion_request_to_prompt(request, request.model)
|
prompt = await chat_completion_request_to_prompt(request, request.model)
|
||||||
self.assertIn('[custom1(param1="value1")]', prompt)
|
assert '[custom1(param1="value1")]' in prompt
|
||||||
|
|
||||||
request.model = MODEL
|
request.model = MODEL
|
||||||
request.tool_config.tool_prompt_format = ToolPromptFormat.json
|
request.tool_config = ToolConfig(tool_prompt_format=ToolPromptFormat.json)
|
||||||
prompt = await chat_completion_request_to_prompt(request, request.model)
|
prompt = await chat_completion_request_to_prompt(request, request.model)
|
||||||
self.assertIn(
|
assert '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}' in prompt
|
||||||
'{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}',
|
|
||||||
prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def test_user_provided_system_message(self):
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_user_provided_system_message():
|
||||||
content = "Hello !"
|
content = "Hello !"
|
||||||
system_prompt = "You are a pirate"
|
system_prompt = "You are a pirate"
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
|
|
@ -184,12 +189,14 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
messages = chat_completion_request_to_messages(request, MODEL)
|
messages = chat_completion_request_to_messages(request, MODEL)
|
||||||
self.assertEqual(len(messages), 2, messages)
|
assert len(messages) == 2
|
||||||
self.assertTrue(messages[0].content.endswith(system_prompt))
|
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
|
||||||
|
|
||||||
self.assertEqual(messages[-1].content, content)
|
assert messages[-1].content == content
|
||||||
|
|
||||||
async def test_repalce_system_message_behavior_builtin_tools(self):
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_replace_system_message_behavior_builtin_tools():
|
||||||
content = "Hello !"
|
content = "Hello !"
|
||||||
system_prompt = "You are a pirate"
|
system_prompt = "You are a pirate"
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
|
|
@ -203,17 +210,19 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
||||||
],
|
],
|
||||||
tool_config=ToolConfig(
|
tool_config=ToolConfig(
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
tool_prompt_format="python_list",
|
tool_prompt_format=ToolPromptFormat.python_list,
|
||||||
system_message_behavior="replace",
|
system_message_behavior=SystemMessageBehavior.replace,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
||||||
self.assertEqual(len(messages), 2, messages)
|
assert len(messages) == 2
|
||||||
self.assertTrue(messages[0].content.endswith(system_prompt))
|
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
|
||||||
self.assertIn("Environment: ipython", messages[0].content)
|
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
|
||||||
self.assertEqual(messages[-1].content, content)
|
assert messages[-1].content == content
|
||||||
|
|
||||||
async def test_repalce_system_message_behavior_custom_tools(self):
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_replace_system_message_behavior_custom_tools():
|
||||||
content = "Hello !"
|
content = "Hello !"
|
||||||
system_prompt = "You are a pirate"
|
system_prompt = "You are a pirate"
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
|
|
@ -238,18 +247,20 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
||||||
],
|
],
|
||||||
tool_config=ToolConfig(
|
tool_config=ToolConfig(
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
tool_prompt_format="python_list",
|
tool_prompt_format=ToolPromptFormat.python_list,
|
||||||
system_message_behavior="replace",
|
system_message_behavior=SystemMessageBehavior.replace,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
||||||
|
|
||||||
self.assertEqual(len(messages), 2, messages)
|
assert len(messages) == 2
|
||||||
self.assertTrue(messages[0].content.endswith(system_prompt))
|
assert interleaved_content_as_str(messages[0].content).endswith(system_prompt)
|
||||||
self.assertIn("Environment: ipython", messages[0].content)
|
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
|
||||||
self.assertEqual(messages[-1].content, content)
|
assert messages[-1].content == content
|
||||||
|
|
||||||
async def test_replace_system_message_behavior_custom_tools_with_template(self):
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_replace_system_message_behavior_custom_tools_with_template():
|
||||||
content = "Hello !"
|
content = "Hello !"
|
||||||
system_prompt = "You are a pirate {{ function_description }}"
|
system_prompt = "You are a pirate {{ function_description }}"
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
|
|
@ -274,15 +285,15 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
||||||
],
|
],
|
||||||
tool_config=ToolConfig(
|
tool_config=ToolConfig(
|
||||||
tool_choice="auto",
|
tool_choice="auto",
|
||||||
tool_prompt_format="python_list",
|
tool_prompt_format=ToolPromptFormat.python_list,
|
||||||
system_message_behavior="replace",
|
system_message_behavior=SystemMessageBehavior.replace,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
messages = chat_completion_request_to_messages(request, MODEL3_2)
|
||||||
|
|
||||||
self.assertEqual(len(messages), 2, messages)
|
assert len(messages) == 2
|
||||||
self.assertIn("Environment: ipython", messages[0].content)
|
assert "Environment: ipython" in interleaved_content_as_str(messages[0].content)
|
||||||
self.assertIn("You are a pirate", messages[0].content)
|
assert "You are a pirate" in interleaved_content_as_str(messages[0].content)
|
||||||
# function description is present in the system prompt
|
# function description is present in the system prompt
|
||||||
self.assertIn('"name": "custom1"', messages[0].content)
|
assert '"name": "custom1"' in interleaved_content_as_str(messages[0].content)
|
||||||
self.assertEqual(messages[-1].content, content)
|
assert messages[-1].content == content
|
||||||
|
|
|
||||||
|
|
@ -5,13 +5,14 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import unittest
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from llama_stack.distribution.stack import replace_env_vars
|
from llama_stack.distribution.stack import replace_env_vars
|
||||||
|
|
||||||
|
|
||||||
class TestReplaceEnvVars(unittest.TestCase):
|
@pytest.fixture
|
||||||
def setUp(self):
|
def setup_env_vars():
|
||||||
# Clear any existing environment variables we'll use in tests
|
# Clear any existing environment variables we'll use in tests
|
||||||
for var in ["TEST_VAR", "EMPTY_VAR", "ZERO_VAR"]:
|
for var in ["TEST_VAR", "EMPTY_VAR", "ZERO_VAR"]:
|
||||||
if var in os.environ:
|
if var in os.environ:
|
||||||
|
|
@ -22,56 +23,68 @@ class TestReplaceEnvVars(unittest.TestCase):
|
||||||
os.environ["EMPTY_VAR"] = ""
|
os.environ["EMPTY_VAR"] = ""
|
||||||
os.environ["ZERO_VAR"] = "0"
|
os.environ["ZERO_VAR"] = "0"
|
||||||
|
|
||||||
def test_simple_replacement(self):
|
yield
|
||||||
self.assertEqual(replace_env_vars("${env.TEST_VAR}"), "test_value")
|
|
||||||
|
|
||||||
def test_default_value_when_not_set(self):
|
# Cleanup after test
|
||||||
self.assertEqual(replace_env_vars("${env.NOT_SET:=default}"), "default")
|
for var in ["TEST_VAR", "EMPTY_VAR", "ZERO_VAR"]:
|
||||||
|
if var in os.environ:
|
||||||
|
del os.environ[var]
|
||||||
|
|
||||||
def test_default_value_when_set(self):
|
|
||||||
self.assertEqual(replace_env_vars("${env.TEST_VAR:=default}"), "test_value")
|
|
||||||
|
|
||||||
def test_default_value_when_empty(self):
|
def test_simple_replacement(setup_env_vars):
|
||||||
self.assertEqual(replace_env_vars("${env.EMPTY_VAR:=default}"), "default")
|
assert replace_env_vars("${env.TEST_VAR}") == "test_value"
|
||||||
|
|
||||||
def test_none_value_when_empty(self):
|
|
||||||
self.assertEqual(replace_env_vars("${env.EMPTY_VAR:=}"), None)
|
|
||||||
|
|
||||||
def test_value_when_set(self):
|
def test_default_value_when_not_set(setup_env_vars):
|
||||||
self.assertEqual(replace_env_vars("${env.TEST_VAR:=}"), "test_value")
|
assert replace_env_vars("${env.NOT_SET:=default}") == "default"
|
||||||
|
|
||||||
def test_empty_var_no_default(self):
|
|
||||||
self.assertEqual(replace_env_vars("${env.EMPTY_VAR_NO_DEFAULT:+}"), None)
|
|
||||||
|
|
||||||
def test_conditional_value_when_set(self):
|
def test_default_value_when_set(setup_env_vars):
|
||||||
self.assertEqual(replace_env_vars("${env.TEST_VAR:+conditional}"), "conditional")
|
assert replace_env_vars("${env.TEST_VAR:=default}") == "test_value"
|
||||||
|
|
||||||
def test_conditional_value_when_not_set(self):
|
|
||||||
self.assertEqual(replace_env_vars("${env.NOT_SET:+conditional}"), None)
|
|
||||||
|
|
||||||
def test_conditional_value_when_empty(self):
|
def test_default_value_when_empty(setup_env_vars):
|
||||||
self.assertEqual(replace_env_vars("${env.EMPTY_VAR:+conditional}"), None)
|
assert replace_env_vars("${env.EMPTY_VAR:=default}") == "default"
|
||||||
|
|
||||||
def test_conditional_value_with_zero(self):
|
|
||||||
self.assertEqual(replace_env_vars("${env.ZERO_VAR:+conditional}"), "conditional")
|
|
||||||
|
|
||||||
def test_mixed_syntax(self):
|
def test_none_value_when_empty(setup_env_vars):
|
||||||
self.assertEqual(
|
assert replace_env_vars("${env.EMPTY_VAR:=}") is None
|
||||||
replace_env_vars("${env.TEST_VAR:=default} and ${env.NOT_SET:+conditional}"), "test_value and "
|
|
||||||
)
|
|
||||||
self.assertEqual(
|
|
||||||
replace_env_vars("${env.NOT_SET:=default} and ${env.TEST_VAR:+conditional}"), "default and conditional"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_nested_structures(self):
|
|
||||||
|
def test_value_when_set(setup_env_vars):
|
||||||
|
assert replace_env_vars("${env.TEST_VAR:=}") == "test_value"
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_var_no_default(setup_env_vars):
|
||||||
|
assert replace_env_vars("${env.EMPTY_VAR_NO_DEFAULT:+}") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_conditional_value_when_set(setup_env_vars):
|
||||||
|
assert replace_env_vars("${env.TEST_VAR:+conditional}") == "conditional"
|
||||||
|
|
||||||
|
|
||||||
|
def test_conditional_value_when_not_set(setup_env_vars):
|
||||||
|
assert replace_env_vars("${env.NOT_SET:+conditional}") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_conditional_value_when_empty(setup_env_vars):
|
||||||
|
assert replace_env_vars("${env.EMPTY_VAR:+conditional}") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_conditional_value_with_zero(setup_env_vars):
|
||||||
|
assert replace_env_vars("${env.ZERO_VAR:+conditional}") == "conditional"
|
||||||
|
|
||||||
|
|
||||||
|
def test_mixed_syntax(setup_env_vars):
|
||||||
|
assert replace_env_vars("${env.TEST_VAR:=default} and ${env.NOT_SET:+conditional}") == "test_value and "
|
||||||
|
assert replace_env_vars("${env.NOT_SET:=default} and ${env.TEST_VAR:+conditional}") == "default and conditional"
|
||||||
|
|
||||||
|
|
||||||
|
def test_nested_structures(setup_env_vars):
|
||||||
data = {
|
data = {
|
||||||
"key1": "${env.TEST_VAR:=default}",
|
"key1": "${env.TEST_VAR:=default}",
|
||||||
"key2": ["${env.NOT_SET:=default}", "${env.TEST_VAR:+conditional}"],
|
"key2": ["${env.NOT_SET:=default}", "${env.TEST_VAR:+conditional}"],
|
||||||
"key3": {"nested": "${env.NOT_SET:+conditional}"},
|
"key3": {"nested": "${env.NOT_SET:+conditional}"},
|
||||||
}
|
}
|
||||||
expected = {"key1": "test_value", "key2": ["default", "conditional"], "key3": {"nested": None}}
|
expected = {"key1": "test_value", "key2": ["default", "conditional"], "key3": {"nested": None}}
|
||||||
self.assertEqual(replace_env_vars(data), expected)
|
assert replace_env_vars(data) == expected
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue