refactor(test): introduce --stack-config and simplify options (#1404)

You now run the integration tests with these options:

```bash
Custom options:
  --stack-config=STACK_CONFIG
                        a 'pointer' to the stack. this can be either be:
                        (a) a template name like `fireworks`, or
                        (b) a path to a run.yaml file, or
                        (c) an adhoc config spec, e.g.
                        `inference=fireworks,safety=llama-guard,agents=meta-
                        reference`
  --env=ENV             Set environment variables, e.g. --env KEY=value
  --text-model=TEXT_MODEL
                        comma-separated list of text models. Fixture name:
                        text_model_id
  --vision-model=VISION_MODEL
                        comma-separated list of vision models. Fixture name:
                        vision_model_id
  --embedding-model=EMBEDDING_MODEL
                        comma-separated list of embedding models. Fixture name:
                        embedding_model_id
  --safety-shield=SAFETY_SHIELD
                        comma-separated list of safety shields. Fixture name:
                        shield_id
  --judge-model=JUDGE_MODEL
                        comma-separated list of judge models. Fixture name:
                        judge_model_id
  --embedding-dimension=EMBEDDING_DIMENSION
                        Output dimensionality of the embedding model to use for
                        testing. Default: 384
  --record-responses    Record new API responses instead of using cached ones.
  --report=REPORT       Path where the test report should be written, e.g.
                        --report=/path/to/report.md

```

Importantly, if you don't specify any of the models (text-model,
vision-model, etc.) the relevant tests will get **skipped!**

This will make running tests somewhat more annoying since all options
will need to be specified. We will make this easier by adding some easy
wrapper yaml configs.

## Test Plan

Example:

```bash
ashwin@ashwin-mbp ~/local/llama-stack/tests/integration (unify_tests) $ 
LLAMA_STACK_CONFIG=fireworks pytest -s -v inference/test_text_inference.py \
   --text-model meta-llama/Llama-3.2-3B-Instruct 
```
This commit is contained in:
Ashwin Bharambe 2025-03-05 17:02:02 -08:00 committed by GitHub
parent a0d6b165b0
commit 2fe976ed0a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 536 additions and 1144 deletions

View file

@ -7,6 +7,7 @@
import importlib.resources
import os
import re
import tempfile
from typing import Any, Dict, Optional
import yaml
@ -33,10 +34,11 @@ from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.datatypes import Provider, StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.providers.datatypes import Api
@ -228,3 +230,53 @@ def get_stack_run_config_from_template(template: str) -> StackRunConfig:
run_config = yaml.safe_load(path.open())
return StackRunConfig(**replace_env_vars(run_config))
def run_config_from_adhoc_config_spec(
adhoc_config_spec: str, provider_registry: Optional[ProviderRegistry] = None
) -> StackRunConfig:
"""
Create an adhoc distribution from a list of API providers.
The list should be of the form "api=provider", e.g. "inference=fireworks". If you have
multiple pairs, separate them with commas or semicolons, e.g. "inference=fireworks,safety=llama-guard,agents=meta-reference"
"""
api_providers = adhoc_config_spec.replace(";", ",").split(",")
provider_registry = provider_registry or get_provider_registry()
distro_dir = tempfile.mkdtemp()
provider_configs_by_api = {}
for api_provider in api_providers:
api_str, provider = api_provider.split("=")
api = Api(api_str)
providers_by_type = provider_registry[api]
provider_spec = providers_by_type.get(provider)
if not provider_spec:
provider_spec = providers_by_type.get(f"inline::{provider}")
if not provider_spec:
provider_spec = providers_by_type.get(f"remote::{provider}")
if not provider_spec:
raise ValueError(
f"Provider {provider} (or remote::{provider} or inline::{provider}) not found for API {api}"
)
# call method "sample_run_config" on the provider spec config class
provider_config_type = instantiate_class_type(provider_spec.config_class)
provider_config = replace_env_vars(provider_config_type.sample_run_config(__distro_dir__=distro_dir))
provider_configs_by_api[api_str] = [
Provider(
provider_id=provider,
provider_type=provider_spec.provider_type,
config=provider_config,
)
]
config = StackRunConfig(
image_name="distro-test",
apis=list(provider_configs_by_api.keys()),
providers=provider_configs_by_api,
)
return config

View file

@ -1,411 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import tempfile
from typing import AsyncIterator, List, Optional, Union
import pytest
from llama_stack.apis.agents import (
AgentConfig,
AgentToolGroupWithArgs,
AgentTurnCreateRequest,
AgentTurnResponseTurnCompletePayload,
StepType,
)
from llama_stack.apis.common.content_types import URL, TextDelta
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
UserMessage,
)
from llama_stack.apis.safety import RunShieldResponse
from llama_stack.apis.tools import (
ListToolGroupsResponse,
ListToolsResponse,
Tool,
ToolDef,
ToolGroup,
ToolHost,
ToolInvocationResult,
)
from llama_stack.apis.vector_io import QueryChunksResponse
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
MEMORY_QUERY_TOOL,
)
from llama_stack.providers.inline.agents.meta_reference.agents import (
MetaReferenceAgentsImpl,
MetaReferenceAgentsImplConfig,
)
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
class MockInferenceAPI:
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = None,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
async def stream_response():
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta=TextDelta(text=""),
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=TextDelta(text="AI is a fascinating field..."),
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=TextDelta(text=""),
stop_reason=StopReason.end_of_turn,
)
)
if stream:
return stream_response()
else:
return ChatCompletionResponse(
completion_message=CompletionMessage(
role="assistant",
content="Mock response",
stop_reason="end_of_turn",
),
logprobs={"token_logprobs": [0.1, 0.2, 0.3]} if logprobs else None,
)
class MockSafetyAPI:
async def run_shield(self, shield_id: str, messages: List[Message]) -> RunShieldResponse:
return RunShieldResponse(violation=None)
class MockVectorIOAPI:
def __init__(self):
self.chunks = {}
async def insert_chunks(self, vector_db_id, chunks, ttl_seconds=None):
for chunk in chunks:
metadata = chunk.metadata
self.chunks[vector_db_id][metadata["document_id"]] = chunk
async def query_chunks(self, vector_db_id, query, params=None):
if vector_db_id not in self.chunks:
raise ValueError(f"Bank {vector_db_id} not found")
chunks = list(self.chunks[vector_db_id].values())
scores = [1.0] * len(chunks)
return QueryChunksResponse(chunks=chunks, scores=scores)
class MockToolGroupsAPI:
async def register_tool_group(self, toolgroup_id: str, provider_id: str, mcp_endpoint=None, args=None) -> None:
pass
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
return ToolGroup(
identifier=toolgroup_id,
provider_resource_id=toolgroup_id,
)
async def list_tool_groups(self) -> ListToolGroupsResponse:
return ListToolGroupsResponse(data=[])
async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse:
if toolgroup_id == MEMORY_TOOLGROUP:
return ListToolsResponse(
data=[
Tool(
identifier=MEMORY_QUERY_TOOL,
provider_resource_id=MEMORY_QUERY_TOOL,
toolgroup_id=MEMORY_TOOLGROUP,
tool_host=ToolHost.client,
description="Mock tool",
provider_id="builtin::rag",
parameters=[],
)
]
)
if toolgroup_id == CODE_INTERPRETER_TOOLGROUP:
return ListToolsResponse(
data=[
Tool(
identifier="code_interpreter",
provider_resource_id="code_interpreter",
toolgroup_id=CODE_INTERPRETER_TOOLGROUP,
tool_host=ToolHost.client,
description="Mock tool",
provider_id="builtin::code_interpreter",
parameters=[],
)
]
)
return ListToolsResponse(data=[])
async def get_tool(self, tool_name: str) -> Tool:
return Tool(
identifier=tool_name,
provider_resource_id=tool_name,
toolgroup_id="mock_group",
tool_host=ToolHost.client,
description="Mock tool",
provider_id="mock_provider",
parameters=[],
)
async def unregister_tool_group(self, toolgroup_id: str) -> None:
pass
class MockToolRuntimeAPI:
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return []
async def invoke_tool(self, tool_name: str, args: dict) -> ToolInvocationResult:
return ToolInvocationResult(content={"result": "Mock tool result"})
@pytest.fixture
def mock_inference_api():
return MockInferenceAPI()
@pytest.fixture
def mock_safety_api():
return MockSafetyAPI()
@pytest.fixture
def mock_vector_io_api():
return MockVectorIOAPI()
@pytest.fixture
def mock_tool_groups_api():
return MockToolGroupsAPI()
@pytest.fixture
def mock_tool_runtime_api():
return MockToolRuntimeAPI()
@pytest.fixture
async def get_agents_impl(
mock_inference_api,
mock_safety_api,
mock_vector_io_api,
mock_tool_runtime_api,
mock_tool_groups_api,
):
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
impl = MetaReferenceAgentsImpl(
config=MetaReferenceAgentsImplConfig(
persistence_store=SqliteKVStoreConfig(
db_name=sqlite_file.name,
),
),
inference_api=mock_inference_api,
safety_api=mock_safety_api,
vector_io_api=mock_vector_io_api,
tool_runtime_api=mock_tool_runtime_api,
tool_groups_api=mock_tool_groups_api,
)
await impl.initialize()
return impl
@pytest.fixture
async def get_chat_agent(get_agents_impl):
impl = await get_agents_impl
agent_config = AgentConfig(
model="test_model",
instructions="You are a helpful assistant.",
toolgroups=[],
tool_choice=ToolChoice.auto,
enable_session_persistence=False,
input_shields=["test_shield"],
)
response = await impl.create_agent(agent_config)
return await impl.get_agent(response.agent_id)
MEMORY_TOOLGROUP = "builtin::rag"
CODE_INTERPRETER_TOOLGROUP = "builtin::code_interpreter"
@pytest.fixture
async def get_chat_agent_with_tools(get_agents_impl, request):
impl = await get_agents_impl
toolgroups = request.param
agent_config = AgentConfig(
model="test_model",
instructions="You are a helpful assistant.",
toolgroups=toolgroups,
tool_choice=ToolChoice.auto,
enable_session_persistence=False,
input_shields=["test_shield"],
)
response = await impl.create_agent(agent_config)
return await impl.get_agent(response.agent_id)
@pytest.mark.asyncio
async def test_chat_agent_create_and_execute_turn(get_chat_agent):
chat_agent = await get_chat_agent
session_id = await chat_agent.create_session("Test Session")
request = AgentTurnCreateRequest(
agent_id=chat_agent.agent_id,
session_id=session_id,
messages=[UserMessage(content="Hello")],
stream=True,
)
responses = []
async for response in chat_agent.create_and_execute_turn(request):
responses.append(response)
assert len(responses) > 0
assert (
len(responses) == 7
) # TurnStart, ShieldCallStart, ShieldCallComplete, StepStart, StepProgress, StepComplete, TurnComplete
assert responses[0].event.payload.turn_id is not None
@pytest.mark.asyncio
async def test_run_multiple_shields_wrapper(get_chat_agent):
chat_agent = await get_chat_agent
messages = [UserMessage(content="Test message")]
shields = ["test_shield"]
responses = [
chunk
async for chunk in chat_agent.run_multiple_shields_wrapper(
turn_id="test_turn_id",
messages=messages,
shields=shields,
touchpoint="user-input",
)
]
assert len(responses) == 2 # StepStart, StepComplete
assert responses[0].event.payload.step_type.value == "shield_call"
assert not responses[1].event.payload.step_details.violation
@pytest.mark.asyncio
async def test_chat_agent_complex_turn(get_chat_agent):
chat_agent = await get_chat_agent
session_id = await chat_agent.create_session("Test Session")
request = AgentTurnCreateRequest(
agent_id=chat_agent.agent_id,
session_id=session_id,
messages=[UserMessage(content="Tell me about AI and then use a tool.")],
stream=True,
)
responses = []
async for response in chat_agent.create_and_execute_turn(request):
responses.append(response)
assert len(responses) > 0
step_types = [
response.event.payload.step_type for response in responses if hasattr(response.event.payload, "step_type")
]
assert StepType.shield_call in step_types, "Shield call step is missing"
assert StepType.inference in step_types, "Inference step is missing"
event_types = [
response.event.payload.event_type for response in responses if hasattr(response.event.payload, "event_type")
]
assert "turn_start" in event_types, "Start event is missing"
assert "turn_complete" in event_types, "Complete event is missing"
assert any(isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload) for response in responses), (
"Turn complete event is missing"
)
turn_complete_payload = next(
response.event.payload
for response in responses
if isinstance(response.event.payload, AgentTurnResponseTurnCompletePayload)
)
turn = turn_complete_payload.turn
assert turn.input_messages == request.messages, "Input messages do not match"
@pytest.mark.asyncio
@pytest.mark.parametrize(
"toolgroups, expected_memory, expected_code_interpreter",
[
([], False, False), # no tools
([MEMORY_TOOLGROUP], True, False), # memory only
([CODE_INTERPRETER_TOOLGROUP], False, True), # code interpreter only
([MEMORY_TOOLGROUP, CODE_INTERPRETER_TOOLGROUP], True, True), # all tools
],
)
async def test_chat_agent_tools(get_agents_impl, toolgroups, expected_memory, expected_code_interpreter):
impl = await get_agents_impl
agent_config = AgentConfig(
model="test_model",
instructions="You are a helpful assistant.",
toolgroups=toolgroups,
tool_choice=ToolChoice.auto,
enable_session_persistence=False,
input_shields=["test_shield"],
)
response = await impl.create_agent(agent_config)
chat_agent = await impl.get_agent(response.agent_id)
tool_defs, _ = await chat_agent._get_tool_defs()
tool_defs_names = [t.tool_name for t in tool_defs]
if expected_memory:
assert MEMORY_QUERY_TOOL in tool_defs_names
if expected_code_interpreter:
assert BuiltinTool.code_interpreter in tool_defs_names
if expected_memory and expected_code_interpreter:
# override the tools for turn
new_tool_defs, _ = await chat_agent._get_tool_defs(
toolgroups_for_turn=[
AgentToolGroupWithArgs(
name=MEMORY_TOOLGROUP,
args={"vector_dbs": ["test_vector_db"]},
)
]
)
new_tool_defs_names = [t.tool_name for t in new_tool_defs]
assert MEMORY_QUERY_TOOL in new_tool_defs_names
assert BuiltinTool.code_interpreter not in new_tool_defs_names

View file

@ -1,109 +0,0 @@
# Testing Llama Stack Providers
The Llama Stack is designed as a collection of Lego blocks -- various APIs -- which are composable and can be used to quickly and reliably build an app. We need a testing setup which is relatively flexible to enable easy combinations of these providers.
We use `pytest` and all of its dynamism to enable the features needed. Specifically:
- We use `pytest_addoption` to add CLI options allowing you to override providers, models, etc.
- We use `pytest_generate_tests` to dynamically parametrize our tests. This allows us to support a default set of (providers, models, etc.) combinations but retain the flexibility to override them via the CLI if needed.
- We use `pytest_configure` to make sure we dynamically add appropriate marks based on the fixtures we make.
- We use `pytest_collection_modifyitems` to filter tests based on the test config (if specified).
## Pre-requisites
Your development environment should have been configured as per the instructions in the
[CONTRIBUTING.md](../../../CONTRIBUTING.md) file. In particular, make sure to install the test extra
dependencies. Below is the full configuration:
```bash
cd llama-stack
uv sync --extra dev --extra test
uv pip install -e .
source .venv/bin/activate
```
## Common options
All tests support a `--providers` option which can be a string of the form `api1=provider_fixture1,api2=provider_fixture2`. So, when testing safety (which need inference and safety APIs) you can use `--providers inference=together,safety=meta_reference` to use these fixtures in concert.
Depending on the API, there are custom options enabled. For example, `inference` tests allow for an `--inference-model` override, etc.
By default, we disable warnings and enable short tracebacks. You can override them using pytest's flags as appropriate.
Some providers need special API keys or other configuration options to work. You can check out the individual fixtures (located in `tests/<api>/fixtures.py`) for what these keys are. These can be specified using the `--env` CLI option. You can also have it be present in the environment (exporting in your shell) or put it in the `.env` file in the directory from which you run the test. For example, to use the Together fixture you can use `--env TOGETHER_API_KEY=<...>`
## Inference
We have the following orthogonal parametrizations (pytest "marks") for inference tests:
- providers: (meta_reference, together, fireworks, ollama)
- models: (llama_8b, llama_3b)
If you want to run a test with the llama_8b model with fireworks, you can use:
```bash
pytest -s -v llama_stack/providers/tests/inference/test_text_inference.py \
-m "fireworks and llama_8b" \
--env FIREWORKS_API_KEY=<...>
```
You can make it more complex to run both llama_8b and llama_3b on Fireworks, but only llama_3b with Ollama:
```bash
pytest -s -v llama_stack/providers/tests/inference/test_text_inference.py \
-m "fireworks or (ollama and llama_3b)" \
--env FIREWORKS_API_KEY=<...>
```
Finally, you can override the model completely by doing:
```bash
pytest -s -v llama_stack/providers/tests/inference/test_text_inference.py \
-m fireworks \
--inference-model "meta-llama/Llama3.1-70B-Instruct" \
--env FIREWORKS_API_KEY=<...>
```
> [!TIP]
> If youre using `uv`, you can isolate test executions by prefixing all commands with `uv run pytest...`.
## Agents
The Agents API composes three other APIs underneath:
- Inference
- Safety
- Memory
Given that each of these has several fixtures each, the set of combinations is large. We provide a default set of combinations (see `tests/agents/conftest.py`) with easy to use "marks":
- `meta_reference` -- uses all the `meta_reference` fixtures for the dependent APIs
- `together` -- uses Together for inference, and `meta_reference` for the rest
- `ollama` -- uses Ollama for inference, and `meta_reference` for the rest
An example test with Together:
```bash
pytest -s -m together llama_stack/providers/tests/agents/test_agents.py \
--env TOGETHER_API_KEY=<...>
```
If you want to override the inference model or safety model used, you can use the `--inference-model` or `--safety-shield` CLI options as appropriate.
If you wanted to test a remotely hosted stack, you can use `-m remote` as follows:
```bash
pytest -s -m remote llama_stack/providers/tests/agents/test_agents.py \
--env REMOTE_STACK_URL=<...>
```
## Test Config
If you want to run a test suite with a custom set of tests and parametrizations, you can define a YAML test config under llama_stack/providers/tests/ folder and pass the filename through `--config` option as follows:
```
pytest llama_stack/providers/tests/ --config=ci_test_config.yaml
```
### Test config format
Currently, we support test config on inference, agents and memory api tests.
Example format of test config can be found in ci_test_config.yaml.
## Test Data
We encourage providers to use our test data for internal development testing, so to make it easier and consistent with the tests we provide. Each test case may define its own data format, and please refer to our test source code to get details on how these fields are used in the test.

View file

@ -1,101 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import tempfile
from typing import Any, Dict, List, Optional
from pydantic import BaseModel
from llama_stack.apis.benchmarks import BenchmarkInput
from llama_stack.apis.datasets import DatasetInput
from llama_stack.apis.models import ModelInput
from llama_stack.apis.scoring_functions import ScoringFnInput
from llama_stack.apis.shields import ShieldInput
from llama_stack.apis.tools import ToolGroupInput
from llama_stack.apis.vector_dbs import VectorDBInput
from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.datatypes import Provider, StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_remote_stack_impls
from llama_stack.distribution.stack import construct_stack
from llama_stack.providers.datatypes import Api, RemoteProviderConfig
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
class TestStack(BaseModel):
impls: Dict[Api, Any]
run_config: StackRunConfig
async def construct_stack_for_test(
apis: List[Api],
providers: Dict[str, List[Provider]],
provider_data: Optional[Dict[str, Any]] = None,
models: Optional[List[ModelInput]] = None,
shields: Optional[List[ShieldInput]] = None,
vector_dbs: Optional[List[VectorDBInput]] = None,
datasets: Optional[List[DatasetInput]] = None,
scoring_fns: Optional[List[ScoringFnInput]] = None,
benchmarks: Optional[List[BenchmarkInput]] = None,
tool_groups: Optional[List[ToolGroupInput]] = None,
) -> TestStack:
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
run_config = dict(
image_name="test-fixture",
apis=apis,
providers=providers,
metadata_store=SqliteKVStoreConfig(db_path=sqlite_file.name),
models=models or [],
shields=shields or [],
vector_dbs=vector_dbs or [],
datasets=datasets or [],
scoring_fns=scoring_fns or [],
benchmarks=benchmarks or [],
tool_groups=tool_groups or [],
)
run_config = parse_and_maybe_upgrade_config(run_config)
try:
remote_config = remote_provider_config(run_config)
if not remote_config:
# TODO: add to provider registry by creating interesting mocks or fakes
impls = await construct_stack(run_config, get_provider_registry())
else:
# we don't register resources for a remote stack as part of the fixture setup
# because the stack is already "up". if a test needs to register resources, it
# can do so manually always.
impls = await resolve_remote_stack_impls(remote_config, run_config.apis)
test_stack = TestStack(impls=impls, run_config=run_config)
except ModuleNotFoundError as e:
print_pip_install_help(providers)
raise e
if provider_data:
set_request_provider_data({"X-LlamaStack-Provider-Data": json.dumps(provider_data)})
return test_stack
def remote_provider_config(
run_config: StackRunConfig,
) -> Optional[RemoteProviderConfig]:
remote_config = None
has_non_remote = False
for api_providers in run_config.providers.values():
for provider in api_providers:
if provider.provider_type == "test::remote":
remote_config = RemoteProviderConfig(**provider.config)
else:
has_non_remote = True
if remote_config:
assert not has_non_remote, "Remote stack cannot have non-remote providers"
return remote_config

View file

@ -1,101 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import List
import pytest
import requests
from pydantic import TypeAdapter
from llama_stack.apis.tools import (
DefaultRAGQueryGeneratorConfig,
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
)
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.providers.utils.memory.vector_store import interleaved_content_as_str
class TestRAGToolEndpoints:
@pytest.fixture
def base_url(self) -> str:
return "http://localhost:8321/v1" # Adjust port if needed
@pytest.fixture
def sample_documents(self) -> List[RAGDocument]:
return [
RAGDocument(
document_id="doc1",
content="Python is a high-level programming language.",
metadata={"category": "programming", "difficulty": "beginner"},
),
RAGDocument(
document_id="doc2",
content="Machine learning is a subset of artificial intelligence.",
metadata={"category": "AI", "difficulty": "advanced"},
),
RAGDocument(
document_id="doc3",
content="Data structures are fundamental to computer science.",
metadata={"category": "computer science", "difficulty": "intermediate"},
),
]
@pytest.mark.asyncio
async def test_rag_workflow(self, base_url: str, sample_documents: List[RAGDocument]):
vector_db_payload = {
"vector_db_id": "test_vector_db",
"embedding_model": "all-MiniLM-L6-v2",
"embedding_dimension": 384,
}
response = requests.post(f"{base_url}/vector-dbs", json=vector_db_payload)
assert response.status_code == 200
vector_db = VectorDB(**response.json())
insert_payload = {
"documents": [json.loads(doc.model_dump_json()) for doc in sample_documents],
"vector_db_id": vector_db.identifier,
"chunk_size_in_tokens": 512,
}
response = requests.post(
f"{base_url}/tool-runtime/rag-tool/insert-documents",
json=insert_payload,
)
assert response.status_code == 200
query = "What is Python?"
query_config = RAGQueryConfig(
query_generator_config=DefaultRAGQueryGeneratorConfig(),
max_tokens_in_context=4096,
max_chunks=2,
)
query_payload = {
"content": query,
"query_config": json.loads(query_config.model_dump_json()),
"vector_db_ids": [vector_db.identifier],
}
response = requests.post(
f"{base_url}/tool-runtime/rag-tool/query-context",
json=query_payload,
)
assert response.status_code == 200
result = response.json()
result = TypeAdapter(RAGQueryResult).validate_python(result)
content_str = interleaved_content_as_str(result.content)
print(f"content: {content_str}")
assert len(content_str) > 0
assert "Python" in content_str
# Clean up: Delete the vector DB
response = requests.delete(f"{base_url}/vector-dbs/{vector_db.identifier}")
assert response.status_code == 200