mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
- simplified test to use telemetry.query_metrics for verification - test now validates actual queryable metrics data - verified by query metrics functionality added in #3074
170 lines
5.5 KiB
Python
170 lines
5.5 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from collections.abc import AsyncGenerator, Callable
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from unittest.mock import Mock, patch
|
|
|
|
import pytest
|
|
|
|
from llama_stack.apis.inference import ToolDefinition
|
|
from llama_stack.apis.tools import ToolInvocationResult
|
|
from llama_stack.providers.inline.agents.meta_reference.agent_instance import ChatAgent
|
|
from llama_stack.providers.inline.telemetry.meta_reference.config import (
|
|
TelemetryConfig,
|
|
TelemetrySink,
|
|
)
|
|
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
|
|
TelemetryAdapter,
|
|
)
|
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
|
from llama_stack.providers.utils.kvstore.sqlite.sqlite import SqliteKVStoreImpl
|
|
from llama_stack.providers.utils.telemetry import tracing as telemetry_tracing
|
|
|
|
|
|
@pytest.fixture
|
|
def make_agent_fixture():
|
|
def _make(telemetry, kvstore) -> ChatAgent:
|
|
agent = ChatAgent(
|
|
agent_id="test-agent",
|
|
agent_config=Mock(),
|
|
inference_api=Mock(),
|
|
safety_api=Mock(),
|
|
tool_runtime_api=Mock(),
|
|
tool_groups_api=Mock(),
|
|
vector_io_api=Mock(),
|
|
telemetry_api=telemetry,
|
|
persistence_store=kvstore,
|
|
created_at="2025-01-01T00:00:00Z",
|
|
policy=[],
|
|
)
|
|
agent.agent_config.client_tools = []
|
|
agent.agent_config.max_infer_iters = 5
|
|
agent.input_shields = []
|
|
agent.output_shields = []
|
|
agent.tool_defs = [
|
|
ToolDefinition(tool_name="web_search", description="", parameters={}),
|
|
ToolDefinition(tool_name="knowledge_search", description="", parameters={}),
|
|
]
|
|
agent.tool_name_to_args = {}
|
|
|
|
# Stub tool runtime invoke_tool
|
|
async def _mock_invoke_tool(
|
|
*args: Any,
|
|
tool_name: str | None = None,
|
|
kwargs: dict | None = None,
|
|
**extra: Any,
|
|
):
|
|
return ToolInvocationResult(content="Tool execution result")
|
|
|
|
agent.tool_runtime_api.invoke_tool = _mock_invoke_tool
|
|
return agent
|
|
|
|
return _make
|
|
|
|
|
|
def _chat_stream(tool_name: str | None, content: str = ""):
|
|
from llama_stack.apis.common.content_types import (
|
|
TextDelta,
|
|
ToolCallDelta,
|
|
ToolCallParseStatus,
|
|
)
|
|
from llama_stack.apis.inference import (
|
|
ChatCompletionResponseEvent,
|
|
ChatCompletionResponseEventType,
|
|
ChatCompletionResponseStreamChunk,
|
|
StopReason,
|
|
)
|
|
from llama_stack.models.llama.datatypes import ToolCall
|
|
|
|
async def gen():
|
|
# Start
|
|
yield ChatCompletionResponseStreamChunk(
|
|
event=ChatCompletionResponseEvent(
|
|
event_type=ChatCompletionResponseEventType.start,
|
|
delta=TextDelta(text=""),
|
|
)
|
|
)
|
|
|
|
# Content
|
|
if content:
|
|
yield ChatCompletionResponseStreamChunk(
|
|
event=ChatCompletionResponseEvent(
|
|
event_type=ChatCompletionResponseEventType.progress,
|
|
delta=TextDelta(text=content),
|
|
)
|
|
)
|
|
|
|
# Tool call if specified
|
|
if tool_name:
|
|
yield ChatCompletionResponseStreamChunk(
|
|
event=ChatCompletionResponseEvent(
|
|
event_type=ChatCompletionResponseEventType.progress,
|
|
delta=ToolCallDelta(
|
|
tool_call=ToolCall(call_id="call_0", tool_name=tool_name, arguments={}),
|
|
parse_status=ToolCallParseStatus.succeeded,
|
|
),
|
|
)
|
|
)
|
|
|
|
# Complete
|
|
yield ChatCompletionResponseStreamChunk(
|
|
event=ChatCompletionResponseEvent(
|
|
event_type=ChatCompletionResponseEventType.complete,
|
|
delta=TextDelta(text=""),
|
|
stop_reason=StopReason.end_of_turn,
|
|
)
|
|
)
|
|
|
|
return gen()
|
|
|
|
|
|
@pytest.fixture
|
|
async def telemetry(tmp_path: Path) -> AsyncGenerator[TelemetryAdapter, None]:
|
|
db_path = tmp_path / "trace_store.db"
|
|
cfg = TelemetryConfig(
|
|
sinks=[TelemetrySink.CONSOLE, TelemetrySink.SQLITE],
|
|
sqlite_db_path=str(db_path),
|
|
)
|
|
telemetry = TelemetryAdapter(cfg, deps={})
|
|
telemetry_tracing.setup_logger(telemetry)
|
|
try:
|
|
yield telemetry
|
|
finally:
|
|
await telemetry.shutdown()
|
|
|
|
|
|
@pytest.fixture
|
|
async def kvstore(tmp_path: Path) -> SqliteKVStoreImpl:
|
|
kv_path = tmp_path / "agent_kvstore.db"
|
|
kv = SqliteKVStoreImpl(SqliteKVStoreConfig(db_path=str(kv_path)))
|
|
await kv.initialize()
|
|
return kv
|
|
|
|
|
|
@pytest.fixture
|
|
def span_patch():
|
|
with (
|
|
patch("llama_stack.providers.inline.agents.meta_reference.agent_instance.get_current_span") as mock_span,
|
|
patch(
|
|
"llama_stack.providers.utils.telemetry.tracing.generate_span_id",
|
|
return_value="0000000000000abc",
|
|
),
|
|
):
|
|
mock_span.return_value = Mock(get_span_context=Mock(return_value=Mock(trace_id=0x123, span_id=0xABC)))
|
|
yield
|
|
|
|
|
|
@pytest.fixture
|
|
def make_completion_fn() -> Callable[[str | None, str], Callable]:
|
|
def _factory(tool_name: str | None = None, content: str = "") -> Callable:
|
|
async def chat_completion(*args: Any, **kwargs: Any):
|
|
return _chat_stream(tool_name, content)
|
|
|
|
return chat_completion
|
|
|
|
return _factory
|