mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
Kill "remote" providers and fix testing with a remote stack properly (#435)
# What does this PR do? This PR kills the notion of "pure passthrough" remote providers. You cannot specify a single provider you must specify a whole distribution (stack) as remote. This PR also significantly fixes / upgrades testing infrastructure so you can now test against a remotely hosted stack server by just doing ```bash pytest -s -v -m remote test_agents.py \ --inference-model=Llama3.1-8B-Instruct --safety-shield=Llama-Guard-3-1B \ --env REMOTE_STACK_URL=http://localhost:5001 ``` Also fixed `test_agents_persistence.py` (which was broken) and killed some deprecated testing functions. ## Test Plan All the tests.
This commit is contained in:
parent
59a65e34d3
commit
12947ac19e
28 changed files with 406 additions and 519 deletions
|
@ -20,21 +20,17 @@ from llama_stack.providers.datatypes import RemoteProviderConfig
|
|||
_CLIENT_CLASSES = {}
|
||||
|
||||
|
||||
async def get_client_impl(
|
||||
protocol, additional_protocol, config: RemoteProviderConfig, _deps: Any
|
||||
):
|
||||
client_class = create_api_client_class(protocol, additional_protocol)
|
||||
async def get_client_impl(protocol, config: RemoteProviderConfig, _deps: Any):
|
||||
client_class = create_api_client_class(protocol)
|
||||
impl = client_class(config.url)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
def create_api_client_class(protocol, additional_protocol) -> Type:
|
||||
def create_api_client_class(protocol) -> Type:
|
||||
if protocol in _CLIENT_CLASSES:
|
||||
return _CLIENT_CLASSES[protocol]
|
||||
|
||||
protocols = [protocol, additional_protocol] if additional_protocol else [protocol]
|
||||
|
||||
class APIClient:
|
||||
def __init__(self, base_url: str):
|
||||
print(f"({protocol.__name__}) Connecting to {base_url}")
|
||||
|
@ -42,11 +38,10 @@ def create_api_client_class(protocol, additional_protocol) -> Type:
|
|||
self.routes = {}
|
||||
|
||||
# Store routes for this protocol
|
||||
for p in protocols:
|
||||
for name, method in inspect.getmembers(p):
|
||||
if hasattr(method, "__webmethod__"):
|
||||
sig = inspect.signature(method)
|
||||
self.routes[name] = (method.__webmethod__, sig)
|
||||
for name, method in inspect.getmembers(protocol):
|
||||
if hasattr(method, "__webmethod__"):
|
||||
sig = inspect.signature(method)
|
||||
self.routes[name] = (method.__webmethod__, sig)
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
@ -160,17 +155,16 @@ def create_api_client_class(protocol, additional_protocol) -> Type:
|
|||
return ret
|
||||
|
||||
# Add protocol methods to the wrapper
|
||||
for p in protocols:
|
||||
for name, method in inspect.getmembers(p):
|
||||
if hasattr(method, "__webmethod__"):
|
||||
for name, method in inspect.getmembers(protocol):
|
||||
if hasattr(method, "__webmethod__"):
|
||||
|
||||
async def method_impl(self, *args, method_name=name, **kwargs):
|
||||
return await self.__acall__(method_name, *args, **kwargs)
|
||||
async def method_impl(self, *args, method_name=name, **kwargs):
|
||||
return await self.__acall__(method_name, *args, **kwargs)
|
||||
|
||||
method_impl.__name__ = name
|
||||
method_impl.__qualname__ = f"APIClient.{name}"
|
||||
method_impl.__signature__ = inspect.signature(method)
|
||||
setattr(APIClient, name, method_impl)
|
||||
method_impl.__name__ = name
|
||||
method_impl.__qualname__ = f"APIClient.{name}"
|
||||
method_impl.__signature__ = inspect.signature(method)
|
||||
setattr(APIClient, name, method_impl)
|
||||
|
||||
# Name the class after the protocol
|
||||
APIClient.__name__ = f"{protocol.__name__}Client"
|
||||
|
|
|
@ -9,7 +9,7 @@ from typing import Dict, List
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec, remote_provider_spec
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
|
||||
def stack_apis() -> List[Api]:
|
||||
|
@ -62,9 +62,6 @@ def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
|
|||
for api in providable_apis():
|
||||
name = api.name.lower()
|
||||
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
||||
ret[api] = {
|
||||
"remote": remote_provider_spec(api),
|
||||
**{a.provider_type: a for a in module.available_providers()},
|
||||
}
|
||||
ret[api] = {a.provider_type: a for a in module.available_providers()}
|
||||
|
||||
return ret
|
||||
|
|
|
@ -28,6 +28,7 @@ from llama_stack.apis.scoring import Scoring
|
|||
from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||
from llama_stack.apis.shields import Shields
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
from llama_stack.distribution.client import get_client_impl
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
|
@ -59,12 +60,16 @@ def api_protocol_map() -> Dict[Api, Any]:
|
|||
|
||||
def additional_protocols_map() -> Dict[Api, Any]:
|
||||
return {
|
||||
Api.inference: (ModelsProtocolPrivate, Models),
|
||||
Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks),
|
||||
Api.safety: (ShieldsProtocolPrivate, Shields),
|
||||
Api.datasetio: (DatasetsProtocolPrivate, Datasets),
|
||||
Api.scoring: (ScoringFunctionsProtocolPrivate, ScoringFunctions),
|
||||
Api.eval_tasks: (EvalTasksProtocolPrivate, EvalTasks),
|
||||
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
|
||||
Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks, Api.memory_banks),
|
||||
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
|
||||
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
|
||||
Api.scoring: (
|
||||
ScoringFunctionsProtocolPrivate,
|
||||
ScoringFunctions,
|
||||
Api.scoring_functions,
|
||||
),
|
||||
Api.eval: (EvalTasksProtocolPrivate, EvalTasks, Api.eval_tasks),
|
||||
}
|
||||
|
||||
|
||||
|
@ -73,10 +78,13 @@ class ProviderWithSpec(Provider):
|
|||
spec: ProviderSpec
|
||||
|
||||
|
||||
ProviderRegistry = Dict[Api, Dict[str, ProviderSpec]]
|
||||
|
||||
|
||||
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
|
||||
async def resolve_impls(
|
||||
run_config: StackRunConfig,
|
||||
provider_registry: Dict[Api, Dict[str, ProviderSpec]],
|
||||
provider_registry: ProviderRegistry,
|
||||
dist_registry: DistributionRegistry,
|
||||
) -> Dict[Api, Any]:
|
||||
"""
|
||||
|
@ -273,17 +281,8 @@ async def instantiate_provider(
|
|||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
config = config_type(**provider.config)
|
||||
|
||||
if provider_spec.adapter:
|
||||
method = "get_adapter_impl"
|
||||
args = [config, deps]
|
||||
else:
|
||||
method = "get_client_impl"
|
||||
protocol = protocols[provider_spec.api]
|
||||
if provider_spec.api in additional_protocols:
|
||||
_, additional_protocol = additional_protocols[provider_spec.api]
|
||||
else:
|
||||
additional_protocol = None
|
||||
args = [protocol, additional_protocol, config, deps]
|
||||
method = "get_adapter_impl"
|
||||
args = [config, deps]
|
||||
|
||||
elif isinstance(provider_spec, AutoRoutedProviderSpec):
|
||||
method = "get_auto_router_impl"
|
||||
|
@ -313,7 +312,7 @@ async def instantiate_provider(
|
|||
not isinstance(provider_spec, AutoRoutedProviderSpec)
|
||||
and provider_spec.api in additional_protocols
|
||||
):
|
||||
additional_api, _ = additional_protocols[provider_spec.api]
|
||||
additional_api, _, _ = additional_protocols[provider_spec.api]
|
||||
check_protocol_compliance(impl, additional_api)
|
||||
|
||||
return impl
|
||||
|
@ -359,3 +358,29 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
|||
raise ValueError(
|
||||
f"Provider `{obj.__provider_id__} ({obj.__provider_spec__.api})` does not implement the following methods:\n{missing_methods}"
|
||||
)
|
||||
|
||||
|
||||
async def resolve_remote_stack_impls(
|
||||
config: RemoteProviderConfig,
|
||||
apis: List[str],
|
||||
) -> Dict[Api, Any]:
|
||||
protocols = api_protocol_map()
|
||||
additional_protocols = additional_protocols_map()
|
||||
|
||||
impls = {}
|
||||
for api_str in apis:
|
||||
api = Api(api_str)
|
||||
impls[api] = await get_client_impl(
|
||||
protocols[api],
|
||||
config,
|
||||
{},
|
||||
)
|
||||
if api in additional_protocols:
|
||||
_, additional_protocol, additional_api = additional_protocols[api]
|
||||
impls[additional_api] = await get_client_impl(
|
||||
additional_protocol,
|
||||
config,
|
||||
{},
|
||||
)
|
||||
|
||||
return impls
|
||||
|
|
|
@ -33,28 +33,20 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
|
|||
|
||||
api = get_impl_api(p)
|
||||
|
||||
if obj.provider_id == "remote":
|
||||
# TODO: this is broken right now because we use the generic
|
||||
# { identifier, provider_id, provider_resource_id } tuple here
|
||||
# but the APIs expect things like ModelInput, ShieldInput, etc.
|
||||
|
||||
# if this is just a passthrough, we want to let the remote
|
||||
# end actually do the registration with the correct provider
|
||||
obj = obj.model_copy(deep=True)
|
||||
obj.provider_id = ""
|
||||
assert obj.provider_id != "remote", "Remote provider should not be registered"
|
||||
|
||||
if api == Api.inference:
|
||||
return await p.register_model(obj)
|
||||
elif api == Api.safety:
|
||||
await p.register_shield(obj)
|
||||
return await p.register_shield(obj)
|
||||
elif api == Api.memory:
|
||||
await p.register_memory_bank(obj)
|
||||
return await p.register_memory_bank(obj)
|
||||
elif api == Api.datasetio:
|
||||
await p.register_dataset(obj)
|
||||
return await p.register_dataset(obj)
|
||||
elif api == Api.scoring:
|
||||
await p.register_scoring_function(obj)
|
||||
return await p.register_scoring_function(obj)
|
||||
elif api == Api.eval:
|
||||
await p.register_eval_task(obj)
|
||||
return await p.register_eval_task(obj)
|
||||
else:
|
||||
raise ValueError(f"Unknown API {api} for registering object with provider")
|
||||
|
||||
|
@ -82,15 +74,10 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
if cls is None:
|
||||
obj.provider_id = provider_id
|
||||
else:
|
||||
if provider_id == "remote":
|
||||
# if this is just a passthrough, we got the *WithProvider object
|
||||
# so we should just override the provider in-place
|
||||
obj.provider_id = provider_id
|
||||
else:
|
||||
# Create a copy of the model data and explicitly set provider_id
|
||||
model_data = obj.model_dump()
|
||||
model_data["provider_id"] = provider_id
|
||||
obj = cls(**model_data)
|
||||
# Create a copy of the model data and explicitly set provider_id
|
||||
model_data = obj.model_dump()
|
||||
model_data["provider_id"] = provider_id
|
||||
obj = cls(**model_data)
|
||||
await self.dist_registry.register(obj)
|
||||
|
||||
# Register all objects from providers
|
||||
|
@ -100,18 +87,14 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
p.model_store = self
|
||||
elif api == Api.safety:
|
||||
p.shield_store = self
|
||||
|
||||
elif api == Api.memory:
|
||||
p.memory_bank_store = self
|
||||
|
||||
elif api == Api.datasetio:
|
||||
p.dataset_store = self
|
||||
|
||||
elif api == Api.scoring:
|
||||
p.scoring_function_store = self
|
||||
scoring_functions = await p.list_scoring_functions()
|
||||
await add_objects(scoring_functions, pid, ScoringFn)
|
||||
|
||||
elif api == Api.eval:
|
||||
p.eval_task_store = self
|
||||
|
||||
|
|
|
@ -182,15 +182,6 @@ async def lifespan(app: FastAPI):
|
|||
await impl.shutdown()
|
||||
|
||||
|
||||
def create_dynamic_passthrough(
|
||||
downstream_url: str, downstream_headers: Optional[Dict[str, str]] = None
|
||||
):
|
||||
async def endpoint(request: Request):
|
||||
return await passthrough(request, downstream_url, downstream_headers)
|
||||
|
||||
return endpoint
|
||||
|
||||
|
||||
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
||||
# TODO: pass the api method and punt it to the Protocol definition directly
|
||||
return kwargs.get("stream", False)
|
||||
|
@ -305,28 +296,19 @@ def main(
|
|||
endpoints = all_endpoints[api]
|
||||
impl = impls[api]
|
||||
|
||||
if is_passthrough(impl.__provider_spec__):
|
||||
for endpoint in endpoints:
|
||||
url = impl.__provider_config__.url.rstrip("/") + endpoint.route
|
||||
getattr(app, endpoint.method)(endpoint.route)(
|
||||
create_dynamic_passthrough(url)
|
||||
)
|
||||
else:
|
||||
for endpoint in endpoints:
|
||||
if not hasattr(impl, endpoint.name):
|
||||
# ideally this should be a typing violation already
|
||||
raise ValueError(
|
||||
f"Could not find method {endpoint.name} on {impl}!!"
|
||||
)
|
||||
for endpoint in endpoints:
|
||||
if not hasattr(impl, endpoint.name):
|
||||
# ideally this should be a typing violation already
|
||||
raise ValueError(f"Could not find method {endpoint.name} on {impl}!!")
|
||||
|
||||
impl_method = getattr(impl, endpoint.name)
|
||||
impl_method = getattr(impl, endpoint.name)
|
||||
|
||||
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
||||
create_dynamic_typed_route(
|
||||
impl_method,
|
||||
endpoint.method,
|
||||
)
|
||||
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
||||
create_dynamic_typed_route(
|
||||
impl_method,
|
||||
endpoint.method,
|
||||
)
|
||||
)
|
||||
|
||||
cprint(f"Serving API {api_str}", "white", attrs=["bold"])
|
||||
for endpoint in endpoints:
|
||||
|
|
|
@ -30,7 +30,7 @@ from llama_stack.apis.eval_tasks import * # noqa: F403
|
|||
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.resolver import resolve_impls
|
||||
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
|
||||
from llama_stack.distribution.store.registry import create_dist_registry
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
|
@ -58,29 +58,23 @@ class LlamaStack(
|
|||
pass
|
||||
|
||||
|
||||
# Produces a stack of providers for the given run config. Not all APIs may be
|
||||
# asked for in the run config.
|
||||
async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]:
|
||||
dist_registry, _ = await create_dist_registry(
|
||||
run_config.metadata_store, run_config.image_name
|
||||
)
|
||||
RESOURCES = [
|
||||
("models", Api.models, "register_model", "list_models"),
|
||||
("shields", Api.shields, "register_shield", "list_shields"),
|
||||
("memory_banks", Api.memory_banks, "register_memory_bank", "list_memory_banks"),
|
||||
("datasets", Api.datasets, "register_dataset", "list_datasets"),
|
||||
(
|
||||
"scoring_fns",
|
||||
Api.scoring_functions,
|
||||
"register_scoring_function",
|
||||
"list_scoring_functions",
|
||||
),
|
||||
("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"),
|
||||
]
|
||||
|
||||
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
|
||||
|
||||
resources = [
|
||||
("models", Api.models, "register_model", "list_models"),
|
||||
("shields", Api.shields, "register_shield", "list_shields"),
|
||||
("memory_banks", Api.memory_banks, "register_memory_bank", "list_memory_banks"),
|
||||
("datasets", Api.datasets, "register_dataset", "list_datasets"),
|
||||
(
|
||||
"scoring_fns",
|
||||
Api.scoring_functions,
|
||||
"register_scoring_function",
|
||||
"list_scoring_functions",
|
||||
),
|
||||
("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"),
|
||||
]
|
||||
for rsrc, api, register_method, list_method in resources:
|
||||
async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
|
||||
for rsrc, api, register_method, list_method in RESOURCES:
|
||||
objects = getattr(run_config, rsrc)
|
||||
if api not in impls:
|
||||
continue
|
||||
|
@ -96,4 +90,18 @@ async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]:
|
|||
)
|
||||
|
||||
print("")
|
||||
|
||||
|
||||
# Produces a stack of providers for the given run config. Not all APIs may be
|
||||
# asked for in the run config.
|
||||
async def construct_stack(
|
||||
run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None
|
||||
) -> Dict[Api, Any]:
|
||||
dist_registry, _ = await create_dist_registry(
|
||||
run_config.metadata_store, run_config.image_name
|
||||
)
|
||||
impls = await resolve_impls(
|
||||
run_config, provider_registry or get_provider_registry(), dist_registry
|
||||
)
|
||||
await register_resources(run_config, impls)
|
||||
return impls
|
||||
|
|
|
@ -99,6 +99,7 @@ class RoutingTable(Protocol):
|
|||
def get_provider_impl(self, routing_key: str) -> Any: ...
|
||||
|
||||
|
||||
# TODO: this can now be inlined into RemoteProviderSpec
|
||||
@json_schema_type
|
||||
class AdapterSpec(BaseModel):
|
||||
adapter_type: str = Field(
|
||||
|
@ -171,12 +172,10 @@ class RemoteProviderConfig(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class RemoteProviderSpec(ProviderSpec):
|
||||
adapter: Optional[AdapterSpec] = Field(
|
||||
default=None,
|
||||
adapter: AdapterSpec = Field(
|
||||
description="""
|
||||
If some code is needed to convert the remote responses into Llama Stack compatible
|
||||
API responses, specify the adapter here. If not specified, it indicates the remote
|
||||
as being "Llama Stack compatible"
|
||||
API responses, specify the adapter here.
|
||||
""",
|
||||
)
|
||||
|
||||
|
@ -186,38 +185,21 @@ as being "Llama Stack compatible"
|
|||
|
||||
@property
|
||||
def module(self) -> str:
|
||||
if self.adapter:
|
||||
return self.adapter.module
|
||||
return "llama_stack.distribution.client"
|
||||
return self.adapter.module
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> List[str]:
|
||||
if self.adapter:
|
||||
return self.adapter.pip_packages
|
||||
return []
|
||||
return self.adapter.pip_packages
|
||||
|
||||
@property
|
||||
def provider_data_validator(self) -> Optional[str]:
|
||||
if self.adapter:
|
||||
return self.adapter.provider_data_validator
|
||||
return None
|
||||
return self.adapter.provider_data_validator
|
||||
|
||||
|
||||
def is_passthrough(spec: ProviderSpec) -> bool:
|
||||
return isinstance(spec, RemoteProviderSpec) and spec.adapter is None
|
||||
|
||||
|
||||
# Can avoid this by using Pydantic computed_field
|
||||
def remote_provider_spec(
|
||||
api: Api, adapter: Optional[AdapterSpec] = None
|
||||
) -> RemoteProviderSpec:
|
||||
config_class = (
|
||||
adapter.config_class
|
||||
if adapter and adapter.config_class
|
||||
else "llama_stack.distribution.datatypes.RemoteProviderConfig"
|
||||
)
|
||||
provider_type = f"remote::{adapter.adapter_type}" if adapter else "remote"
|
||||
|
||||
def remote_provider_spec(api: Api, adapter: AdapterSpec) -> RemoteProviderSpec:
|
||||
return RemoteProviderSpec(
|
||||
api=api, provider_type=provider_type, config_class=config_class, adapter=adapter
|
||||
api=api,
|
||||
provider_type=f"remote::{adapter.adapter_type}",
|
||||
config_class=adapter.config_class,
|
||||
adapter=adapter,
|
||||
)
|
||||
|
|
|
@ -234,7 +234,7 @@ class LlamaGuardShield:
|
|||
# TODO: llama-stack inference protocol has issues with non-streaming inference code
|
||||
content = ""
|
||||
async for chunk in await self.inference_api.chat_completion(
|
||||
model=self.model,
|
||||
model_id=self.model,
|
||||
messages=[shield_input_message],
|
||||
stream=True,
|
||||
):
|
||||
|
|
|
@ -53,6 +53,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
adapter_type="chromadb",
|
||||
pip_packages=EMBEDDING_DEPS + ["chromadb-client"],
|
||||
module="llama_stack.providers.remote.memory.chroma",
|
||||
config_class="llama_stack.distribution.datatypes.RemoteProviderConfig",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
|
|
|
@ -164,7 +164,6 @@ class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPriva
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
print(f"model={model}")
|
||||
request = ChatCompletionRequest(
|
||||
model=model.provider_resource_id,
|
||||
messages=messages,
|
||||
|
|
|
@ -10,7 +10,7 @@ from ..conftest import get_provider_fixture_overrides
|
|||
|
||||
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||
from ..memory.fixtures import MEMORY_FIXTURES
|
||||
from ..safety.fixtures import SAFETY_FIXTURES
|
||||
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
|
||||
from .fixtures import AGENTS_FIXTURES
|
||||
|
||||
|
||||
|
@ -46,6 +46,16 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
id="together",
|
||||
marks=pytest.mark.together,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "fireworks",
|
||||
"safety": "llama_guard",
|
||||
"memory": "faiss",
|
||||
"agents": "meta_reference",
|
||||
},
|
||||
id="fireworks",
|
||||
marks=pytest.mark.fireworks,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "remote",
|
||||
|
@ -60,7 +70,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
|
||||
|
||||
def pytest_configure(config):
|
||||
for mark in ["meta_reference", "ollama", "together", "remote"]:
|
||||
for mark in ["meta_reference", "ollama", "together", "fireworks", "remote"]:
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
f"{mark}: marks tests as {mark} specific",
|
||||
|
@ -75,28 +85,30 @@ def pytest_addoption(parser):
|
|||
help="Specify the inference model to use for testing",
|
||||
)
|
||||
parser.addoption(
|
||||
"--safety-model",
|
||||
"--safety-shield",
|
||||
action="store",
|
||||
default="Llama-Guard-3-8B",
|
||||
help="Specify the safety model to use for testing",
|
||||
help="Specify the safety shield to use for testing",
|
||||
)
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
safety_model = metafunc.config.getoption("--safety-model")
|
||||
if "safety_model" in metafunc.fixturenames:
|
||||
shield_id = metafunc.config.getoption("--safety-shield")
|
||||
if "safety_shield" in metafunc.fixturenames:
|
||||
metafunc.parametrize(
|
||||
"safety_model",
|
||||
[pytest.param(safety_model, id="")],
|
||||
"safety_shield",
|
||||
[pytest.param(shield_id, id="")],
|
||||
indirect=True,
|
||||
)
|
||||
if "inference_model" in metafunc.fixturenames:
|
||||
inference_model = metafunc.config.getoption("--inference-model")
|
||||
models = list(set({inference_model, safety_model}))
|
||||
models = set({inference_model})
|
||||
if safety_model := safety_model_from_shield(shield_id):
|
||||
models.add(safety_model)
|
||||
|
||||
metafunc.parametrize(
|
||||
"inference_model",
|
||||
[pytest.param(models, id="")],
|
||||
[pytest.param(list(models), id="")],
|
||||
indirect=True,
|
||||
)
|
||||
if "agents_stack" in metafunc.fixturenames:
|
||||
|
|
|
@ -16,10 +16,9 @@ from llama_stack.providers.inline.agents.meta_reference import (
|
|||
MetaReferenceAgentsImplConfig,
|
||||
)
|
||||
|
||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
from ..safety.fixtures import get_shield_to_register
|
||||
|
||||
|
||||
def pick_inference_model(inference_model):
|
||||
|
@ -60,7 +59,7 @@ AGENTS_FIXTURES = ["meta_reference", "remote"]
|
|||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def agents_stack(request, inference_model, safety_model):
|
||||
async def agents_stack(request, inference_model, safety_shield):
|
||||
fixture_dict = request.param
|
||||
|
||||
providers = {}
|
||||
|
@ -71,13 +70,10 @@ async def agents_stack(request, inference_model, safety_model):
|
|||
if fixture.provider_data:
|
||||
provider_data.update(fixture.provider_data)
|
||||
|
||||
shield_input = get_shield_to_register(
|
||||
providers["safety"][0].provider_type, safety_model
|
||||
)
|
||||
inference_models = (
|
||||
inference_model if isinstance(inference_model, list) else [inference_model]
|
||||
)
|
||||
impls = await resolve_impls_for_test_v2(
|
||||
test_stack = await construct_stack_for_test(
|
||||
[Api.agents, Api.inference, Api.safety, Api.memory],
|
||||
providers,
|
||||
provider_data,
|
||||
|
@ -87,6 +83,6 @@ async def agents_stack(request, inference_model, safety_model):
|
|||
)
|
||||
for model in inference_models
|
||||
],
|
||||
shields=[shield_input],
|
||||
shields=[safety_shield],
|
||||
)
|
||||
return impls[Api.agents], impls[Api.memory]
|
||||
return test_stack
|
||||
|
|
|
@ -1,148 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.agents import * # noqa: F403
|
||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
||||
from llama_stack.providers.datatypes import * # noqa: F403
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
# 1. Ensure you have a conda environment with the right dependencies installed.
|
||||
# This includes `pytest` and `pytest-asyncio`.
|
||||
#
|
||||
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
|
||||
#
|
||||
# 3. Run:
|
||||
#
|
||||
# ```bash
|
||||
# PROVIDER_ID=<your_provider> \
|
||||
# PROVIDER_CONFIG=provider_config.yaml \
|
||||
# pytest -s llama_stack/providers/tests/agents/test_agent_persistence.py \
|
||||
# --tb=short --disable-warnings
|
||||
# ```
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def agents_settings():
|
||||
impls = await resolve_impls_for_test(
|
||||
Api.agents, deps=[Api.inference, Api.memory, Api.safety]
|
||||
)
|
||||
|
||||
return {
|
||||
"impl": impls[Api.agents],
|
||||
"memory_impl": impls[Api.memory],
|
||||
"common_params": {
|
||||
"model": "Llama3.1-8B-Instruct",
|
||||
"instructions": "You are a helpful assistant.",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_messages():
|
||||
return [
|
||||
UserMessage(content="What's the weather like today?"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_agents_and_sessions(agents_settings, sample_messages):
|
||||
agents_impl = agents_settings["impl"]
|
||||
# First, create an agent
|
||||
agent_config = AgentConfig(
|
||||
model=agents_settings["common_params"]["model"],
|
||||
instructions=agents_settings["common_params"]["instructions"],
|
||||
enable_session_persistence=True,
|
||||
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
||||
input_shields=[],
|
||||
output_shields=[],
|
||||
tools=[],
|
||||
max_infer_iters=5,
|
||||
)
|
||||
|
||||
create_response = await agents_impl.create_agent(agent_config)
|
||||
agent_id = create_response.agent_id
|
||||
|
||||
# Create a session
|
||||
session_create_response = await agents_impl.create_agent_session(
|
||||
agent_id, "Test Session"
|
||||
)
|
||||
session_id = session_create_response.session_id
|
||||
persistence_store = await kvstore_impl(agents_settings["persistence"])
|
||||
|
||||
await agents_impl.delete_agents_session(agent_id, session_id)
|
||||
session_response = await persistence_store.get(f"session:{agent_id}:{session_id}")
|
||||
|
||||
await agents_impl.delete_agents(agent_id)
|
||||
agent_response = await persistence_store.get(f"agent:{agent_id}")
|
||||
|
||||
assert session_response is None
|
||||
assert agent_response is None
|
||||
|
||||
|
||||
async def test_get_agent_turns_and_steps(agents_settings, sample_messages):
|
||||
agents_impl = agents_settings["impl"]
|
||||
|
||||
# First, create an agent
|
||||
agent_config = AgentConfig(
|
||||
model=agents_settings["common_params"]["model"],
|
||||
instructions=agents_settings["common_params"]["instructions"],
|
||||
enable_session_persistence=True,
|
||||
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
||||
input_shields=[],
|
||||
output_shields=[],
|
||||
tools=[],
|
||||
max_infer_iters=5,
|
||||
)
|
||||
|
||||
create_response = await agents_impl.create_agent(agent_config)
|
||||
agent_id = create_response.agent_id
|
||||
|
||||
# Create a session
|
||||
session_create_response = await agents_impl.create_agent_session(
|
||||
agent_id, "Test Session"
|
||||
)
|
||||
session_id = session_create_response.session_id
|
||||
|
||||
# Create and execute a turn
|
||||
turn_request = dict(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=sample_messages,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
turn_response = [
|
||||
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
||||
]
|
||||
|
||||
final_event = turn_response[-1].event.payload
|
||||
turn_id = final_event.turn.turn_id
|
||||
persistence_store = await kvstore_impl(SqliteKVStoreConfig())
|
||||
turn = await persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
||||
response = await agents_impl.get_agents_turn(agent_id, session_id, turn_id)
|
||||
|
||||
assert isinstance(response, Turn)
|
||||
assert response == final_event.turn
|
||||
assert turn == final_event.turn
|
||||
|
||||
steps = final_event.turn.steps
|
||||
step_id = steps[0].step_id
|
||||
step_response = await agents_impl.get_agents_step(
|
||||
agent_id, session_id, turn_id, step_id
|
||||
)
|
||||
|
||||
assert isinstance(step_response.step, Step)
|
||||
assert step_response.step == steps[0]
|
|
@ -17,6 +17,7 @@ from llama_stack.providers.datatypes import * # noqa: F403
|
|||
# -m "meta_reference"
|
||||
|
||||
from .fixtures import pick_inference_model
|
||||
from .utils import create_agent_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -67,31 +68,19 @@ def query_attachment_messages():
|
|||
]
|
||||
|
||||
|
||||
async def create_agent_session(agents_impl, agent_config):
|
||||
create_response = await agents_impl.create_agent(agent_config)
|
||||
agent_id = create_response.agent_id
|
||||
|
||||
# Create a session
|
||||
session_create_response = await agents_impl.create_agent_session(
|
||||
agent_id, "Test Session"
|
||||
)
|
||||
session_id = session_create_response.session_id
|
||||
return agent_id, session_id
|
||||
|
||||
|
||||
class TestAgents:
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_turns_with_safety(
|
||||
self, safety_model, agents_stack, common_params
|
||||
self, safety_shield, agents_stack, common_params
|
||||
):
|
||||
agents_impl, _ = agents_stack
|
||||
agents_impl = agents_stack.impls[Api.agents]
|
||||
agent_id, session_id = await create_agent_session(
|
||||
agents_impl,
|
||||
AgentConfig(
|
||||
**{
|
||||
**common_params,
|
||||
"input_shields": [safety_model],
|
||||
"output_shields": [safety_model],
|
||||
"input_shields": [safety_shield.shield_id],
|
||||
"output_shields": [safety_shield.shield_id],
|
||||
}
|
||||
),
|
||||
)
|
||||
|
@ -127,7 +116,7 @@ class TestAgents:
|
|||
async def test_create_agent_turn(
|
||||
self, agents_stack, sample_messages, common_params
|
||||
):
|
||||
agents_impl, _ = agents_stack
|
||||
agents_impl = agents_stack.impls[Api.agents]
|
||||
|
||||
agent_id, session_id = await create_agent_session(
|
||||
agents_impl, AgentConfig(**common_params)
|
||||
|
@ -158,7 +147,7 @@ class TestAgents:
|
|||
query_attachment_messages,
|
||||
common_params,
|
||||
):
|
||||
agents_impl, _ = agents_stack
|
||||
agents_impl = agents_stack.impls[Api.agents]
|
||||
urls = [
|
||||
"memory_optimizations.rst",
|
||||
"chat.rst",
|
||||
|
@ -226,7 +215,7 @@ class TestAgents:
|
|||
async def test_create_agent_turn_with_brave_search(
|
||||
self, agents_stack, search_query_messages, common_params
|
||||
):
|
||||
agents_impl, _ = agents_stack
|
||||
agents_impl = agents_stack.impls[Api.agents]
|
||||
|
||||
if "BRAVE_SEARCH_API_KEY" not in os.environ:
|
||||
pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test")
|
||||
|
|
122
llama_stack/providers/tests/agents/test_persistence.py
Normal file
122
llama_stack/providers/tests/agents/test_persistence.py
Normal file
|
@ -0,0 +1,122 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
|
||||
from .fixtures import pick_inference_model
|
||||
|
||||
from .utils import create_agent_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_messages():
|
||||
return [
|
||||
UserMessage(content="What's the weather like today?"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def common_params(inference_model):
|
||||
inference_model = pick_inference_model(inference_model)
|
||||
|
||||
return dict(
|
||||
model=inference_model,
|
||||
instructions="You are a helpful assistant.",
|
||||
enable_session_persistence=True,
|
||||
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
||||
input_shields=[],
|
||||
output_shields=[],
|
||||
tools=[],
|
||||
max_infer_iters=5,
|
||||
)
|
||||
|
||||
|
||||
class TestAgentPersistence:
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_agents_and_sessions(self, agents_stack, common_params):
|
||||
agents_impl = agents_stack.impls[Api.agents]
|
||||
agent_id, session_id = await create_agent_session(
|
||||
agents_impl,
|
||||
AgentConfig(
|
||||
**{
|
||||
**common_params,
|
||||
"input_shields": [],
|
||||
"output_shields": [],
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
run_config = agents_stack.run_config
|
||||
provider_config = run_config.providers["agents"][0].config
|
||||
persistence_store = await kvstore_impl(
|
||||
SqliteKVStoreConfig(**provider_config["persistence_store"])
|
||||
)
|
||||
|
||||
await agents_impl.delete_agents_session(agent_id, session_id)
|
||||
session_response = await persistence_store.get(
|
||||
f"session:{agent_id}:{session_id}"
|
||||
)
|
||||
|
||||
await agents_impl.delete_agents(agent_id)
|
||||
agent_response = await persistence_store.get(f"agent:{agent_id}")
|
||||
|
||||
assert session_response is None
|
||||
assert agent_response is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_agent_turns_and_steps(
|
||||
self, agents_stack, sample_messages, common_params
|
||||
):
|
||||
agents_impl = agents_stack.impls[Api.agents]
|
||||
|
||||
agent_id, session_id = await create_agent_session(
|
||||
agents_impl,
|
||||
AgentConfig(
|
||||
**{
|
||||
**common_params,
|
||||
"input_shields": [],
|
||||
"output_shields": [],
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
# Create and execute a turn
|
||||
turn_request = dict(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=sample_messages,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
turn_response = [
|
||||
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
||||
]
|
||||
|
||||
final_event = turn_response[-1].event.payload
|
||||
turn_id = final_event.turn.turn_id
|
||||
|
||||
provider_config = agents_stack.run_config.providers["agents"][0].config
|
||||
persistence_store = await kvstore_impl(
|
||||
SqliteKVStoreConfig(**provider_config["persistence_store"])
|
||||
)
|
||||
turn = await persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
||||
response = await agents_impl.get_agents_turn(agent_id, session_id, turn_id)
|
||||
|
||||
assert isinstance(response, Turn)
|
||||
assert response == final_event.turn
|
||||
assert turn == final_event.turn.model_dump_json()
|
||||
|
||||
steps = final_event.turn.steps
|
||||
step_id = steps[0].step_id
|
||||
step_response = await agents_impl.get_agents_step(
|
||||
agent_id, session_id, turn_id, step_id
|
||||
)
|
||||
|
||||
assert step_response.step == steps[0]
|
17
llama_stack/providers/tests/agents/utils.py
Normal file
17
llama_stack/providers/tests/agents/utils.py
Normal file
|
@ -0,0 +1,17 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
async def create_agent_session(agents_impl, agent_config):
|
||||
create_response = await agents_impl.create_agent(agent_config)
|
||||
agent_id = create_response.agent_id
|
||||
|
||||
# Create a session
|
||||
session_create_response = await agents_impl.create_agent_session(
|
||||
agent_id, "Test Session"
|
||||
)
|
||||
session_id = session_create_response.session_id
|
||||
return agent_id, session_id
|
|
@ -35,8 +35,8 @@ def remote_stack_fixture() -> ProviderFixture:
|
|||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="remote",
|
||||
provider_type="remote",
|
||||
provider_id="test::remote",
|
||||
provider_type="test::remote",
|
||||
config=config.model_dump(),
|
||||
)
|
||||
],
|
||||
|
|
|
@ -9,7 +9,7 @@ import pytest_asyncio
|
|||
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
|
||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
|
||||
|
||||
|
@ -52,10 +52,10 @@ async def datasetio_stack(request):
|
|||
fixture_name = request.param
|
||||
fixture = request.getfixturevalue(f"datasetio_{fixture_name}")
|
||||
|
||||
impls = await resolve_impls_for_test_v2(
|
||||
test_stack = await construct_stack_for_test(
|
||||
[Api.datasetio],
|
||||
{"datasetio": fixture.providers},
|
||||
fixture.provider_data,
|
||||
)
|
||||
|
||||
return impls[Api.datasetio], impls[Api.datasets]
|
||||
return test_stack.impls[Api.datasetio], test_stack.impls[Api.datasets]
|
||||
|
|
|
@ -9,7 +9,7 @@ import pytest_asyncio
|
|||
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
|
||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
|
||||
|
||||
|
@ -46,10 +46,10 @@ async def eval_stack(request):
|
|||
if fixture.provider_data:
|
||||
provider_data.update(fixture.provider_data)
|
||||
|
||||
impls = await resolve_impls_for_test_v2(
|
||||
test_stack = await construct_stack_for_test(
|
||||
[Api.eval, Api.datasetio, Api.inference, Api.scoring],
|
||||
providers,
|
||||
provider_data,
|
||||
)
|
||||
|
||||
return impls
|
||||
return test_stack.impls
|
||||
|
|
|
@ -21,7 +21,7 @@ from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
|
|||
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
||||
from llama_stack.providers.remote.inference.together import TogetherImplConfig
|
||||
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
|
||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
from ..env import get_env_or_fail
|
||||
|
@ -182,15 +182,11 @@ INFERENCE_FIXTURES = [
|
|||
async def inference_stack(request, inference_model):
|
||||
fixture_name = request.param
|
||||
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
|
||||
impls = await resolve_impls_for_test_v2(
|
||||
test_stack = await construct_stack_for_test(
|
||||
[Api.inference],
|
||||
{"inference": inference_fixture.providers},
|
||||
inference_fixture.provider_data,
|
||||
models=[
|
||||
ModelInput(
|
||||
model_id=inference_model,
|
||||
)
|
||||
],
|
||||
models=[ModelInput(model_id=inference_model)],
|
||||
)
|
||||
|
||||
return (impls[Api.inference], impls[Api.models])
|
||||
return test_stack.impls[Api.inference], test_stack.impls[Api.models]
|
||||
|
|
|
@ -147,9 +147,9 @@ class TestInference:
|
|||
|
||||
user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003."
|
||||
response = await inference_impl.completion(
|
||||
model_id=inference_model,
|
||||
content=user_input,
|
||||
stream=False,
|
||||
model=inference_model,
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=50,
|
||||
),
|
||||
|
|
|
@ -55,7 +55,7 @@ class TestVisionModelInference:
|
|||
)
|
||||
|
||||
response = await inference_impl.chat_completion(
|
||||
model=inference_model,
|
||||
model_id=inference_model,
|
||||
messages=[
|
||||
UserMessage(content="You are a helpful assistant."),
|
||||
UserMessage(content=[image, "Describe this image in two sentences."]),
|
||||
|
@ -102,7 +102,7 @@ class TestVisionModelInference:
|
|||
response = [
|
||||
r
|
||||
async for r in await inference_impl.chat_completion(
|
||||
model=inference_model,
|
||||
model_id=inference_model,
|
||||
messages=[
|
||||
UserMessage(content="You are a helpful assistant."),
|
||||
UserMessage(
|
||||
|
|
|
@ -14,7 +14,7 @@ from llama_stack.distribution.datatypes import Api, Provider, RemoteProviderConf
|
|||
from llama_stack.providers.inline.memory.faiss import FaissImplConfig
|
||||
from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
|
||||
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig
|
||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
from ..env import get_env_or_fail
|
||||
|
@ -101,10 +101,10 @@ async def memory_stack(request):
|
|||
fixture_name = request.param
|
||||
fixture = request.getfixturevalue(f"memory_{fixture_name}")
|
||||
|
||||
impls = await resolve_impls_for_test_v2(
|
||||
test_stack = await construct_stack_for_test(
|
||||
[Api.memory],
|
||||
{"memory": fixture.providers},
|
||||
fixture.provider_data,
|
||||
)
|
||||
|
||||
return impls[Api.memory], impls[Api.memory_banks]
|
||||
return test_stack.impls[Api.memory], test_stack.impls[Api.memory_banks]
|
||||
|
|
|
@ -5,33 +5,36 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.build import print_pip_install_help
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.resolver import resolve_remote_stack_impls
|
||||
from llama_stack.distribution.stack import construct_stack
|
||||
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
|
||||
|
||||
|
||||
async def resolve_impls_for_test_v2(
|
||||
class TestStack(BaseModel):
|
||||
impls: Dict[Api, Any]
|
||||
run_config: StackRunConfig
|
||||
|
||||
|
||||
async def construct_stack_for_test(
|
||||
apis: List[Api],
|
||||
providers: Dict[str, List[Provider]],
|
||||
provider_data: Optional[Dict[str, Any]] = None,
|
||||
models: Optional[List[Model]] = None,
|
||||
shields: Optional[List[Shield]] = None,
|
||||
memory_banks: Optional[List[MemoryBank]] = None,
|
||||
datasets: Optional[List[Dataset]] = None,
|
||||
scoring_fns: Optional[List[ScoringFn]] = None,
|
||||
eval_tasks: Optional[List[EvalTask]] = None,
|
||||
):
|
||||
models: Optional[List[ModelInput]] = None,
|
||||
shields: Optional[List[ShieldInput]] = None,
|
||||
memory_banks: Optional[List[MemoryBankInput]] = None,
|
||||
datasets: Optional[List[DatasetInput]] = None,
|
||||
scoring_fns: Optional[List[ScoringFnInput]] = None,
|
||||
eval_tasks: Optional[List[EvalTaskInput]] = None,
|
||||
) -> TestStack:
|
||||
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||
run_config = dict(
|
||||
built_at=datetime.now(),
|
||||
|
@ -48,7 +51,18 @@ async def resolve_impls_for_test_v2(
|
|||
)
|
||||
run_config = parse_and_maybe_upgrade_config(run_config)
|
||||
try:
|
||||
impls = await construct_stack(run_config)
|
||||
remote_config = remote_provider_config(run_config)
|
||||
if not remote_config:
|
||||
# TODO: add to provider registry by creating interesting mocks or fakes
|
||||
impls = await construct_stack(run_config, get_provider_registry())
|
||||
else:
|
||||
# we don't register resources for a remote stack as part of the fixture setup
|
||||
# because the stack is already "up". if a test needs to register resources, it
|
||||
# can do so manually always.
|
||||
|
||||
impls = await resolve_remote_stack_impls(remote_config, run_config.apis)
|
||||
|
||||
test_stack = TestStack(impls=impls, run_config=run_config)
|
||||
except ModuleNotFoundError as e:
|
||||
print_pip_install_help(providers)
|
||||
raise e
|
||||
|
@ -58,91 +72,22 @@ async def resolve_impls_for_test_v2(
|
|||
{"X-LlamaStack-ProviderData": json.dumps(provider_data)}
|
||||
)
|
||||
|
||||
return impls
|
||||
return test_stack
|
||||
|
||||
|
||||
async def resolve_impls_for_test(api: Api, deps: List[Api] = None):
|
||||
if "PROVIDER_CONFIG" not in os.environ:
|
||||
raise ValueError(
|
||||
"You must set PROVIDER_CONFIG to a YAML file containing provider config"
|
||||
)
|
||||
def remote_provider_config(
|
||||
run_config: StackRunConfig,
|
||||
) -> Optional[RemoteProviderConfig]:
|
||||
remote_config = None
|
||||
has_non_remote = False
|
||||
for api_providers in run_config.providers.values():
|
||||
for provider in api_providers:
|
||||
if provider.provider_type == "test::remote":
|
||||
remote_config = RemoteProviderConfig(**provider.config)
|
||||
else:
|
||||
has_non_remote = True
|
||||
|
||||
with open(os.environ["PROVIDER_CONFIG"], "r") as f:
|
||||
config_dict = yaml.safe_load(f)
|
||||
if remote_config:
|
||||
assert not has_non_remote, "Remote stack cannot have non-remote providers"
|
||||
|
||||
providers = read_providers(api, config_dict)
|
||||
|
||||
chosen = choose_providers(providers, api, deps)
|
||||
run_config = dict(
|
||||
built_at=datetime.now(),
|
||||
image_name="test-fixture",
|
||||
apis=[api] + (deps or []),
|
||||
providers=chosen,
|
||||
)
|
||||
run_config = parse_and_maybe_upgrade_config(run_config)
|
||||
try:
|
||||
impls = await resolve_impls(run_config, get_provider_registry())
|
||||
except ModuleNotFoundError as e:
|
||||
print_pip_install_help(providers)
|
||||
raise e
|
||||
|
||||
if "provider_data" in config_dict:
|
||||
provider_id = chosen[api.value][0].provider_id
|
||||
provider_data = config_dict["provider_data"].get(provider_id, {})
|
||||
if provider_data:
|
||||
set_request_provider_data(
|
||||
{"X-LlamaStack-ProviderData": json.dumps(provider_data)}
|
||||
)
|
||||
|
||||
return impls
|
||||
|
||||
|
||||
def read_providers(api: Api, config_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if "providers" not in config_dict:
|
||||
raise ValueError("Config file should contain a `providers` key")
|
||||
|
||||
providers = config_dict["providers"]
|
||||
if isinstance(providers, dict):
|
||||
return providers
|
||||
elif isinstance(providers, list):
|
||||
return {
|
||||
api.value: providers,
|
||||
}
|
||||
else:
|
||||
raise ValueError(
|
||||
"Config file should contain a list of providers or dict(api to providers)"
|
||||
)
|
||||
|
||||
|
||||
def choose_providers(
|
||||
providers: Dict[str, Any], api: Api, deps: List[Api] = None
|
||||
) -> Dict[str, Provider]:
|
||||
chosen = {}
|
||||
if api.value not in providers:
|
||||
raise ValueError(f"No providers found for `{api}`?")
|
||||
chosen[api.value] = [pick_provider(api, providers[api.value], "PROVIDER_ID")]
|
||||
|
||||
for dep in deps or []:
|
||||
if dep.value not in providers:
|
||||
raise ValueError(f"No providers specified for `{dep}` in config?")
|
||||
chosen[dep.value] = [Provider(**x) for x in providers[dep.value]]
|
||||
|
||||
return chosen
|
||||
|
||||
|
||||
def pick_provider(api: Api, providers: List[Any], key: str) -> Provider:
|
||||
providers_by_id = {x["provider_id"]: x for x in providers}
|
||||
if len(providers_by_id) == 0:
|
||||
raise ValueError(f"No providers found for `{api}` in config file")
|
||||
|
||||
if key in os.environ:
|
||||
provider_id = os.environ[key]
|
||||
if provider_id not in providers_by_id:
|
||||
raise ValueError(f"Provider ID {provider_id} not found in config file")
|
||||
provider = providers_by_id[provider_id]
|
||||
else:
|
||||
provider = list(providers_by_id.values())[0]
|
||||
provider_id = provider["provider_id"]
|
||||
print(f"No provider ID specified, picking first `{provider_id}`")
|
||||
|
||||
return Provider(**provider)
|
||||
return remote_config
|
||||
|
|
|
@ -66,14 +66,14 @@ def pytest_configure(config):
|
|||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--safety-model",
|
||||
"--safety-shield",
|
||||
action="store",
|
||||
default=None,
|
||||
help="Specify the safety model to use for testing",
|
||||
help="Specify the safety shield to use for testing",
|
||||
)
|
||||
|
||||
|
||||
SAFETY_MODEL_PARAMS = [
|
||||
SAFETY_SHIELD_PARAMS = [
|
||||
pytest.param("Llama-Guard-3-1B", marks=pytest.mark.guard_1b, id="guard_1b"),
|
||||
]
|
||||
|
||||
|
@ -83,13 +83,13 @@ def pytest_generate_tests(metafunc):
|
|||
# But a user can also pass in a custom combination via the CLI by doing
|
||||
# `--providers inference=together,safety=meta_reference`
|
||||
|
||||
if "safety_model" in metafunc.fixturenames:
|
||||
model = metafunc.config.getoption("--safety-model")
|
||||
if model:
|
||||
params = [pytest.param(model, id="")]
|
||||
if "safety_shield" in metafunc.fixturenames:
|
||||
shield_id = metafunc.config.getoption("--safety-shield")
|
||||
if shield_id:
|
||||
params = [pytest.param(shield_id, id="")]
|
||||
else:
|
||||
params = SAFETY_MODEL_PARAMS
|
||||
for fixture in ["inference_model", "safety_model"]:
|
||||
params = SAFETY_SHIELD_PARAMS
|
||||
for fixture in ["inference_model", "safety_shield"]:
|
||||
metafunc.parametrize(
|
||||
fixture,
|
||||
params,
|
||||
|
|
|
@ -16,7 +16,7 @@ from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
|
|||
from llama_stack.providers.inline.safety.prompt_guard import PromptGuardConfig
|
||||
from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig
|
||||
|
||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
from ..env import get_env_or_fail
|
||||
|
@ -27,19 +27,38 @@ def safety_remote() -> ProviderFixture:
|
|||
return remote_stack_fixture()
|
||||
|
||||
|
||||
def safety_model_from_shield(shield_id):
|
||||
if shield_id in ("Bedrock", "CodeScanner", "CodeShield"):
|
||||
return None
|
||||
|
||||
return shield_id
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def safety_model(request):
|
||||
def safety_shield(request):
|
||||
if hasattr(request, "param"):
|
||||
return request.param
|
||||
return request.config.getoption("--safety-model", None)
|
||||
shield_id = request.param
|
||||
else:
|
||||
shield_id = request.config.getoption("--safety-shield", None)
|
||||
|
||||
if shield_id == "bedrock":
|
||||
shield_id = get_env_or_fail("BEDROCK_GUARDRAIL_IDENTIFIER")
|
||||
params = {"guardrailVersion": get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")}
|
||||
else:
|
||||
params = {}
|
||||
|
||||
return ShieldInput(
|
||||
shield_id=shield_id,
|
||||
params=params,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def safety_llama_guard(safety_model) -> ProviderFixture:
|
||||
def safety_llama_guard() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="inline::llama-guard",
|
||||
provider_id="llama-guard",
|
||||
provider_type="inline::llama-guard",
|
||||
config=LlamaGuardConfig().model_dump(),
|
||||
)
|
||||
|
@ -55,7 +74,7 @@ def safety_prompt_guard() -> ProviderFixture:
|
|||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="inline::prompt-guard",
|
||||
provider_id="prompt-guard",
|
||||
provider_type="inline::prompt-guard",
|
||||
config=PromptGuardConfig().model_dump(),
|
||||
)
|
||||
|
@ -80,50 +99,25 @@ SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote"]
|
|||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def safety_stack(inference_model, safety_model, request):
|
||||
async def safety_stack(inference_model, safety_shield, request):
|
||||
# We need an inference + safety fixture to test safety
|
||||
fixture_dict = request.param
|
||||
inference_fixture = request.getfixturevalue(
|
||||
f"inference_{fixture_dict['inference']}"
|
||||
)
|
||||
safety_fixture = request.getfixturevalue(f"safety_{fixture_dict['safety']}")
|
||||
|
||||
providers = {
|
||||
"inference": inference_fixture.providers,
|
||||
"safety": safety_fixture.providers,
|
||||
}
|
||||
providers = {}
|
||||
provider_data = {}
|
||||
if inference_fixture.provider_data:
|
||||
provider_data.update(inference_fixture.provider_data)
|
||||
if safety_fixture.provider_data:
|
||||
provider_data.update(safety_fixture.provider_data)
|
||||
for key in ["inference", "safety"]:
|
||||
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||
providers[key] = fixture.providers
|
||||
if fixture.provider_data:
|
||||
provider_data.update(fixture.provider_data)
|
||||
|
||||
shield_provider_type = safety_fixture.providers[0].provider_type
|
||||
shield_input = get_shield_to_register(shield_provider_type, safety_model)
|
||||
|
||||
print(f"inference_model: {inference_model}")
|
||||
print(f"shield_input = {shield_input}")
|
||||
impls = await resolve_impls_for_test_v2(
|
||||
test_stack = await construct_stack_for_test(
|
||||
[Api.safety, Api.shields, Api.inference],
|
||||
providers,
|
||||
provider_data,
|
||||
models=[ModelInput(model_id=inference_model)],
|
||||
shields=[shield_input],
|
||||
shields=[safety_shield],
|
||||
)
|
||||
|
||||
shield = await impls[Api.shields].get_shield(shield_input.shield_id)
|
||||
return impls[Api.safety], impls[Api.shields], shield
|
||||
|
||||
|
||||
def get_shield_to_register(provider_type: str, safety_model: str) -> ShieldInput:
|
||||
if provider_type == "remote::bedrock":
|
||||
identifier = get_env_or_fail("BEDROCK_GUARDRAIL_IDENTIFIER")
|
||||
params = {"guardrailVersion": get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")}
|
||||
else:
|
||||
params = {}
|
||||
identifier = safety_model
|
||||
|
||||
return ShieldInput(
|
||||
shield_id=identifier,
|
||||
params=params,
|
||||
)
|
||||
shield = await test_stack.impls[Api.shields].get_shield(safety_shield.shield_id)
|
||||
return test_stack.impls[Api.safety], test_stack.impls[Api.shields], shield
|
||||
|
|
|
@ -18,13 +18,6 @@ from llama_stack.distribution.datatypes import * # noqa: F403
|
|||
|
||||
|
||||
class TestSafety:
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_shield(self, safety_stack):
|
||||
_, shields_impl, shield = safety_stack
|
||||
assert shield is not None
|
||||
assert shield.provider_resource_id == shield.identifier
|
||||
assert shield.provider_id is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shield_list(self, safety_stack):
|
||||
_, shields_impl, _ = safety_stack
|
||||
|
|
|
@ -11,7 +11,7 @@ from llama_stack.apis.models import ModelInput
|
|||
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
|
||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
|
||||
|
||||
|
@ -74,7 +74,7 @@ async def scoring_stack(request, inference_model):
|
|||
if fixture.provider_data:
|
||||
provider_data.update(fixture.provider_data)
|
||||
|
||||
impls = await resolve_impls_for_test_v2(
|
||||
test_stack = await construct_stack_for_test(
|
||||
[Api.scoring, Api.datasetio, Api.inference],
|
||||
providers,
|
||||
provider_data,
|
||||
|
@ -88,4 +88,4 @@ async def scoring_stack(request, inference_model):
|
|||
],
|
||||
)
|
||||
|
||||
return impls
|
||||
return test_stack.impls
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue