Kill "remote" providers and fix testing with a remote stack properly (#435)

# What does this PR do?

This PR kills the notion of "pure passthrough" remote providers. You
cannot specify a single provider you must specify a whole distribution
(stack) as remote.

This PR also significantly fixes / upgrades testing infrastructure so
you can now test against a remotely hosted stack server by just doing

```bash
pytest -s -v -m remote  test_agents.py \
  --inference-model=Llama3.1-8B-Instruct --safety-shield=Llama-Guard-3-1B \
  --env REMOTE_STACK_URL=http://localhost:5001
```

Also fixed `test_agents_persistence.py` (which was broken) and killed
some deprecated testing functions.

## Test Plan

All the tests.
This commit is contained in:
Ashwin Bharambe 2024-11-12 21:51:29 -08:00 committed by GitHub
parent 59a65e34d3
commit 12947ac19e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
28 changed files with 406 additions and 519 deletions

View file

@ -20,21 +20,17 @@ from llama_stack.providers.datatypes import RemoteProviderConfig
_CLIENT_CLASSES = {}
async def get_client_impl(
protocol, additional_protocol, config: RemoteProviderConfig, _deps: Any
):
client_class = create_api_client_class(protocol, additional_protocol)
async def get_client_impl(protocol, config: RemoteProviderConfig, _deps: Any):
client_class = create_api_client_class(protocol)
impl = client_class(config.url)
await impl.initialize()
return impl
def create_api_client_class(protocol, additional_protocol) -> Type:
def create_api_client_class(protocol) -> Type:
if protocol in _CLIENT_CLASSES:
return _CLIENT_CLASSES[protocol]
protocols = [protocol, additional_protocol] if additional_protocol else [protocol]
class APIClient:
def __init__(self, base_url: str):
print(f"({protocol.__name__}) Connecting to {base_url}")
@ -42,11 +38,10 @@ def create_api_client_class(protocol, additional_protocol) -> Type:
self.routes = {}
# Store routes for this protocol
for p in protocols:
for name, method in inspect.getmembers(p):
if hasattr(method, "__webmethod__"):
sig = inspect.signature(method)
self.routes[name] = (method.__webmethod__, sig)
for name, method in inspect.getmembers(protocol):
if hasattr(method, "__webmethod__"):
sig = inspect.signature(method)
self.routes[name] = (method.__webmethod__, sig)
async def initialize(self):
pass
@ -160,17 +155,16 @@ def create_api_client_class(protocol, additional_protocol) -> Type:
return ret
# Add protocol methods to the wrapper
for p in protocols:
for name, method in inspect.getmembers(p):
if hasattr(method, "__webmethod__"):
for name, method in inspect.getmembers(protocol):
if hasattr(method, "__webmethod__"):
async def method_impl(self, *args, method_name=name, **kwargs):
return await self.__acall__(method_name, *args, **kwargs)
async def method_impl(self, *args, method_name=name, **kwargs):
return await self.__acall__(method_name, *args, **kwargs)
method_impl.__name__ = name
method_impl.__qualname__ = f"APIClient.{name}"
method_impl.__signature__ = inspect.signature(method)
setattr(APIClient, name, method_impl)
method_impl.__name__ = name
method_impl.__qualname__ = f"APIClient.{name}"
method_impl.__signature__ = inspect.signature(method)
setattr(APIClient, name, method_impl)
# Name the class after the protocol
APIClient.__name__ = f"{protocol.__name__}Client"

View file

@ -9,7 +9,7 @@ from typing import Dict, List
from pydantic import BaseModel
from llama_stack.providers.datatypes import Api, ProviderSpec, remote_provider_spec
from llama_stack.providers.datatypes import Api, ProviderSpec
def stack_apis() -> List[Api]:
@ -62,9 +62,6 @@ def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
for api in providable_apis():
name = api.name.lower()
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
ret[api] = {
"remote": remote_provider_spec(api),
**{a.provider_type: a for a in module.available_providers()},
}
ret[api] = {a.provider_type: a for a in module.available_providers()}
return ret

View file

@ -28,6 +28,7 @@ from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from llama_stack.distribution.client import get_client_impl
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
@ -59,12 +60,16 @@ def api_protocol_map() -> Dict[Api, Any]:
def additional_protocols_map() -> Dict[Api, Any]:
return {
Api.inference: (ModelsProtocolPrivate, Models),
Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks),
Api.safety: (ShieldsProtocolPrivate, Shields),
Api.datasetio: (DatasetsProtocolPrivate, Datasets),
Api.scoring: (ScoringFunctionsProtocolPrivate, ScoringFunctions),
Api.eval_tasks: (EvalTasksProtocolPrivate, EvalTasks),
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks, Api.memory_banks),
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
Api.scoring: (
ScoringFunctionsProtocolPrivate,
ScoringFunctions,
Api.scoring_functions,
),
Api.eval: (EvalTasksProtocolPrivate, EvalTasks, Api.eval_tasks),
}
@ -73,10 +78,13 @@ class ProviderWithSpec(Provider):
spec: ProviderSpec
ProviderRegistry = Dict[Api, Dict[str, ProviderSpec]]
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
async def resolve_impls(
run_config: StackRunConfig,
provider_registry: Dict[Api, Dict[str, ProviderSpec]],
provider_registry: ProviderRegistry,
dist_registry: DistributionRegistry,
) -> Dict[Api, Any]:
"""
@ -273,17 +281,8 @@ async def instantiate_provider(
config_type = instantiate_class_type(provider_spec.config_class)
config = config_type(**provider.config)
if provider_spec.adapter:
method = "get_adapter_impl"
args = [config, deps]
else:
method = "get_client_impl"
protocol = protocols[provider_spec.api]
if provider_spec.api in additional_protocols:
_, additional_protocol = additional_protocols[provider_spec.api]
else:
additional_protocol = None
args = [protocol, additional_protocol, config, deps]
method = "get_adapter_impl"
args = [config, deps]
elif isinstance(provider_spec, AutoRoutedProviderSpec):
method = "get_auto_router_impl"
@ -313,7 +312,7 @@ async def instantiate_provider(
not isinstance(provider_spec, AutoRoutedProviderSpec)
and provider_spec.api in additional_protocols
):
additional_api, _ = additional_protocols[provider_spec.api]
additional_api, _, _ = additional_protocols[provider_spec.api]
check_protocol_compliance(impl, additional_api)
return impl
@ -359,3 +358,29 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
raise ValueError(
f"Provider `{obj.__provider_id__} ({obj.__provider_spec__.api})` does not implement the following methods:\n{missing_methods}"
)
async def resolve_remote_stack_impls(
config: RemoteProviderConfig,
apis: List[str],
) -> Dict[Api, Any]:
protocols = api_protocol_map()
additional_protocols = additional_protocols_map()
impls = {}
for api_str in apis:
api = Api(api_str)
impls[api] = await get_client_impl(
protocols[api],
config,
{},
)
if api in additional_protocols:
_, additional_protocol, additional_api = additional_protocols[api]
impls[additional_api] = await get_client_impl(
additional_protocol,
config,
{},
)
return impls

View file

@ -33,28 +33,20 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
api = get_impl_api(p)
if obj.provider_id == "remote":
# TODO: this is broken right now because we use the generic
# { identifier, provider_id, provider_resource_id } tuple here
# but the APIs expect things like ModelInput, ShieldInput, etc.
# if this is just a passthrough, we want to let the remote
# end actually do the registration with the correct provider
obj = obj.model_copy(deep=True)
obj.provider_id = ""
assert obj.provider_id != "remote", "Remote provider should not be registered"
if api == Api.inference:
return await p.register_model(obj)
elif api == Api.safety:
await p.register_shield(obj)
return await p.register_shield(obj)
elif api == Api.memory:
await p.register_memory_bank(obj)
return await p.register_memory_bank(obj)
elif api == Api.datasetio:
await p.register_dataset(obj)
return await p.register_dataset(obj)
elif api == Api.scoring:
await p.register_scoring_function(obj)
return await p.register_scoring_function(obj)
elif api == Api.eval:
await p.register_eval_task(obj)
return await p.register_eval_task(obj)
else:
raise ValueError(f"Unknown API {api} for registering object with provider")
@ -82,15 +74,10 @@ class CommonRoutingTableImpl(RoutingTable):
if cls is None:
obj.provider_id = provider_id
else:
if provider_id == "remote":
# if this is just a passthrough, we got the *WithProvider object
# so we should just override the provider in-place
obj.provider_id = provider_id
else:
# Create a copy of the model data and explicitly set provider_id
model_data = obj.model_dump()
model_data["provider_id"] = provider_id
obj = cls(**model_data)
# Create a copy of the model data and explicitly set provider_id
model_data = obj.model_dump()
model_data["provider_id"] = provider_id
obj = cls(**model_data)
await self.dist_registry.register(obj)
# Register all objects from providers
@ -100,18 +87,14 @@ class CommonRoutingTableImpl(RoutingTable):
p.model_store = self
elif api == Api.safety:
p.shield_store = self
elif api == Api.memory:
p.memory_bank_store = self
elif api == Api.datasetio:
p.dataset_store = self
elif api == Api.scoring:
p.scoring_function_store = self
scoring_functions = await p.list_scoring_functions()
await add_objects(scoring_functions, pid, ScoringFn)
elif api == Api.eval:
p.eval_task_store = self

View file

@ -182,15 +182,6 @@ async def lifespan(app: FastAPI):
await impl.shutdown()
def create_dynamic_passthrough(
downstream_url: str, downstream_headers: Optional[Dict[str, str]] = None
):
async def endpoint(request: Request):
return await passthrough(request, downstream_url, downstream_headers)
return endpoint
def is_streaming_request(func_name: str, request: Request, **kwargs):
# TODO: pass the api method and punt it to the Protocol definition directly
return kwargs.get("stream", False)
@ -305,28 +296,19 @@ def main(
endpoints = all_endpoints[api]
impl = impls[api]
if is_passthrough(impl.__provider_spec__):
for endpoint in endpoints:
url = impl.__provider_config__.url.rstrip("/") + endpoint.route
getattr(app, endpoint.method)(endpoint.route)(
create_dynamic_passthrough(url)
)
else:
for endpoint in endpoints:
if not hasattr(impl, endpoint.name):
# ideally this should be a typing violation already
raise ValueError(
f"Could not find method {endpoint.name} on {impl}!!"
)
for endpoint in endpoints:
if not hasattr(impl, endpoint.name):
# ideally this should be a typing violation already
raise ValueError(f"Could not find method {endpoint.name} on {impl}!!")
impl_method = getattr(impl, endpoint.name)
impl_method = getattr(impl, endpoint.name)
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
create_dynamic_typed_route(
impl_method,
endpoint.method,
)
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
create_dynamic_typed_route(
impl_method,
endpoint.method,
)
)
cprint(f"Serving API {api_str}", "white", attrs=["bold"])
for endpoint in endpoints:

View file

@ -30,7 +30,7 @@ from llama_stack.apis.eval_tasks import * # noqa: F403
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import resolve_impls
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.providers.datatypes import Api
@ -58,29 +58,23 @@ class LlamaStack(
pass
# Produces a stack of providers for the given run config. Not all APIs may be
# asked for in the run config.
async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]:
dist_registry, _ = await create_dist_registry(
run_config.metadata_store, run_config.image_name
)
RESOURCES = [
("models", Api.models, "register_model", "list_models"),
("shields", Api.shields, "register_shield", "list_shields"),
("memory_banks", Api.memory_banks, "register_memory_bank", "list_memory_banks"),
("datasets", Api.datasets, "register_dataset", "list_datasets"),
(
"scoring_fns",
Api.scoring_functions,
"register_scoring_function",
"list_scoring_functions",
),
("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"),
]
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
resources = [
("models", Api.models, "register_model", "list_models"),
("shields", Api.shields, "register_shield", "list_shields"),
("memory_banks", Api.memory_banks, "register_memory_bank", "list_memory_banks"),
("datasets", Api.datasets, "register_dataset", "list_datasets"),
(
"scoring_fns",
Api.scoring_functions,
"register_scoring_function",
"list_scoring_functions",
),
("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"),
]
for rsrc, api, register_method, list_method in resources:
async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
for rsrc, api, register_method, list_method in RESOURCES:
objects = getattr(run_config, rsrc)
if api not in impls:
continue
@ -96,4 +90,18 @@ async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]:
)
print("")
# Produces a stack of providers for the given run config. Not all APIs may be
# asked for in the run config.
async def construct_stack(
run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None
) -> Dict[Api, Any]:
dist_registry, _ = await create_dist_registry(
run_config.metadata_store, run_config.image_name
)
impls = await resolve_impls(
run_config, provider_registry or get_provider_registry(), dist_registry
)
await register_resources(run_config, impls)
return impls

View file

@ -99,6 +99,7 @@ class RoutingTable(Protocol):
def get_provider_impl(self, routing_key: str) -> Any: ...
# TODO: this can now be inlined into RemoteProviderSpec
@json_schema_type
class AdapterSpec(BaseModel):
adapter_type: str = Field(
@ -171,12 +172,10 @@ class RemoteProviderConfig(BaseModel):
@json_schema_type
class RemoteProviderSpec(ProviderSpec):
adapter: Optional[AdapterSpec] = Field(
default=None,
adapter: AdapterSpec = Field(
description="""
If some code is needed to convert the remote responses into Llama Stack compatible
API responses, specify the adapter here. If not specified, it indicates the remote
as being "Llama Stack compatible"
API responses, specify the adapter here.
""",
)
@ -186,38 +185,21 @@ as being "Llama Stack compatible"
@property
def module(self) -> str:
if self.adapter:
return self.adapter.module
return "llama_stack.distribution.client"
return self.adapter.module
@property
def pip_packages(self) -> List[str]:
if self.adapter:
return self.adapter.pip_packages
return []
return self.adapter.pip_packages
@property
def provider_data_validator(self) -> Optional[str]:
if self.adapter:
return self.adapter.provider_data_validator
return None
return self.adapter.provider_data_validator
def is_passthrough(spec: ProviderSpec) -> bool:
return isinstance(spec, RemoteProviderSpec) and spec.adapter is None
# Can avoid this by using Pydantic computed_field
def remote_provider_spec(
api: Api, adapter: Optional[AdapterSpec] = None
) -> RemoteProviderSpec:
config_class = (
adapter.config_class
if adapter and adapter.config_class
else "llama_stack.distribution.datatypes.RemoteProviderConfig"
)
provider_type = f"remote::{adapter.adapter_type}" if adapter else "remote"
def remote_provider_spec(api: Api, adapter: AdapterSpec) -> RemoteProviderSpec:
return RemoteProviderSpec(
api=api, provider_type=provider_type, config_class=config_class, adapter=adapter
api=api,
provider_type=f"remote::{adapter.adapter_type}",
config_class=adapter.config_class,
adapter=adapter,
)

View file

@ -234,7 +234,7 @@ class LlamaGuardShield:
# TODO: llama-stack inference protocol has issues with non-streaming inference code
content = ""
async for chunk in await self.inference_api.chat_completion(
model=self.model,
model_id=self.model,
messages=[shield_input_message],
stream=True,
):

View file

@ -53,6 +53,7 @@ def available_providers() -> List[ProviderSpec]:
adapter_type="chromadb",
pip_packages=EMBEDDING_DEPS + ["chromadb-client"],
module="llama_stack.providers.remote.memory.chroma",
config_class="llama_stack.distribution.datatypes.RemoteProviderConfig",
),
),
remote_provider_spec(

View file

@ -164,7 +164,6 @@ class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPriva
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
print(f"model={model}")
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,

View file

@ -10,7 +10,7 @@ from ..conftest import get_provider_fixture_overrides
from ..inference.fixtures import INFERENCE_FIXTURES
from ..memory.fixtures import MEMORY_FIXTURES
from ..safety.fixtures import SAFETY_FIXTURES
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
from .fixtures import AGENTS_FIXTURES
@ -46,6 +46,16 @@ DEFAULT_PROVIDER_COMBINATIONS = [
id="together",
marks=pytest.mark.together,
),
pytest.param(
{
"inference": "fireworks",
"safety": "llama_guard",
"memory": "faiss",
"agents": "meta_reference",
},
id="fireworks",
marks=pytest.mark.fireworks,
),
pytest.param(
{
"inference": "remote",
@ -60,7 +70,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
def pytest_configure(config):
for mark in ["meta_reference", "ollama", "together", "remote"]:
for mark in ["meta_reference", "ollama", "together", "fireworks", "remote"]:
config.addinivalue_line(
"markers",
f"{mark}: marks tests as {mark} specific",
@ -75,28 +85,30 @@ def pytest_addoption(parser):
help="Specify the inference model to use for testing",
)
parser.addoption(
"--safety-model",
"--safety-shield",
action="store",
default="Llama-Guard-3-8B",
help="Specify the safety model to use for testing",
help="Specify the safety shield to use for testing",
)
def pytest_generate_tests(metafunc):
safety_model = metafunc.config.getoption("--safety-model")
if "safety_model" in metafunc.fixturenames:
shield_id = metafunc.config.getoption("--safety-shield")
if "safety_shield" in metafunc.fixturenames:
metafunc.parametrize(
"safety_model",
[pytest.param(safety_model, id="")],
"safety_shield",
[pytest.param(shield_id, id="")],
indirect=True,
)
if "inference_model" in metafunc.fixturenames:
inference_model = metafunc.config.getoption("--inference-model")
models = list(set({inference_model, safety_model}))
models = set({inference_model})
if safety_model := safety_model_from_shield(shield_id):
models.add(safety_model)
metafunc.parametrize(
"inference_model",
[pytest.param(models, id="")],
[pytest.param(list(models), id="")],
indirect=True,
)
if "agents_stack" in metafunc.fixturenames:

View file

@ -16,10 +16,9 @@ from llama_stack.providers.inline.agents.meta_reference import (
MetaReferenceAgentsImplConfig,
)
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from llama_stack.providers.tests.resolver import construct_stack_for_test
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from ..conftest import ProviderFixture, remote_stack_fixture
from ..safety.fixtures import get_shield_to_register
def pick_inference_model(inference_model):
@ -60,7 +59,7 @@ AGENTS_FIXTURES = ["meta_reference", "remote"]
@pytest_asyncio.fixture(scope="session")
async def agents_stack(request, inference_model, safety_model):
async def agents_stack(request, inference_model, safety_shield):
fixture_dict = request.param
providers = {}
@ -71,13 +70,10 @@ async def agents_stack(request, inference_model, safety_model):
if fixture.provider_data:
provider_data.update(fixture.provider_data)
shield_input = get_shield_to_register(
providers["safety"][0].provider_type, safety_model
)
inference_models = (
inference_model if isinstance(inference_model, list) else [inference_model]
)
impls = await resolve_impls_for_test_v2(
test_stack = await construct_stack_for_test(
[Api.agents, Api.inference, Api.safety, Api.memory],
providers,
provider_data,
@ -87,6 +83,6 @@ async def agents_stack(request, inference_model, safety_model):
)
for model in inference_models
],
shields=[shield_input],
shields=[safety_shield],
)
return impls[Api.agents], impls[Api.memory]
return test_stack

View file

@ -1,148 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
import pytest_asyncio
from llama_stack.apis.agents import * # noqa: F403
from llama_stack.providers.tests.resolver import resolve_impls_for_test
from llama_stack.providers.datatypes import * # noqa: F403
from dotenv import load_dotenv
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
# How to run this test:
#
# 1. Ensure you have a conda environment with the right dependencies installed.
# This includes `pytest` and `pytest-asyncio`.
#
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
#
# 3. Run:
#
# ```bash
# PROVIDER_ID=<your_provider> \
# PROVIDER_CONFIG=provider_config.yaml \
# pytest -s llama_stack/providers/tests/agents/test_agent_persistence.py \
# --tb=short --disable-warnings
# ```
load_dotenv()
@pytest_asyncio.fixture(scope="session")
async def agents_settings():
impls = await resolve_impls_for_test(
Api.agents, deps=[Api.inference, Api.memory, Api.safety]
)
return {
"impl": impls[Api.agents],
"memory_impl": impls[Api.memory],
"common_params": {
"model": "Llama3.1-8B-Instruct",
"instructions": "You are a helpful assistant.",
},
}
@pytest.fixture
def sample_messages():
return [
UserMessage(content="What's the weather like today?"),
]
@pytest.mark.asyncio
async def test_delete_agents_and_sessions(agents_settings, sample_messages):
agents_impl = agents_settings["impl"]
# First, create an agent
agent_config = AgentConfig(
model=agents_settings["common_params"]["model"],
instructions=agents_settings["common_params"]["instructions"],
enable_session_persistence=True,
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
input_shields=[],
output_shields=[],
tools=[],
max_infer_iters=5,
)
create_response = await agents_impl.create_agent(agent_config)
agent_id = create_response.agent_id
# Create a session
session_create_response = await agents_impl.create_agent_session(
agent_id, "Test Session"
)
session_id = session_create_response.session_id
persistence_store = await kvstore_impl(agents_settings["persistence"])
await agents_impl.delete_agents_session(agent_id, session_id)
session_response = await persistence_store.get(f"session:{agent_id}:{session_id}")
await agents_impl.delete_agents(agent_id)
agent_response = await persistence_store.get(f"agent:{agent_id}")
assert session_response is None
assert agent_response is None
async def test_get_agent_turns_and_steps(agents_settings, sample_messages):
agents_impl = agents_settings["impl"]
# First, create an agent
agent_config = AgentConfig(
model=agents_settings["common_params"]["model"],
instructions=agents_settings["common_params"]["instructions"],
enable_session_persistence=True,
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
input_shields=[],
output_shields=[],
tools=[],
max_infer_iters=5,
)
create_response = await agents_impl.create_agent(agent_config)
agent_id = create_response.agent_id
# Create a session
session_create_response = await agents_impl.create_agent_session(
agent_id, "Test Session"
)
session_id = session_create_response.session_id
# Create and execute a turn
turn_request = dict(
agent_id=agent_id,
session_id=session_id,
messages=sample_messages,
stream=True,
)
turn_response = [
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
]
final_event = turn_response[-1].event.payload
turn_id = final_event.turn.turn_id
persistence_store = await kvstore_impl(SqliteKVStoreConfig())
turn = await persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
response = await agents_impl.get_agents_turn(agent_id, session_id, turn_id)
assert isinstance(response, Turn)
assert response == final_event.turn
assert turn == final_event.turn
steps = final_event.turn.steps
step_id = steps[0].step_id
step_response = await agents_impl.get_agents_step(
agent_id, session_id, turn_id, step_id
)
assert isinstance(step_response.step, Step)
assert step_response.step == steps[0]

View file

@ -17,6 +17,7 @@ from llama_stack.providers.datatypes import * # noqa: F403
# -m "meta_reference"
from .fixtures import pick_inference_model
from .utils import create_agent_session
@pytest.fixture
@ -67,31 +68,19 @@ def query_attachment_messages():
]
async def create_agent_session(agents_impl, agent_config):
create_response = await agents_impl.create_agent(agent_config)
agent_id = create_response.agent_id
# Create a session
session_create_response = await agents_impl.create_agent_session(
agent_id, "Test Session"
)
session_id = session_create_response.session_id
return agent_id, session_id
class TestAgents:
@pytest.mark.asyncio
async def test_agent_turns_with_safety(
self, safety_model, agents_stack, common_params
self, safety_shield, agents_stack, common_params
):
agents_impl, _ = agents_stack
agents_impl = agents_stack.impls[Api.agents]
agent_id, session_id = await create_agent_session(
agents_impl,
AgentConfig(
**{
**common_params,
"input_shields": [safety_model],
"output_shields": [safety_model],
"input_shields": [safety_shield.shield_id],
"output_shields": [safety_shield.shield_id],
}
),
)
@ -127,7 +116,7 @@ class TestAgents:
async def test_create_agent_turn(
self, agents_stack, sample_messages, common_params
):
agents_impl, _ = agents_stack
agents_impl = agents_stack.impls[Api.agents]
agent_id, session_id = await create_agent_session(
agents_impl, AgentConfig(**common_params)
@ -158,7 +147,7 @@ class TestAgents:
query_attachment_messages,
common_params,
):
agents_impl, _ = agents_stack
agents_impl = agents_stack.impls[Api.agents]
urls = [
"memory_optimizations.rst",
"chat.rst",
@ -226,7 +215,7 @@ class TestAgents:
async def test_create_agent_turn_with_brave_search(
self, agents_stack, search_query_messages, common_params
):
agents_impl, _ = agents_stack
agents_impl = agents_stack.impls[Api.agents]
if "BRAVE_SEARCH_API_KEY" not in os.environ:
pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test")

View file

@ -0,0 +1,122 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.apis.agents import * # noqa: F403
from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
from .fixtures import pick_inference_model
from .utils import create_agent_session
@pytest.fixture
def sample_messages():
return [
UserMessage(content="What's the weather like today?"),
]
@pytest.fixture
def common_params(inference_model):
inference_model = pick_inference_model(inference_model)
return dict(
model=inference_model,
instructions="You are a helpful assistant.",
enable_session_persistence=True,
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
input_shields=[],
output_shields=[],
tools=[],
max_infer_iters=5,
)
class TestAgentPersistence:
@pytest.mark.asyncio
async def test_delete_agents_and_sessions(self, agents_stack, common_params):
agents_impl = agents_stack.impls[Api.agents]
agent_id, session_id = await create_agent_session(
agents_impl,
AgentConfig(
**{
**common_params,
"input_shields": [],
"output_shields": [],
}
),
)
run_config = agents_stack.run_config
provider_config = run_config.providers["agents"][0].config
persistence_store = await kvstore_impl(
SqliteKVStoreConfig(**provider_config["persistence_store"])
)
await agents_impl.delete_agents_session(agent_id, session_id)
session_response = await persistence_store.get(
f"session:{agent_id}:{session_id}"
)
await agents_impl.delete_agents(agent_id)
agent_response = await persistence_store.get(f"agent:{agent_id}")
assert session_response is None
assert agent_response is None
@pytest.mark.asyncio
async def test_get_agent_turns_and_steps(
self, agents_stack, sample_messages, common_params
):
agents_impl = agents_stack.impls[Api.agents]
agent_id, session_id = await create_agent_session(
agents_impl,
AgentConfig(
**{
**common_params,
"input_shields": [],
"output_shields": [],
}
),
)
# Create and execute a turn
turn_request = dict(
agent_id=agent_id,
session_id=session_id,
messages=sample_messages,
stream=True,
)
turn_response = [
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
]
final_event = turn_response[-1].event.payload
turn_id = final_event.turn.turn_id
provider_config = agents_stack.run_config.providers["agents"][0].config
persistence_store = await kvstore_impl(
SqliteKVStoreConfig(**provider_config["persistence_store"])
)
turn = await persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
response = await agents_impl.get_agents_turn(agent_id, session_id, turn_id)
assert isinstance(response, Turn)
assert response == final_event.turn
assert turn == final_event.turn.model_dump_json()
steps = final_event.turn.steps
step_id = steps[0].step_id
step_response = await agents_impl.get_agents_step(
agent_id, session_id, turn_id, step_id
)
assert step_response.step == steps[0]

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
async def create_agent_session(agents_impl, agent_config):
create_response = await agents_impl.create_agent(agent_config)
agent_id = create_response.agent_id
# Create a session
session_create_response = await agents_impl.create_agent_session(
agent_id, "Test Session"
)
session_id = session_create_response.session_id
return agent_id, session_id

View file

@ -35,8 +35,8 @@ def remote_stack_fixture() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="remote",
provider_type="remote",
provider_id="test::remote",
provider_type="test::remote",
config=config.model_dump(),
)
],

View file

@ -9,7 +9,7 @@ import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture, remote_stack_fixture
@ -52,10 +52,10 @@ async def datasetio_stack(request):
fixture_name = request.param
fixture = request.getfixturevalue(f"datasetio_{fixture_name}")
impls = await resolve_impls_for_test_v2(
test_stack = await construct_stack_for_test(
[Api.datasetio],
{"datasetio": fixture.providers},
fixture.provider_data,
)
return impls[Api.datasetio], impls[Api.datasets]
return test_stack.impls[Api.datasetio], test_stack.impls[Api.datasets]

View file

@ -9,7 +9,7 @@ import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture, remote_stack_fixture
@ -46,10 +46,10 @@ async def eval_stack(request):
if fixture.provider_data:
provider_data.update(fixture.provider_data)
impls = await resolve_impls_for_test_v2(
test_stack = await construct_stack_for_test(
[Api.eval, Api.datasetio, Api.inference, Api.scoring],
providers,
provider_data,
)
return impls
return test_stack.impls

View file

@ -21,7 +21,7 @@ from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
from llama_stack.providers.remote.inference.together import TogetherImplConfig
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture, remote_stack_fixture
from ..env import get_env_or_fail
@ -182,15 +182,11 @@ INFERENCE_FIXTURES = [
async def inference_stack(request, inference_model):
fixture_name = request.param
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
impls = await resolve_impls_for_test_v2(
test_stack = await construct_stack_for_test(
[Api.inference],
{"inference": inference_fixture.providers},
inference_fixture.provider_data,
models=[
ModelInput(
model_id=inference_model,
)
],
models=[ModelInput(model_id=inference_model)],
)
return (impls[Api.inference], impls[Api.models])
return test_stack.impls[Api.inference], test_stack.impls[Api.models]

View file

@ -147,9 +147,9 @@ class TestInference:
user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003."
response = await inference_impl.completion(
model_id=inference_model,
content=user_input,
stream=False,
model=inference_model,
sampling_params=SamplingParams(
max_tokens=50,
),

View file

@ -55,7 +55,7 @@ class TestVisionModelInference:
)
response = await inference_impl.chat_completion(
model=inference_model,
model_id=inference_model,
messages=[
UserMessage(content="You are a helpful assistant."),
UserMessage(content=[image, "Describe this image in two sentences."]),
@ -102,7 +102,7 @@ class TestVisionModelInference:
response = [
r
async for r in await inference_impl.chat_completion(
model=inference_model,
model_id=inference_model,
messages=[
UserMessage(content="You are a helpful assistant."),
UserMessage(

View file

@ -14,7 +14,7 @@ from llama_stack.distribution.datatypes import Api, Provider, RemoteProviderConf
from llama_stack.providers.inline.memory.faiss import FaissImplConfig
from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from llama_stack.providers.tests.resolver import construct_stack_for_test
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
from ..conftest import ProviderFixture, remote_stack_fixture
from ..env import get_env_or_fail
@ -101,10 +101,10 @@ async def memory_stack(request):
fixture_name = request.param
fixture = request.getfixturevalue(f"memory_{fixture_name}")
impls = await resolve_impls_for_test_v2(
test_stack = await construct_stack_for_test(
[Api.memory],
{"memory": fixture.providers},
fixture.provider_data,
)
return impls[Api.memory], impls[Api.memory_banks]
return test_stack.impls[Api.memory], test_stack.impls[Api.memory_banks]

View file

@ -5,33 +5,36 @@
# the root directory of this source tree.
import json
import os
import tempfile
from datetime import datetime
from typing import Any, Dict, List, Optional
import yaml
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_remote_stack_impls
from llama_stack.distribution.stack import construct_stack
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
async def resolve_impls_for_test_v2(
class TestStack(BaseModel):
impls: Dict[Api, Any]
run_config: StackRunConfig
async def construct_stack_for_test(
apis: List[Api],
providers: Dict[str, List[Provider]],
provider_data: Optional[Dict[str, Any]] = None,
models: Optional[List[Model]] = None,
shields: Optional[List[Shield]] = None,
memory_banks: Optional[List[MemoryBank]] = None,
datasets: Optional[List[Dataset]] = None,
scoring_fns: Optional[List[ScoringFn]] = None,
eval_tasks: Optional[List[EvalTask]] = None,
):
models: Optional[List[ModelInput]] = None,
shields: Optional[List[ShieldInput]] = None,
memory_banks: Optional[List[MemoryBankInput]] = None,
datasets: Optional[List[DatasetInput]] = None,
scoring_fns: Optional[List[ScoringFnInput]] = None,
eval_tasks: Optional[List[EvalTaskInput]] = None,
) -> TestStack:
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
run_config = dict(
built_at=datetime.now(),
@ -48,7 +51,18 @@ async def resolve_impls_for_test_v2(
)
run_config = parse_and_maybe_upgrade_config(run_config)
try:
impls = await construct_stack(run_config)
remote_config = remote_provider_config(run_config)
if not remote_config:
# TODO: add to provider registry by creating interesting mocks or fakes
impls = await construct_stack(run_config, get_provider_registry())
else:
# we don't register resources for a remote stack as part of the fixture setup
# because the stack is already "up". if a test needs to register resources, it
# can do so manually always.
impls = await resolve_remote_stack_impls(remote_config, run_config.apis)
test_stack = TestStack(impls=impls, run_config=run_config)
except ModuleNotFoundError as e:
print_pip_install_help(providers)
raise e
@ -58,91 +72,22 @@ async def resolve_impls_for_test_v2(
{"X-LlamaStack-ProviderData": json.dumps(provider_data)}
)
return impls
return test_stack
async def resolve_impls_for_test(api: Api, deps: List[Api] = None):
if "PROVIDER_CONFIG" not in os.environ:
raise ValueError(
"You must set PROVIDER_CONFIG to a YAML file containing provider config"
)
def remote_provider_config(
run_config: StackRunConfig,
) -> Optional[RemoteProviderConfig]:
remote_config = None
has_non_remote = False
for api_providers in run_config.providers.values():
for provider in api_providers:
if provider.provider_type == "test::remote":
remote_config = RemoteProviderConfig(**provider.config)
else:
has_non_remote = True
with open(os.environ["PROVIDER_CONFIG"], "r") as f:
config_dict = yaml.safe_load(f)
if remote_config:
assert not has_non_remote, "Remote stack cannot have non-remote providers"
providers = read_providers(api, config_dict)
chosen = choose_providers(providers, api, deps)
run_config = dict(
built_at=datetime.now(),
image_name="test-fixture",
apis=[api] + (deps or []),
providers=chosen,
)
run_config = parse_and_maybe_upgrade_config(run_config)
try:
impls = await resolve_impls(run_config, get_provider_registry())
except ModuleNotFoundError as e:
print_pip_install_help(providers)
raise e
if "provider_data" in config_dict:
provider_id = chosen[api.value][0].provider_id
provider_data = config_dict["provider_data"].get(provider_id, {})
if provider_data:
set_request_provider_data(
{"X-LlamaStack-ProviderData": json.dumps(provider_data)}
)
return impls
def read_providers(api: Api, config_dict: Dict[str, Any]) -> Dict[str, Any]:
if "providers" not in config_dict:
raise ValueError("Config file should contain a `providers` key")
providers = config_dict["providers"]
if isinstance(providers, dict):
return providers
elif isinstance(providers, list):
return {
api.value: providers,
}
else:
raise ValueError(
"Config file should contain a list of providers or dict(api to providers)"
)
def choose_providers(
providers: Dict[str, Any], api: Api, deps: List[Api] = None
) -> Dict[str, Provider]:
chosen = {}
if api.value not in providers:
raise ValueError(f"No providers found for `{api}`?")
chosen[api.value] = [pick_provider(api, providers[api.value], "PROVIDER_ID")]
for dep in deps or []:
if dep.value not in providers:
raise ValueError(f"No providers specified for `{dep}` in config?")
chosen[dep.value] = [Provider(**x) for x in providers[dep.value]]
return chosen
def pick_provider(api: Api, providers: List[Any], key: str) -> Provider:
providers_by_id = {x["provider_id"]: x for x in providers}
if len(providers_by_id) == 0:
raise ValueError(f"No providers found for `{api}` in config file")
if key in os.environ:
provider_id = os.environ[key]
if provider_id not in providers_by_id:
raise ValueError(f"Provider ID {provider_id} not found in config file")
provider = providers_by_id[provider_id]
else:
provider = list(providers_by_id.values())[0]
provider_id = provider["provider_id"]
print(f"No provider ID specified, picking first `{provider_id}`")
return Provider(**provider)
return remote_config

View file

@ -66,14 +66,14 @@ def pytest_configure(config):
def pytest_addoption(parser):
parser.addoption(
"--safety-model",
"--safety-shield",
action="store",
default=None,
help="Specify the safety model to use for testing",
help="Specify the safety shield to use for testing",
)
SAFETY_MODEL_PARAMS = [
SAFETY_SHIELD_PARAMS = [
pytest.param("Llama-Guard-3-1B", marks=pytest.mark.guard_1b, id="guard_1b"),
]
@ -83,13 +83,13 @@ def pytest_generate_tests(metafunc):
# But a user can also pass in a custom combination via the CLI by doing
# `--providers inference=together,safety=meta_reference`
if "safety_model" in metafunc.fixturenames:
model = metafunc.config.getoption("--safety-model")
if model:
params = [pytest.param(model, id="")]
if "safety_shield" in metafunc.fixturenames:
shield_id = metafunc.config.getoption("--safety-shield")
if shield_id:
params = [pytest.param(shield_id, id="")]
else:
params = SAFETY_MODEL_PARAMS
for fixture in ["inference_model", "safety_model"]:
params = SAFETY_SHIELD_PARAMS
for fixture in ["inference_model", "safety_shield"]:
metafunc.parametrize(
fixture,
params,

View file

@ -16,7 +16,7 @@ from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
from llama_stack.providers.inline.safety.prompt_guard import PromptGuardConfig
from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture, remote_stack_fixture
from ..env import get_env_or_fail
@ -27,19 +27,38 @@ def safety_remote() -> ProviderFixture:
return remote_stack_fixture()
def safety_model_from_shield(shield_id):
if shield_id in ("Bedrock", "CodeScanner", "CodeShield"):
return None
return shield_id
@pytest.fixture(scope="session")
def safety_model(request):
def safety_shield(request):
if hasattr(request, "param"):
return request.param
return request.config.getoption("--safety-model", None)
shield_id = request.param
else:
shield_id = request.config.getoption("--safety-shield", None)
if shield_id == "bedrock":
shield_id = get_env_or_fail("BEDROCK_GUARDRAIL_IDENTIFIER")
params = {"guardrailVersion": get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")}
else:
params = {}
return ShieldInput(
shield_id=shield_id,
params=params,
)
@pytest.fixture(scope="session")
def safety_llama_guard(safety_model) -> ProviderFixture:
def safety_llama_guard() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="inline::llama-guard",
provider_id="llama-guard",
provider_type="inline::llama-guard",
config=LlamaGuardConfig().model_dump(),
)
@ -55,7 +74,7 @@ def safety_prompt_guard() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="inline::prompt-guard",
provider_id="prompt-guard",
provider_type="inline::prompt-guard",
config=PromptGuardConfig().model_dump(),
)
@ -80,50 +99,25 @@ SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote"]
@pytest_asyncio.fixture(scope="session")
async def safety_stack(inference_model, safety_model, request):
async def safety_stack(inference_model, safety_shield, request):
# We need an inference + safety fixture to test safety
fixture_dict = request.param
inference_fixture = request.getfixturevalue(
f"inference_{fixture_dict['inference']}"
)
safety_fixture = request.getfixturevalue(f"safety_{fixture_dict['safety']}")
providers = {
"inference": inference_fixture.providers,
"safety": safety_fixture.providers,
}
providers = {}
provider_data = {}
if inference_fixture.provider_data:
provider_data.update(inference_fixture.provider_data)
if safety_fixture.provider_data:
provider_data.update(safety_fixture.provider_data)
for key in ["inference", "safety"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if fixture.provider_data:
provider_data.update(fixture.provider_data)
shield_provider_type = safety_fixture.providers[0].provider_type
shield_input = get_shield_to_register(shield_provider_type, safety_model)
print(f"inference_model: {inference_model}")
print(f"shield_input = {shield_input}")
impls = await resolve_impls_for_test_v2(
test_stack = await construct_stack_for_test(
[Api.safety, Api.shields, Api.inference],
providers,
provider_data,
models=[ModelInput(model_id=inference_model)],
shields=[shield_input],
shields=[safety_shield],
)
shield = await impls[Api.shields].get_shield(shield_input.shield_id)
return impls[Api.safety], impls[Api.shields], shield
def get_shield_to_register(provider_type: str, safety_model: str) -> ShieldInput:
if provider_type == "remote::bedrock":
identifier = get_env_or_fail("BEDROCK_GUARDRAIL_IDENTIFIER")
params = {"guardrailVersion": get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")}
else:
params = {}
identifier = safety_model
return ShieldInput(
shield_id=identifier,
params=params,
)
shield = await test_stack.impls[Api.shields].get_shield(safety_shield.shield_id)
return test_stack.impls[Api.safety], test_stack.impls[Api.shields], shield

View file

@ -18,13 +18,6 @@ from llama_stack.distribution.datatypes import * # noqa: F403
class TestSafety:
@pytest.mark.asyncio
async def test_new_shield(self, safety_stack):
_, shields_impl, shield = safety_stack
assert shield is not None
assert shield.provider_resource_id == shield.identifier
assert shield.provider_id is not None
@pytest.mark.asyncio
async def test_shield_list(self, safety_stack):
_, shields_impl, _ = safety_stack

View file

@ -11,7 +11,7 @@ from llama_stack.apis.models import ModelInput
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture, remote_stack_fixture
@ -74,7 +74,7 @@ async def scoring_stack(request, inference_model):
if fixture.provider_data:
provider_data.update(fixture.provider_data)
impls = await resolve_impls_for_test_v2(
test_stack = await construct_stack_for_test(
[Api.scoring, Api.datasetio, Api.inference],
providers,
provider_data,
@ -88,4 +88,4 @@ async def scoring_stack(request, inference_model):
],
)
return impls
return test_stack.impls