forked from phoenix-oss/llama-stack-mirror
See https://github.com/meta-llama/llama-stack/issues/827 for the broader design. Third part: - we need to make `tool_runtime.rag_tool.query_context()` and `tool_runtime.rag_tool.insert_documents()` methods work smoothly with complete type safety. To that end, we introduce a sub-resource path `tool-runtime/rag-tool/` and make changes to the resolver to make things work. - the PR updates the agents implementation to directly call these typed APIs for memory accesses rather than going through the complex, untyped "invoke_tool" API. the code looks much nicer and simpler (expectedly.) - there are a number of hacks in the server resolver implementation still, we will live with some and fix some Note that we must make sure the client SDKs are able to handle this subresource complexity also. Stainless has support for subresources, so this should be possible but beware. ## Test Plan Our RAG test is sad (doesn't actually test for actual RAG output) but I verified that the implementation works. I will work on fixing the RAG test afterwards. ```bash pytest -s -v tests/agents/test_agents.py -k "rag and together" --safety-shield=meta-llama/Llama-Guard-3-8B ```
128 lines
4.1 KiB
Python
128 lines
4.1 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import tempfile
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
|
|
from llama_stack.apis.models import ModelInput, ModelType
|
|
from llama_stack.distribution.datatypes import Api, Provider
|
|
from llama_stack.providers.inline.agents.meta_reference import (
|
|
MetaReferenceAgentsImplConfig,
|
|
)
|
|
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
|
|
|
from ..conftest import ProviderFixture, remote_stack_fixture
|
|
|
|
|
|
def pick_inference_model(inference_model):
|
|
# This is not entirely satisfactory. The fixture `inference_model` can correspond to
|
|
# multiple models when you need to run a safety model in addition to normal agent
|
|
# inference model. We filter off the safety model by looking for "Llama-Guard"
|
|
if isinstance(inference_model, list):
|
|
inference_model = next(m for m in inference_model if "Llama-Guard" not in m)
|
|
assert inference_model is not None
|
|
return inference_model
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def agents_remote() -> ProviderFixture:
|
|
return remote_stack_fixture()
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def agents_meta_reference() -> ProviderFixture:
|
|
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
|
return ProviderFixture(
|
|
providers=[
|
|
Provider(
|
|
provider_id="meta-reference",
|
|
provider_type="inline::meta-reference",
|
|
config=MetaReferenceAgentsImplConfig(
|
|
# TODO: make this an in-memory store
|
|
persistence_store=SqliteKVStoreConfig(
|
|
db_path=sqlite_file.name,
|
|
),
|
|
).model_dump(),
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
AGENTS_FIXTURES = ["meta_reference", "remote"]
|
|
|
|
|
|
@pytest_asyncio.fixture(scope="session")
|
|
async def agents_stack(
|
|
request,
|
|
inference_model,
|
|
safety_shield,
|
|
tool_group_input_memory,
|
|
tool_group_input_tavily_search,
|
|
):
|
|
fixture_dict = request.param
|
|
|
|
providers = {}
|
|
provider_data = {}
|
|
for key in ["inference", "safety", "vector_io", "agents", "tool_runtime"]:
|
|
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
|
providers[key] = fixture.providers
|
|
if key == "inference":
|
|
providers[key].append(
|
|
Provider(
|
|
provider_id="agents_memory_provider",
|
|
provider_type="inline::sentence-transformers",
|
|
config={},
|
|
)
|
|
)
|
|
if fixture.provider_data:
|
|
provider_data.update(fixture.provider_data)
|
|
|
|
inference_models = (
|
|
inference_model if isinstance(inference_model, list) else [inference_model]
|
|
)
|
|
|
|
# NOTE: meta-reference provider needs 1 provider per model, lookup provider_id from provider config
|
|
model_to_provider_id = {}
|
|
for provider in providers["inference"]:
|
|
if "model" in provider.config:
|
|
model_to_provider_id[provider.config["model"]] = provider.provider_id
|
|
|
|
models = []
|
|
for model in inference_models:
|
|
if model in model_to_provider_id:
|
|
provider_id = model_to_provider_id[model]
|
|
else:
|
|
provider_id = providers["inference"][0].provider_id
|
|
|
|
models.append(
|
|
ModelInput(
|
|
model_id=model,
|
|
model_type=ModelType.llm,
|
|
provider_id=provider_id,
|
|
)
|
|
)
|
|
|
|
models.append(
|
|
ModelInput(
|
|
model_id="all-MiniLM-L6-v2",
|
|
model_type=ModelType.embedding,
|
|
provider_id="agents_memory_provider",
|
|
metadata={"embedding_dimension": 384},
|
|
)
|
|
)
|
|
|
|
test_stack = await construct_stack_for_test(
|
|
[Api.agents, Api.inference, Api.safety, Api.vector_io, Api.tool_runtime],
|
|
providers,
|
|
provider_data,
|
|
models=models,
|
|
shields=[safety_shield] if safety_shield else [],
|
|
tool_groups=[tool_group_input_memory, tool_group_input_tavily_search],
|
|
)
|
|
return test_stack
|