forked from phoenix-oss/llama-stack-mirror
Kill "remote" providers and fix testing with a remote stack properly (#435)
# What does this PR do? This PR kills the notion of "pure passthrough" remote providers. You cannot specify a single provider you must specify a whole distribution (stack) as remote. This PR also significantly fixes / upgrades testing infrastructure so you can now test against a remotely hosted stack server by just doing ```bash pytest -s -v -m remote test_agents.py \ --inference-model=Llama3.1-8B-Instruct --safety-shield=Llama-Guard-3-1B \ --env REMOTE_STACK_URL=http://localhost:5001 ``` Also fixed `test_agents_persistence.py` (which was broken) and killed some deprecated testing functions. ## Test Plan All the tests.
This commit is contained in:
parent
59a65e34d3
commit
12947ac19e
28 changed files with 406 additions and 519 deletions
|
@ -20,21 +20,17 @@ from llama_stack.providers.datatypes import RemoteProviderConfig
|
||||||
_CLIENT_CLASSES = {}
|
_CLIENT_CLASSES = {}
|
||||||
|
|
||||||
|
|
||||||
async def get_client_impl(
|
async def get_client_impl(protocol, config: RemoteProviderConfig, _deps: Any):
|
||||||
protocol, additional_protocol, config: RemoteProviderConfig, _deps: Any
|
client_class = create_api_client_class(protocol)
|
||||||
):
|
|
||||||
client_class = create_api_client_class(protocol, additional_protocol)
|
|
||||||
impl = client_class(config.url)
|
impl = client_class(config.url)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
||||||
|
|
||||||
def create_api_client_class(protocol, additional_protocol) -> Type:
|
def create_api_client_class(protocol) -> Type:
|
||||||
if protocol in _CLIENT_CLASSES:
|
if protocol in _CLIENT_CLASSES:
|
||||||
return _CLIENT_CLASSES[protocol]
|
return _CLIENT_CLASSES[protocol]
|
||||||
|
|
||||||
protocols = [protocol, additional_protocol] if additional_protocol else [protocol]
|
|
||||||
|
|
||||||
class APIClient:
|
class APIClient:
|
||||||
def __init__(self, base_url: str):
|
def __init__(self, base_url: str):
|
||||||
print(f"({protocol.__name__}) Connecting to {base_url}")
|
print(f"({protocol.__name__}) Connecting to {base_url}")
|
||||||
|
@ -42,11 +38,10 @@ def create_api_client_class(protocol, additional_protocol) -> Type:
|
||||||
self.routes = {}
|
self.routes = {}
|
||||||
|
|
||||||
# Store routes for this protocol
|
# Store routes for this protocol
|
||||||
for p in protocols:
|
for name, method in inspect.getmembers(protocol):
|
||||||
for name, method in inspect.getmembers(p):
|
if hasattr(method, "__webmethod__"):
|
||||||
if hasattr(method, "__webmethod__"):
|
sig = inspect.signature(method)
|
||||||
sig = inspect.signature(method)
|
self.routes[name] = (method.__webmethod__, sig)
|
||||||
self.routes[name] = (method.__webmethod__, sig)
|
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
pass
|
pass
|
||||||
|
@ -160,17 +155,16 @@ def create_api_client_class(protocol, additional_protocol) -> Type:
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
# Add protocol methods to the wrapper
|
# Add protocol methods to the wrapper
|
||||||
for p in protocols:
|
for name, method in inspect.getmembers(protocol):
|
||||||
for name, method in inspect.getmembers(p):
|
if hasattr(method, "__webmethod__"):
|
||||||
if hasattr(method, "__webmethod__"):
|
|
||||||
|
|
||||||
async def method_impl(self, *args, method_name=name, **kwargs):
|
async def method_impl(self, *args, method_name=name, **kwargs):
|
||||||
return await self.__acall__(method_name, *args, **kwargs)
|
return await self.__acall__(method_name, *args, **kwargs)
|
||||||
|
|
||||||
method_impl.__name__ = name
|
method_impl.__name__ = name
|
||||||
method_impl.__qualname__ = f"APIClient.{name}"
|
method_impl.__qualname__ = f"APIClient.{name}"
|
||||||
method_impl.__signature__ = inspect.signature(method)
|
method_impl.__signature__ = inspect.signature(method)
|
||||||
setattr(APIClient, name, method_impl)
|
setattr(APIClient, name, method_impl)
|
||||||
|
|
||||||
# Name the class after the protocol
|
# Name the class after the protocol
|
||||||
APIClient.__name__ = f"{protocol.__name__}Client"
|
APIClient.__name__ = f"{protocol.__name__}Client"
|
||||||
|
|
|
@ -9,7 +9,7 @@ from typing import Dict, List
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import Api, ProviderSpec, remote_provider_spec
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
|
|
||||||
def stack_apis() -> List[Api]:
|
def stack_apis() -> List[Api]:
|
||||||
|
@ -62,9 +62,6 @@ def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||||
for api in providable_apis():
|
for api in providable_apis():
|
||||||
name = api.name.lower()
|
name = api.name.lower()
|
||||||
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
||||||
ret[api] = {
|
ret[api] = {a.provider_type: a for a in module.available_providers()}
|
||||||
"remote": remote_provider_spec(api),
|
|
||||||
**{a.provider_type: a for a in module.available_providers()},
|
|
||||||
}
|
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
|
@ -28,6 +28,7 @@ from llama_stack.apis.scoring import Scoring
|
||||||
from llama_stack.apis.scoring_functions import ScoringFunctions
|
from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||||
from llama_stack.apis.shields import Shields
|
from llama_stack.apis.shields import Shields
|
||||||
from llama_stack.apis.telemetry import Telemetry
|
from llama_stack.apis.telemetry import Telemetry
|
||||||
|
from llama_stack.distribution.client import get_client_impl
|
||||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||||
from llama_stack.distribution.store import DistributionRegistry
|
from llama_stack.distribution.store import DistributionRegistry
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
|
@ -59,12 +60,16 @@ def api_protocol_map() -> Dict[Api, Any]:
|
||||||
|
|
||||||
def additional_protocols_map() -> Dict[Api, Any]:
|
def additional_protocols_map() -> Dict[Api, Any]:
|
||||||
return {
|
return {
|
||||||
Api.inference: (ModelsProtocolPrivate, Models),
|
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
|
||||||
Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks),
|
Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks, Api.memory_banks),
|
||||||
Api.safety: (ShieldsProtocolPrivate, Shields),
|
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
|
||||||
Api.datasetio: (DatasetsProtocolPrivate, Datasets),
|
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
|
||||||
Api.scoring: (ScoringFunctionsProtocolPrivate, ScoringFunctions),
|
Api.scoring: (
|
||||||
Api.eval_tasks: (EvalTasksProtocolPrivate, EvalTasks),
|
ScoringFunctionsProtocolPrivate,
|
||||||
|
ScoringFunctions,
|
||||||
|
Api.scoring_functions,
|
||||||
|
),
|
||||||
|
Api.eval: (EvalTasksProtocolPrivate, EvalTasks, Api.eval_tasks),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -73,10 +78,13 @@ class ProviderWithSpec(Provider):
|
||||||
spec: ProviderSpec
|
spec: ProviderSpec
|
||||||
|
|
||||||
|
|
||||||
|
ProviderRegistry = Dict[Api, Dict[str, ProviderSpec]]
|
||||||
|
|
||||||
|
|
||||||
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
|
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
|
||||||
async def resolve_impls(
|
async def resolve_impls(
|
||||||
run_config: StackRunConfig,
|
run_config: StackRunConfig,
|
||||||
provider_registry: Dict[Api, Dict[str, ProviderSpec]],
|
provider_registry: ProviderRegistry,
|
||||||
dist_registry: DistributionRegistry,
|
dist_registry: DistributionRegistry,
|
||||||
) -> Dict[Api, Any]:
|
) -> Dict[Api, Any]:
|
||||||
"""
|
"""
|
||||||
|
@ -273,17 +281,8 @@ async def instantiate_provider(
|
||||||
config_type = instantiate_class_type(provider_spec.config_class)
|
config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
config = config_type(**provider.config)
|
config = config_type(**provider.config)
|
||||||
|
|
||||||
if provider_spec.adapter:
|
method = "get_adapter_impl"
|
||||||
method = "get_adapter_impl"
|
args = [config, deps]
|
||||||
args = [config, deps]
|
|
||||||
else:
|
|
||||||
method = "get_client_impl"
|
|
||||||
protocol = protocols[provider_spec.api]
|
|
||||||
if provider_spec.api in additional_protocols:
|
|
||||||
_, additional_protocol = additional_protocols[provider_spec.api]
|
|
||||||
else:
|
|
||||||
additional_protocol = None
|
|
||||||
args = [protocol, additional_protocol, config, deps]
|
|
||||||
|
|
||||||
elif isinstance(provider_spec, AutoRoutedProviderSpec):
|
elif isinstance(provider_spec, AutoRoutedProviderSpec):
|
||||||
method = "get_auto_router_impl"
|
method = "get_auto_router_impl"
|
||||||
|
@ -313,7 +312,7 @@ async def instantiate_provider(
|
||||||
not isinstance(provider_spec, AutoRoutedProviderSpec)
|
not isinstance(provider_spec, AutoRoutedProviderSpec)
|
||||||
and provider_spec.api in additional_protocols
|
and provider_spec.api in additional_protocols
|
||||||
):
|
):
|
||||||
additional_api, _ = additional_protocols[provider_spec.api]
|
additional_api, _, _ = additional_protocols[provider_spec.api]
|
||||||
check_protocol_compliance(impl, additional_api)
|
check_protocol_compliance(impl, additional_api)
|
||||||
|
|
||||||
return impl
|
return impl
|
||||||
|
@ -359,3 +358,29 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Provider `{obj.__provider_id__} ({obj.__provider_spec__.api})` does not implement the following methods:\n{missing_methods}"
|
f"Provider `{obj.__provider_id__} ({obj.__provider_spec__.api})` does not implement the following methods:\n{missing_methods}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def resolve_remote_stack_impls(
|
||||||
|
config: RemoteProviderConfig,
|
||||||
|
apis: List[str],
|
||||||
|
) -> Dict[Api, Any]:
|
||||||
|
protocols = api_protocol_map()
|
||||||
|
additional_protocols = additional_protocols_map()
|
||||||
|
|
||||||
|
impls = {}
|
||||||
|
for api_str in apis:
|
||||||
|
api = Api(api_str)
|
||||||
|
impls[api] = await get_client_impl(
|
||||||
|
protocols[api],
|
||||||
|
config,
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
if api in additional_protocols:
|
||||||
|
_, additional_protocol, additional_api = additional_protocols[api]
|
||||||
|
impls[additional_api] = await get_client_impl(
|
||||||
|
additional_protocol,
|
||||||
|
config,
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
|
||||||
|
return impls
|
||||||
|
|
|
@ -33,28 +33,20 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
|
||||||
|
|
||||||
api = get_impl_api(p)
|
api = get_impl_api(p)
|
||||||
|
|
||||||
if obj.provider_id == "remote":
|
assert obj.provider_id != "remote", "Remote provider should not be registered"
|
||||||
# TODO: this is broken right now because we use the generic
|
|
||||||
# { identifier, provider_id, provider_resource_id } tuple here
|
|
||||||
# but the APIs expect things like ModelInput, ShieldInput, etc.
|
|
||||||
|
|
||||||
# if this is just a passthrough, we want to let the remote
|
|
||||||
# end actually do the registration with the correct provider
|
|
||||||
obj = obj.model_copy(deep=True)
|
|
||||||
obj.provider_id = ""
|
|
||||||
|
|
||||||
if api == Api.inference:
|
if api == Api.inference:
|
||||||
return await p.register_model(obj)
|
return await p.register_model(obj)
|
||||||
elif api == Api.safety:
|
elif api == Api.safety:
|
||||||
await p.register_shield(obj)
|
return await p.register_shield(obj)
|
||||||
elif api == Api.memory:
|
elif api == Api.memory:
|
||||||
await p.register_memory_bank(obj)
|
return await p.register_memory_bank(obj)
|
||||||
elif api == Api.datasetio:
|
elif api == Api.datasetio:
|
||||||
await p.register_dataset(obj)
|
return await p.register_dataset(obj)
|
||||||
elif api == Api.scoring:
|
elif api == Api.scoring:
|
||||||
await p.register_scoring_function(obj)
|
return await p.register_scoring_function(obj)
|
||||||
elif api == Api.eval:
|
elif api == Api.eval:
|
||||||
await p.register_eval_task(obj)
|
return await p.register_eval_task(obj)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown API {api} for registering object with provider")
|
raise ValueError(f"Unknown API {api} for registering object with provider")
|
||||||
|
|
||||||
|
@ -82,15 +74,10 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
if cls is None:
|
if cls is None:
|
||||||
obj.provider_id = provider_id
|
obj.provider_id = provider_id
|
||||||
else:
|
else:
|
||||||
if provider_id == "remote":
|
# Create a copy of the model data and explicitly set provider_id
|
||||||
# if this is just a passthrough, we got the *WithProvider object
|
model_data = obj.model_dump()
|
||||||
# so we should just override the provider in-place
|
model_data["provider_id"] = provider_id
|
||||||
obj.provider_id = provider_id
|
obj = cls(**model_data)
|
||||||
else:
|
|
||||||
# Create a copy of the model data and explicitly set provider_id
|
|
||||||
model_data = obj.model_dump()
|
|
||||||
model_data["provider_id"] = provider_id
|
|
||||||
obj = cls(**model_data)
|
|
||||||
await self.dist_registry.register(obj)
|
await self.dist_registry.register(obj)
|
||||||
|
|
||||||
# Register all objects from providers
|
# Register all objects from providers
|
||||||
|
@ -100,18 +87,14 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
p.model_store = self
|
p.model_store = self
|
||||||
elif api == Api.safety:
|
elif api == Api.safety:
|
||||||
p.shield_store = self
|
p.shield_store = self
|
||||||
|
|
||||||
elif api == Api.memory:
|
elif api == Api.memory:
|
||||||
p.memory_bank_store = self
|
p.memory_bank_store = self
|
||||||
|
|
||||||
elif api == Api.datasetio:
|
elif api == Api.datasetio:
|
||||||
p.dataset_store = self
|
p.dataset_store = self
|
||||||
|
|
||||||
elif api == Api.scoring:
|
elif api == Api.scoring:
|
||||||
p.scoring_function_store = self
|
p.scoring_function_store = self
|
||||||
scoring_functions = await p.list_scoring_functions()
|
scoring_functions = await p.list_scoring_functions()
|
||||||
await add_objects(scoring_functions, pid, ScoringFn)
|
await add_objects(scoring_functions, pid, ScoringFn)
|
||||||
|
|
||||||
elif api == Api.eval:
|
elif api == Api.eval:
|
||||||
p.eval_task_store = self
|
p.eval_task_store = self
|
||||||
|
|
||||||
|
|
|
@ -182,15 +182,6 @@ async def lifespan(app: FastAPI):
|
||||||
await impl.shutdown()
|
await impl.shutdown()
|
||||||
|
|
||||||
|
|
||||||
def create_dynamic_passthrough(
|
|
||||||
downstream_url: str, downstream_headers: Optional[Dict[str, str]] = None
|
|
||||||
):
|
|
||||||
async def endpoint(request: Request):
|
|
||||||
return await passthrough(request, downstream_url, downstream_headers)
|
|
||||||
|
|
||||||
return endpoint
|
|
||||||
|
|
||||||
|
|
||||||
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
||||||
# TODO: pass the api method and punt it to the Protocol definition directly
|
# TODO: pass the api method and punt it to the Protocol definition directly
|
||||||
return kwargs.get("stream", False)
|
return kwargs.get("stream", False)
|
||||||
|
@ -305,28 +296,19 @@ def main(
|
||||||
endpoints = all_endpoints[api]
|
endpoints = all_endpoints[api]
|
||||||
impl = impls[api]
|
impl = impls[api]
|
||||||
|
|
||||||
if is_passthrough(impl.__provider_spec__):
|
for endpoint in endpoints:
|
||||||
for endpoint in endpoints:
|
if not hasattr(impl, endpoint.name):
|
||||||
url = impl.__provider_config__.url.rstrip("/") + endpoint.route
|
# ideally this should be a typing violation already
|
||||||
getattr(app, endpoint.method)(endpoint.route)(
|
raise ValueError(f"Could not find method {endpoint.name} on {impl}!!")
|
||||||
create_dynamic_passthrough(url)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
for endpoint in endpoints:
|
|
||||||
if not hasattr(impl, endpoint.name):
|
|
||||||
# ideally this should be a typing violation already
|
|
||||||
raise ValueError(
|
|
||||||
f"Could not find method {endpoint.name} on {impl}!!"
|
|
||||||
)
|
|
||||||
|
|
||||||
impl_method = getattr(impl, endpoint.name)
|
impl_method = getattr(impl, endpoint.name)
|
||||||
|
|
||||||
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
||||||
create_dynamic_typed_route(
|
create_dynamic_typed_route(
|
||||||
impl_method,
|
impl_method,
|
||||||
endpoint.method,
|
endpoint.method,
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
cprint(f"Serving API {api_str}", "white", attrs=["bold"])
|
cprint(f"Serving API {api_str}", "white", attrs=["bold"])
|
||||||
for endpoint in endpoints:
|
for endpoint in endpoints:
|
||||||
|
|
|
@ -30,7 +30,7 @@ from llama_stack.apis.eval_tasks import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import StackRunConfig
|
from llama_stack.distribution.datatypes import StackRunConfig
|
||||||
from llama_stack.distribution.distribution import get_provider_registry
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
from llama_stack.distribution.resolver import resolve_impls
|
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
|
||||||
from llama_stack.distribution.store.registry import create_dist_registry
|
from llama_stack.distribution.store.registry import create_dist_registry
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
|
@ -58,29 +58,23 @@ class LlamaStack(
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# Produces a stack of providers for the given run config. Not all APIs may be
|
RESOURCES = [
|
||||||
# asked for in the run config.
|
("models", Api.models, "register_model", "list_models"),
|
||||||
async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]:
|
("shields", Api.shields, "register_shield", "list_shields"),
|
||||||
dist_registry, _ = await create_dist_registry(
|
("memory_banks", Api.memory_banks, "register_memory_bank", "list_memory_banks"),
|
||||||
run_config.metadata_store, run_config.image_name
|
("datasets", Api.datasets, "register_dataset", "list_datasets"),
|
||||||
)
|
(
|
||||||
|
"scoring_fns",
|
||||||
|
Api.scoring_functions,
|
||||||
|
"register_scoring_function",
|
||||||
|
"list_scoring_functions",
|
||||||
|
),
|
||||||
|
("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"),
|
||||||
|
]
|
||||||
|
|
||||||
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
|
|
||||||
|
|
||||||
resources = [
|
async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
|
||||||
("models", Api.models, "register_model", "list_models"),
|
for rsrc, api, register_method, list_method in RESOURCES:
|
||||||
("shields", Api.shields, "register_shield", "list_shields"),
|
|
||||||
("memory_banks", Api.memory_banks, "register_memory_bank", "list_memory_banks"),
|
|
||||||
("datasets", Api.datasets, "register_dataset", "list_datasets"),
|
|
||||||
(
|
|
||||||
"scoring_fns",
|
|
||||||
Api.scoring_functions,
|
|
||||||
"register_scoring_function",
|
|
||||||
"list_scoring_functions",
|
|
||||||
),
|
|
||||||
("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"),
|
|
||||||
]
|
|
||||||
for rsrc, api, register_method, list_method in resources:
|
|
||||||
objects = getattr(run_config, rsrc)
|
objects = getattr(run_config, rsrc)
|
||||||
if api not in impls:
|
if api not in impls:
|
||||||
continue
|
continue
|
||||||
|
@ -96,4 +90,18 @@ async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]:
|
||||||
)
|
)
|
||||||
|
|
||||||
print("")
|
print("")
|
||||||
|
|
||||||
|
|
||||||
|
# Produces a stack of providers for the given run config. Not all APIs may be
|
||||||
|
# asked for in the run config.
|
||||||
|
async def construct_stack(
|
||||||
|
run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None
|
||||||
|
) -> Dict[Api, Any]:
|
||||||
|
dist_registry, _ = await create_dist_registry(
|
||||||
|
run_config.metadata_store, run_config.image_name
|
||||||
|
)
|
||||||
|
impls = await resolve_impls(
|
||||||
|
run_config, provider_registry or get_provider_registry(), dist_registry
|
||||||
|
)
|
||||||
|
await register_resources(run_config, impls)
|
||||||
return impls
|
return impls
|
||||||
|
|
|
@ -99,6 +99,7 @@ class RoutingTable(Protocol):
|
||||||
def get_provider_impl(self, routing_key: str) -> Any: ...
|
def get_provider_impl(self, routing_key: str) -> Any: ...
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: this can now be inlined into RemoteProviderSpec
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AdapterSpec(BaseModel):
|
class AdapterSpec(BaseModel):
|
||||||
adapter_type: str = Field(
|
adapter_type: str = Field(
|
||||||
|
@ -171,12 +172,10 @@ class RemoteProviderConfig(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RemoteProviderSpec(ProviderSpec):
|
class RemoteProviderSpec(ProviderSpec):
|
||||||
adapter: Optional[AdapterSpec] = Field(
|
adapter: AdapterSpec = Field(
|
||||||
default=None,
|
|
||||||
description="""
|
description="""
|
||||||
If some code is needed to convert the remote responses into Llama Stack compatible
|
If some code is needed to convert the remote responses into Llama Stack compatible
|
||||||
API responses, specify the adapter here. If not specified, it indicates the remote
|
API responses, specify the adapter here.
|
||||||
as being "Llama Stack compatible"
|
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -186,38 +185,21 @@ as being "Llama Stack compatible"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def module(self) -> str:
|
def module(self) -> str:
|
||||||
if self.adapter:
|
return self.adapter.module
|
||||||
return self.adapter.module
|
|
||||||
return "llama_stack.distribution.client"
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pip_packages(self) -> List[str]:
|
def pip_packages(self) -> List[str]:
|
||||||
if self.adapter:
|
return self.adapter.pip_packages
|
||||||
return self.adapter.pip_packages
|
|
||||||
return []
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def provider_data_validator(self) -> Optional[str]:
|
def provider_data_validator(self) -> Optional[str]:
|
||||||
if self.adapter:
|
return self.adapter.provider_data_validator
|
||||||
return self.adapter.provider_data_validator
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def is_passthrough(spec: ProviderSpec) -> bool:
|
def remote_provider_spec(api: Api, adapter: AdapterSpec) -> RemoteProviderSpec:
|
||||||
return isinstance(spec, RemoteProviderSpec) and spec.adapter is None
|
|
||||||
|
|
||||||
|
|
||||||
# Can avoid this by using Pydantic computed_field
|
|
||||||
def remote_provider_spec(
|
|
||||||
api: Api, adapter: Optional[AdapterSpec] = None
|
|
||||||
) -> RemoteProviderSpec:
|
|
||||||
config_class = (
|
|
||||||
adapter.config_class
|
|
||||||
if adapter and adapter.config_class
|
|
||||||
else "llama_stack.distribution.datatypes.RemoteProviderConfig"
|
|
||||||
)
|
|
||||||
provider_type = f"remote::{adapter.adapter_type}" if adapter else "remote"
|
|
||||||
|
|
||||||
return RemoteProviderSpec(
|
return RemoteProviderSpec(
|
||||||
api=api, provider_type=provider_type, config_class=config_class, adapter=adapter
|
api=api,
|
||||||
|
provider_type=f"remote::{adapter.adapter_type}",
|
||||||
|
config_class=adapter.config_class,
|
||||||
|
adapter=adapter,
|
||||||
)
|
)
|
||||||
|
|
|
@ -234,7 +234,7 @@ class LlamaGuardShield:
|
||||||
# TODO: llama-stack inference protocol has issues with non-streaming inference code
|
# TODO: llama-stack inference protocol has issues with non-streaming inference code
|
||||||
content = ""
|
content = ""
|
||||||
async for chunk in await self.inference_api.chat_completion(
|
async for chunk in await self.inference_api.chat_completion(
|
||||||
model=self.model,
|
model_id=self.model,
|
||||||
messages=[shield_input_message],
|
messages=[shield_input_message],
|
||||||
stream=True,
|
stream=True,
|
||||||
):
|
):
|
||||||
|
|
|
@ -53,6 +53,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
adapter_type="chromadb",
|
adapter_type="chromadb",
|
||||||
pip_packages=EMBEDDING_DEPS + ["chromadb-client"],
|
pip_packages=EMBEDDING_DEPS + ["chromadb-client"],
|
||||||
module="llama_stack.providers.remote.memory.chroma",
|
module="llama_stack.providers.remote.memory.chroma",
|
||||||
|
config_class="llama_stack.distribution.datatypes.RemoteProviderConfig",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
|
|
|
@ -164,7 +164,6 @@ class OllamaInferenceAdapter(Inference, ModelRegistryHelper, ModelsProtocolPriva
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
print(f"model={model}")
|
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=model.provider_resource_id,
|
model=model.provider_resource_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
|
|
@ -10,7 +10,7 @@ from ..conftest import get_provider_fixture_overrides
|
||||||
|
|
||||||
from ..inference.fixtures import INFERENCE_FIXTURES
|
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||||
from ..memory.fixtures import MEMORY_FIXTURES
|
from ..memory.fixtures import MEMORY_FIXTURES
|
||||||
from ..safety.fixtures import SAFETY_FIXTURES
|
from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield
|
||||||
from .fixtures import AGENTS_FIXTURES
|
from .fixtures import AGENTS_FIXTURES
|
||||||
|
|
||||||
|
|
||||||
|
@ -46,6 +46,16 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
id="together",
|
id="together",
|
||||||
marks=pytest.mark.together,
|
marks=pytest.mark.together,
|
||||||
),
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"inference": "fireworks",
|
||||||
|
"safety": "llama_guard",
|
||||||
|
"memory": "faiss",
|
||||||
|
"agents": "meta_reference",
|
||||||
|
},
|
||||||
|
id="fireworks",
|
||||||
|
marks=pytest.mark.fireworks,
|
||||||
|
),
|
||||||
pytest.param(
|
pytest.param(
|
||||||
{
|
{
|
||||||
"inference": "remote",
|
"inference": "remote",
|
||||||
|
@ -60,7 +70,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
|
|
||||||
|
|
||||||
def pytest_configure(config):
|
def pytest_configure(config):
|
||||||
for mark in ["meta_reference", "ollama", "together", "remote"]:
|
for mark in ["meta_reference", "ollama", "together", "fireworks", "remote"]:
|
||||||
config.addinivalue_line(
|
config.addinivalue_line(
|
||||||
"markers",
|
"markers",
|
||||||
f"{mark}: marks tests as {mark} specific",
|
f"{mark}: marks tests as {mark} specific",
|
||||||
|
@ -75,28 +85,30 @@ def pytest_addoption(parser):
|
||||||
help="Specify the inference model to use for testing",
|
help="Specify the inference model to use for testing",
|
||||||
)
|
)
|
||||||
parser.addoption(
|
parser.addoption(
|
||||||
"--safety-model",
|
"--safety-shield",
|
||||||
action="store",
|
action="store",
|
||||||
default="Llama-Guard-3-8B",
|
default="Llama-Guard-3-8B",
|
||||||
help="Specify the safety model to use for testing",
|
help="Specify the safety shield to use for testing",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def pytest_generate_tests(metafunc):
|
def pytest_generate_tests(metafunc):
|
||||||
safety_model = metafunc.config.getoption("--safety-model")
|
shield_id = metafunc.config.getoption("--safety-shield")
|
||||||
if "safety_model" in metafunc.fixturenames:
|
if "safety_shield" in metafunc.fixturenames:
|
||||||
metafunc.parametrize(
|
metafunc.parametrize(
|
||||||
"safety_model",
|
"safety_shield",
|
||||||
[pytest.param(safety_model, id="")],
|
[pytest.param(shield_id, id="")],
|
||||||
indirect=True,
|
indirect=True,
|
||||||
)
|
)
|
||||||
if "inference_model" in metafunc.fixturenames:
|
if "inference_model" in metafunc.fixturenames:
|
||||||
inference_model = metafunc.config.getoption("--inference-model")
|
inference_model = metafunc.config.getoption("--inference-model")
|
||||||
models = list(set({inference_model, safety_model}))
|
models = set({inference_model})
|
||||||
|
if safety_model := safety_model_from_shield(shield_id):
|
||||||
|
models.add(safety_model)
|
||||||
|
|
||||||
metafunc.parametrize(
|
metafunc.parametrize(
|
||||||
"inference_model",
|
"inference_model",
|
||||||
[pytest.param(models, id="")],
|
[pytest.param(list(models), id="")],
|
||||||
indirect=True,
|
indirect=True,
|
||||||
)
|
)
|
||||||
if "agents_stack" in metafunc.fixturenames:
|
if "agents_stack" in metafunc.fixturenames:
|
||||||
|
|
|
@ -16,10 +16,9 @@ from llama_stack.providers.inline.agents.meta_reference import (
|
||||||
MetaReferenceAgentsImplConfig,
|
MetaReferenceAgentsImplConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||||
from ..safety.fixtures import get_shield_to_register
|
|
||||||
|
|
||||||
|
|
||||||
def pick_inference_model(inference_model):
|
def pick_inference_model(inference_model):
|
||||||
|
@ -60,7 +59,7 @@ AGENTS_FIXTURES = ["meta_reference", "remote"]
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session")
|
||||||
async def agents_stack(request, inference_model, safety_model):
|
async def agents_stack(request, inference_model, safety_shield):
|
||||||
fixture_dict = request.param
|
fixture_dict = request.param
|
||||||
|
|
||||||
providers = {}
|
providers = {}
|
||||||
|
@ -71,13 +70,10 @@ async def agents_stack(request, inference_model, safety_model):
|
||||||
if fixture.provider_data:
|
if fixture.provider_data:
|
||||||
provider_data.update(fixture.provider_data)
|
provider_data.update(fixture.provider_data)
|
||||||
|
|
||||||
shield_input = get_shield_to_register(
|
|
||||||
providers["safety"][0].provider_type, safety_model
|
|
||||||
)
|
|
||||||
inference_models = (
|
inference_models = (
|
||||||
inference_model if isinstance(inference_model, list) else [inference_model]
|
inference_model if isinstance(inference_model, list) else [inference_model]
|
||||||
)
|
)
|
||||||
impls = await resolve_impls_for_test_v2(
|
test_stack = await construct_stack_for_test(
|
||||||
[Api.agents, Api.inference, Api.safety, Api.memory],
|
[Api.agents, Api.inference, Api.safety, Api.memory],
|
||||||
providers,
|
providers,
|
||||||
provider_data,
|
provider_data,
|
||||||
|
@ -87,6 +83,6 @@ async def agents_stack(request, inference_model, safety_model):
|
||||||
)
|
)
|
||||||
for model in inference_models
|
for model in inference_models
|
||||||
],
|
],
|
||||||
shields=[shield_input],
|
shields=[safety_shield],
|
||||||
)
|
)
|
||||||
return impls[Api.agents], impls[Api.memory]
|
return test_stack
|
||||||
|
|
|
@ -1,148 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import pytest_asyncio
|
|
||||||
|
|
||||||
from llama_stack.apis.agents import * # noqa: F403
|
|
||||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
|
||||||
from llama_stack.providers.datatypes import * # noqa: F403
|
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
|
|
||||||
|
|
||||||
# How to run this test:
|
|
||||||
#
|
|
||||||
# 1. Ensure you have a conda environment with the right dependencies installed.
|
|
||||||
# This includes `pytest` and `pytest-asyncio`.
|
|
||||||
#
|
|
||||||
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
|
|
||||||
#
|
|
||||||
# 3. Run:
|
|
||||||
#
|
|
||||||
# ```bash
|
|
||||||
# PROVIDER_ID=<your_provider> \
|
|
||||||
# PROVIDER_CONFIG=provider_config.yaml \
|
|
||||||
# pytest -s llama_stack/providers/tests/agents/test_agent_persistence.py \
|
|
||||||
# --tb=short --disable-warnings
|
|
||||||
# ```
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
|
||||||
async def agents_settings():
|
|
||||||
impls = await resolve_impls_for_test(
|
|
||||||
Api.agents, deps=[Api.inference, Api.memory, Api.safety]
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"impl": impls[Api.agents],
|
|
||||||
"memory_impl": impls[Api.memory],
|
|
||||||
"common_params": {
|
|
||||||
"model": "Llama3.1-8B-Instruct",
|
|
||||||
"instructions": "You are a helpful assistant.",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_messages():
|
|
||||||
return [
|
|
||||||
UserMessage(content="What's the weather like today?"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_agents_and_sessions(agents_settings, sample_messages):
|
|
||||||
agents_impl = agents_settings["impl"]
|
|
||||||
# First, create an agent
|
|
||||||
agent_config = AgentConfig(
|
|
||||||
model=agents_settings["common_params"]["model"],
|
|
||||||
instructions=agents_settings["common_params"]["instructions"],
|
|
||||||
enable_session_persistence=True,
|
|
||||||
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
|
||||||
input_shields=[],
|
|
||||||
output_shields=[],
|
|
||||||
tools=[],
|
|
||||||
max_infer_iters=5,
|
|
||||||
)
|
|
||||||
|
|
||||||
create_response = await agents_impl.create_agent(agent_config)
|
|
||||||
agent_id = create_response.agent_id
|
|
||||||
|
|
||||||
# Create a session
|
|
||||||
session_create_response = await agents_impl.create_agent_session(
|
|
||||||
agent_id, "Test Session"
|
|
||||||
)
|
|
||||||
session_id = session_create_response.session_id
|
|
||||||
persistence_store = await kvstore_impl(agents_settings["persistence"])
|
|
||||||
|
|
||||||
await agents_impl.delete_agents_session(agent_id, session_id)
|
|
||||||
session_response = await persistence_store.get(f"session:{agent_id}:{session_id}")
|
|
||||||
|
|
||||||
await agents_impl.delete_agents(agent_id)
|
|
||||||
agent_response = await persistence_store.get(f"agent:{agent_id}")
|
|
||||||
|
|
||||||
assert session_response is None
|
|
||||||
assert agent_response is None
|
|
||||||
|
|
||||||
|
|
||||||
async def test_get_agent_turns_and_steps(agents_settings, sample_messages):
|
|
||||||
agents_impl = agents_settings["impl"]
|
|
||||||
|
|
||||||
# First, create an agent
|
|
||||||
agent_config = AgentConfig(
|
|
||||||
model=agents_settings["common_params"]["model"],
|
|
||||||
instructions=agents_settings["common_params"]["instructions"],
|
|
||||||
enable_session_persistence=True,
|
|
||||||
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
|
||||||
input_shields=[],
|
|
||||||
output_shields=[],
|
|
||||||
tools=[],
|
|
||||||
max_infer_iters=5,
|
|
||||||
)
|
|
||||||
|
|
||||||
create_response = await agents_impl.create_agent(agent_config)
|
|
||||||
agent_id = create_response.agent_id
|
|
||||||
|
|
||||||
# Create a session
|
|
||||||
session_create_response = await agents_impl.create_agent_session(
|
|
||||||
agent_id, "Test Session"
|
|
||||||
)
|
|
||||||
session_id = session_create_response.session_id
|
|
||||||
|
|
||||||
# Create and execute a turn
|
|
||||||
turn_request = dict(
|
|
||||||
agent_id=agent_id,
|
|
||||||
session_id=session_id,
|
|
||||||
messages=sample_messages,
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
turn_response = [
|
|
||||||
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
|
||||||
]
|
|
||||||
|
|
||||||
final_event = turn_response[-1].event.payload
|
|
||||||
turn_id = final_event.turn.turn_id
|
|
||||||
persistence_store = await kvstore_impl(SqliteKVStoreConfig())
|
|
||||||
turn = await persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
|
||||||
response = await agents_impl.get_agents_turn(agent_id, session_id, turn_id)
|
|
||||||
|
|
||||||
assert isinstance(response, Turn)
|
|
||||||
assert response == final_event.turn
|
|
||||||
assert turn == final_event.turn
|
|
||||||
|
|
||||||
steps = final_event.turn.steps
|
|
||||||
step_id = steps[0].step_id
|
|
||||||
step_response = await agents_impl.get_agents_step(
|
|
||||||
agent_id, session_id, turn_id, step_id
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(step_response.step, Step)
|
|
||||||
assert step_response.step == steps[0]
|
|
|
@ -17,6 +17,7 @@ from llama_stack.providers.datatypes import * # noqa: F403
|
||||||
# -m "meta_reference"
|
# -m "meta_reference"
|
||||||
|
|
||||||
from .fixtures import pick_inference_model
|
from .fixtures import pick_inference_model
|
||||||
|
from .utils import create_agent_session
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -67,31 +68,19 @@ def query_attachment_messages():
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
async def create_agent_session(agents_impl, agent_config):
|
|
||||||
create_response = await agents_impl.create_agent(agent_config)
|
|
||||||
agent_id = create_response.agent_id
|
|
||||||
|
|
||||||
# Create a session
|
|
||||||
session_create_response = await agents_impl.create_agent_session(
|
|
||||||
agent_id, "Test Session"
|
|
||||||
)
|
|
||||||
session_id = session_create_response.session_id
|
|
||||||
return agent_id, session_id
|
|
||||||
|
|
||||||
|
|
||||||
class TestAgents:
|
class TestAgents:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_agent_turns_with_safety(
|
async def test_agent_turns_with_safety(
|
||||||
self, safety_model, agents_stack, common_params
|
self, safety_shield, agents_stack, common_params
|
||||||
):
|
):
|
||||||
agents_impl, _ = agents_stack
|
agents_impl = agents_stack.impls[Api.agents]
|
||||||
agent_id, session_id = await create_agent_session(
|
agent_id, session_id = await create_agent_session(
|
||||||
agents_impl,
|
agents_impl,
|
||||||
AgentConfig(
|
AgentConfig(
|
||||||
**{
|
**{
|
||||||
**common_params,
|
**common_params,
|
||||||
"input_shields": [safety_model],
|
"input_shields": [safety_shield.shield_id],
|
||||||
"output_shields": [safety_model],
|
"output_shields": [safety_shield.shield_id],
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -127,7 +116,7 @@ class TestAgents:
|
||||||
async def test_create_agent_turn(
|
async def test_create_agent_turn(
|
||||||
self, agents_stack, sample_messages, common_params
|
self, agents_stack, sample_messages, common_params
|
||||||
):
|
):
|
||||||
agents_impl, _ = agents_stack
|
agents_impl = agents_stack.impls[Api.agents]
|
||||||
|
|
||||||
agent_id, session_id = await create_agent_session(
|
agent_id, session_id = await create_agent_session(
|
||||||
agents_impl, AgentConfig(**common_params)
|
agents_impl, AgentConfig(**common_params)
|
||||||
|
@ -158,7 +147,7 @@ class TestAgents:
|
||||||
query_attachment_messages,
|
query_attachment_messages,
|
||||||
common_params,
|
common_params,
|
||||||
):
|
):
|
||||||
agents_impl, _ = agents_stack
|
agents_impl = agents_stack.impls[Api.agents]
|
||||||
urls = [
|
urls = [
|
||||||
"memory_optimizations.rst",
|
"memory_optimizations.rst",
|
||||||
"chat.rst",
|
"chat.rst",
|
||||||
|
@ -226,7 +215,7 @@ class TestAgents:
|
||||||
async def test_create_agent_turn_with_brave_search(
|
async def test_create_agent_turn_with_brave_search(
|
||||||
self, agents_stack, search_query_messages, common_params
|
self, agents_stack, search_query_messages, common_params
|
||||||
):
|
):
|
||||||
agents_impl, _ = agents_stack
|
agents_impl = agents_stack.impls[Api.agents]
|
||||||
|
|
||||||
if "BRAVE_SEARCH_API_KEY" not in os.environ:
|
if "BRAVE_SEARCH_API_KEY" not in os.environ:
|
||||||
pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test")
|
pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test")
|
||||||
|
|
122
llama_stack/providers/tests/agents/test_persistence.py
Normal file
122
llama_stack/providers/tests/agents/test_persistence.py
Normal file
|
@ -0,0 +1,122 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.apis.agents import * # noqa: F403
|
||||||
|
from llama_stack.providers.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
|
||||||
|
from .fixtures import pick_inference_model
|
||||||
|
|
||||||
|
from .utils import create_agent_session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_messages():
|
||||||
|
return [
|
||||||
|
UserMessage(content="What's the weather like today?"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def common_params(inference_model):
|
||||||
|
inference_model = pick_inference_model(inference_model)
|
||||||
|
|
||||||
|
return dict(
|
||||||
|
model=inference_model,
|
||||||
|
instructions="You are a helpful assistant.",
|
||||||
|
enable_session_persistence=True,
|
||||||
|
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
||||||
|
input_shields=[],
|
||||||
|
output_shields=[],
|
||||||
|
tools=[],
|
||||||
|
max_infer_iters=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentPersistence:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_agents_and_sessions(self, agents_stack, common_params):
|
||||||
|
agents_impl = agents_stack.impls[Api.agents]
|
||||||
|
agent_id, session_id = await create_agent_session(
|
||||||
|
agents_impl,
|
||||||
|
AgentConfig(
|
||||||
|
**{
|
||||||
|
**common_params,
|
||||||
|
"input_shields": [],
|
||||||
|
"output_shields": [],
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
run_config = agents_stack.run_config
|
||||||
|
provider_config = run_config.providers["agents"][0].config
|
||||||
|
persistence_store = await kvstore_impl(
|
||||||
|
SqliteKVStoreConfig(**provider_config["persistence_store"])
|
||||||
|
)
|
||||||
|
|
||||||
|
await agents_impl.delete_agents_session(agent_id, session_id)
|
||||||
|
session_response = await persistence_store.get(
|
||||||
|
f"session:{agent_id}:{session_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
await agents_impl.delete_agents(agent_id)
|
||||||
|
agent_response = await persistence_store.get(f"agent:{agent_id}")
|
||||||
|
|
||||||
|
assert session_response is None
|
||||||
|
assert agent_response is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_agent_turns_and_steps(
|
||||||
|
self, agents_stack, sample_messages, common_params
|
||||||
|
):
|
||||||
|
agents_impl = agents_stack.impls[Api.agents]
|
||||||
|
|
||||||
|
agent_id, session_id = await create_agent_session(
|
||||||
|
agents_impl,
|
||||||
|
AgentConfig(
|
||||||
|
**{
|
||||||
|
**common_params,
|
||||||
|
"input_shields": [],
|
||||||
|
"output_shields": [],
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create and execute a turn
|
||||||
|
turn_request = dict(
|
||||||
|
agent_id=agent_id,
|
||||||
|
session_id=session_id,
|
||||||
|
messages=sample_messages,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
turn_response = [
|
||||||
|
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
||||||
|
]
|
||||||
|
|
||||||
|
final_event = turn_response[-1].event.payload
|
||||||
|
turn_id = final_event.turn.turn_id
|
||||||
|
|
||||||
|
provider_config = agents_stack.run_config.providers["agents"][0].config
|
||||||
|
persistence_store = await kvstore_impl(
|
||||||
|
SqliteKVStoreConfig(**provider_config["persistence_store"])
|
||||||
|
)
|
||||||
|
turn = await persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
|
||||||
|
response = await agents_impl.get_agents_turn(agent_id, session_id, turn_id)
|
||||||
|
|
||||||
|
assert isinstance(response, Turn)
|
||||||
|
assert response == final_event.turn
|
||||||
|
assert turn == final_event.turn.model_dump_json()
|
||||||
|
|
||||||
|
steps = final_event.turn.steps
|
||||||
|
step_id = steps[0].step_id
|
||||||
|
step_response = await agents_impl.get_agents_step(
|
||||||
|
agent_id, session_id, turn_id, step_id
|
||||||
|
)
|
||||||
|
|
||||||
|
assert step_response.step == steps[0]
|
17
llama_stack/providers/tests/agents/utils.py
Normal file
17
llama_stack/providers/tests/agents/utils.py
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
async def create_agent_session(agents_impl, agent_config):
|
||||||
|
create_response = await agents_impl.create_agent(agent_config)
|
||||||
|
agent_id = create_response.agent_id
|
||||||
|
|
||||||
|
# Create a session
|
||||||
|
session_create_response = await agents_impl.create_agent_session(
|
||||||
|
agent_id, "Test Session"
|
||||||
|
)
|
||||||
|
session_id = session_create_response.session_id
|
||||||
|
return agent_id, session_id
|
|
@ -35,8 +35,8 @@ def remote_stack_fixture() -> ProviderFixture:
|
||||||
return ProviderFixture(
|
return ProviderFixture(
|
||||||
providers=[
|
providers=[
|
||||||
Provider(
|
Provider(
|
||||||
provider_id="remote",
|
provider_id="test::remote",
|
||||||
provider_type="remote",
|
provider_type="test::remote",
|
||||||
config=config.model_dump(),
|
config=config.model_dump(),
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
|
|
|
@ -9,7 +9,7 @@ import pytest_asyncio
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, Provider
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
|
|
||||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||||
|
|
||||||
|
|
||||||
|
@ -52,10 +52,10 @@ async def datasetio_stack(request):
|
||||||
fixture_name = request.param
|
fixture_name = request.param
|
||||||
fixture = request.getfixturevalue(f"datasetio_{fixture_name}")
|
fixture = request.getfixturevalue(f"datasetio_{fixture_name}")
|
||||||
|
|
||||||
impls = await resolve_impls_for_test_v2(
|
test_stack = await construct_stack_for_test(
|
||||||
[Api.datasetio],
|
[Api.datasetio],
|
||||||
{"datasetio": fixture.providers},
|
{"datasetio": fixture.providers},
|
||||||
fixture.provider_data,
|
fixture.provider_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
return impls[Api.datasetio], impls[Api.datasets]
|
return test_stack.impls[Api.datasetio], test_stack.impls[Api.datasets]
|
||||||
|
|
|
@ -9,7 +9,7 @@ import pytest_asyncio
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, Provider
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
|
|
||||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||||
|
|
||||||
|
|
||||||
|
@ -46,10 +46,10 @@ async def eval_stack(request):
|
||||||
if fixture.provider_data:
|
if fixture.provider_data:
|
||||||
provider_data.update(fixture.provider_data)
|
provider_data.update(fixture.provider_data)
|
||||||
|
|
||||||
impls = await resolve_impls_for_test_v2(
|
test_stack = await construct_stack_for_test(
|
||||||
[Api.eval, Api.datasetio, Api.inference, Api.scoring],
|
[Api.eval, Api.datasetio, Api.inference, Api.scoring],
|
||||||
providers,
|
providers,
|
||||||
provider_data,
|
provider_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
return impls
|
return test_stack.impls
|
||||||
|
|
|
@ -21,7 +21,7 @@ from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
|
||||||
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
||||||
from llama_stack.providers.remote.inference.together import TogetherImplConfig
|
from llama_stack.providers.remote.inference.together import TogetherImplConfig
|
||||||
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
|
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
|
||||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||||
|
|
||||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||||
from ..env import get_env_or_fail
|
from ..env import get_env_or_fail
|
||||||
|
@ -182,15 +182,11 @@ INFERENCE_FIXTURES = [
|
||||||
async def inference_stack(request, inference_model):
|
async def inference_stack(request, inference_model):
|
||||||
fixture_name = request.param
|
fixture_name = request.param
|
||||||
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
|
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
|
||||||
impls = await resolve_impls_for_test_v2(
|
test_stack = await construct_stack_for_test(
|
||||||
[Api.inference],
|
[Api.inference],
|
||||||
{"inference": inference_fixture.providers},
|
{"inference": inference_fixture.providers},
|
||||||
inference_fixture.provider_data,
|
inference_fixture.provider_data,
|
||||||
models=[
|
models=[ModelInput(model_id=inference_model)],
|
||||||
ModelInput(
|
|
||||||
model_id=inference_model,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return (impls[Api.inference], impls[Api.models])
|
return test_stack.impls[Api.inference], test_stack.impls[Api.models]
|
||||||
|
|
|
@ -147,9 +147,9 @@ class TestInference:
|
||||||
|
|
||||||
user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003."
|
user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003."
|
||||||
response = await inference_impl.completion(
|
response = await inference_impl.completion(
|
||||||
|
model_id=inference_model,
|
||||||
content=user_input,
|
content=user_input,
|
||||||
stream=False,
|
stream=False,
|
||||||
model=inference_model,
|
|
||||||
sampling_params=SamplingParams(
|
sampling_params=SamplingParams(
|
||||||
max_tokens=50,
|
max_tokens=50,
|
||||||
),
|
),
|
||||||
|
|
|
@ -55,7 +55,7 @@ class TestVisionModelInference:
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await inference_impl.chat_completion(
|
response = await inference_impl.chat_completion(
|
||||||
model=inference_model,
|
model_id=inference_model,
|
||||||
messages=[
|
messages=[
|
||||||
UserMessage(content="You are a helpful assistant."),
|
UserMessage(content="You are a helpful assistant."),
|
||||||
UserMessage(content=[image, "Describe this image in two sentences."]),
|
UserMessage(content=[image, "Describe this image in two sentences."]),
|
||||||
|
@ -102,7 +102,7 @@ class TestVisionModelInference:
|
||||||
response = [
|
response = [
|
||||||
r
|
r
|
||||||
async for r in await inference_impl.chat_completion(
|
async for r in await inference_impl.chat_completion(
|
||||||
model=inference_model,
|
model_id=inference_model,
|
||||||
messages=[
|
messages=[
|
||||||
UserMessage(content="You are a helpful assistant."),
|
UserMessage(content="You are a helpful assistant."),
|
||||||
UserMessage(
|
UserMessage(
|
||||||
|
|
|
@ -14,7 +14,7 @@ from llama_stack.distribution.datatypes import Api, Provider, RemoteProviderConf
|
||||||
from llama_stack.providers.inline.memory.faiss import FaissImplConfig
|
from llama_stack.providers.inline.memory.faiss import FaissImplConfig
|
||||||
from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
|
from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
|
||||||
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig
|
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig
|
||||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||||
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
|
||||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||||
from ..env import get_env_or_fail
|
from ..env import get_env_or_fail
|
||||||
|
@ -101,10 +101,10 @@ async def memory_stack(request):
|
||||||
fixture_name = request.param
|
fixture_name = request.param
|
||||||
fixture = request.getfixturevalue(f"memory_{fixture_name}")
|
fixture = request.getfixturevalue(f"memory_{fixture_name}")
|
||||||
|
|
||||||
impls = await resolve_impls_for_test_v2(
|
test_stack = await construct_stack_for_test(
|
||||||
[Api.memory],
|
[Api.memory],
|
||||||
{"memory": fixture.providers},
|
{"memory": fixture.providers},
|
||||||
fixture.provider_data,
|
fixture.provider_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
return impls[Api.memory], impls[Api.memory_banks]
|
return test_stack.impls[Api.memory], test_stack.impls[Api.memory_banks]
|
||||||
|
|
|
@ -5,33 +5,36 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
|
||||||
import tempfile
|
import tempfile
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
from llama_stack.distribution.build import print_pip_install_help
|
from llama_stack.distribution.build import print_pip_install_help
|
||||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||||
from llama_stack.distribution.distribution import get_provider_registry
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||||
|
from llama_stack.distribution.resolver import resolve_remote_stack_impls
|
||||||
from llama_stack.distribution.stack import construct_stack
|
from llama_stack.distribution.stack import construct_stack
|
||||||
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
|
||||||
|
|
||||||
|
|
||||||
async def resolve_impls_for_test_v2(
|
class TestStack(BaseModel):
|
||||||
|
impls: Dict[Api, Any]
|
||||||
|
run_config: StackRunConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def construct_stack_for_test(
|
||||||
apis: List[Api],
|
apis: List[Api],
|
||||||
providers: Dict[str, List[Provider]],
|
providers: Dict[str, List[Provider]],
|
||||||
provider_data: Optional[Dict[str, Any]] = None,
|
provider_data: Optional[Dict[str, Any]] = None,
|
||||||
models: Optional[List[Model]] = None,
|
models: Optional[List[ModelInput]] = None,
|
||||||
shields: Optional[List[Shield]] = None,
|
shields: Optional[List[ShieldInput]] = None,
|
||||||
memory_banks: Optional[List[MemoryBank]] = None,
|
memory_banks: Optional[List[MemoryBankInput]] = None,
|
||||||
datasets: Optional[List[Dataset]] = None,
|
datasets: Optional[List[DatasetInput]] = None,
|
||||||
scoring_fns: Optional[List[ScoringFn]] = None,
|
scoring_fns: Optional[List[ScoringFnInput]] = None,
|
||||||
eval_tasks: Optional[List[EvalTask]] = None,
|
eval_tasks: Optional[List[EvalTaskInput]] = None,
|
||||||
):
|
) -> TestStack:
|
||||||
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||||
run_config = dict(
|
run_config = dict(
|
||||||
built_at=datetime.now(),
|
built_at=datetime.now(),
|
||||||
|
@ -48,7 +51,18 @@ async def resolve_impls_for_test_v2(
|
||||||
)
|
)
|
||||||
run_config = parse_and_maybe_upgrade_config(run_config)
|
run_config = parse_and_maybe_upgrade_config(run_config)
|
||||||
try:
|
try:
|
||||||
impls = await construct_stack(run_config)
|
remote_config = remote_provider_config(run_config)
|
||||||
|
if not remote_config:
|
||||||
|
# TODO: add to provider registry by creating interesting mocks or fakes
|
||||||
|
impls = await construct_stack(run_config, get_provider_registry())
|
||||||
|
else:
|
||||||
|
# we don't register resources for a remote stack as part of the fixture setup
|
||||||
|
# because the stack is already "up". if a test needs to register resources, it
|
||||||
|
# can do so manually always.
|
||||||
|
|
||||||
|
impls = await resolve_remote_stack_impls(remote_config, run_config.apis)
|
||||||
|
|
||||||
|
test_stack = TestStack(impls=impls, run_config=run_config)
|
||||||
except ModuleNotFoundError as e:
|
except ModuleNotFoundError as e:
|
||||||
print_pip_install_help(providers)
|
print_pip_install_help(providers)
|
||||||
raise e
|
raise e
|
||||||
|
@ -58,91 +72,22 @@ async def resolve_impls_for_test_v2(
|
||||||
{"X-LlamaStack-ProviderData": json.dumps(provider_data)}
|
{"X-LlamaStack-ProviderData": json.dumps(provider_data)}
|
||||||
)
|
)
|
||||||
|
|
||||||
return impls
|
return test_stack
|
||||||
|
|
||||||
|
|
||||||
async def resolve_impls_for_test(api: Api, deps: List[Api] = None):
|
def remote_provider_config(
|
||||||
if "PROVIDER_CONFIG" not in os.environ:
|
run_config: StackRunConfig,
|
||||||
raise ValueError(
|
) -> Optional[RemoteProviderConfig]:
|
||||||
"You must set PROVIDER_CONFIG to a YAML file containing provider config"
|
remote_config = None
|
||||||
)
|
has_non_remote = False
|
||||||
|
for api_providers in run_config.providers.values():
|
||||||
|
for provider in api_providers:
|
||||||
|
if provider.provider_type == "test::remote":
|
||||||
|
remote_config = RemoteProviderConfig(**provider.config)
|
||||||
|
else:
|
||||||
|
has_non_remote = True
|
||||||
|
|
||||||
with open(os.environ["PROVIDER_CONFIG"], "r") as f:
|
if remote_config:
|
||||||
config_dict = yaml.safe_load(f)
|
assert not has_non_remote, "Remote stack cannot have non-remote providers"
|
||||||
|
|
||||||
providers = read_providers(api, config_dict)
|
return remote_config
|
||||||
|
|
||||||
chosen = choose_providers(providers, api, deps)
|
|
||||||
run_config = dict(
|
|
||||||
built_at=datetime.now(),
|
|
||||||
image_name="test-fixture",
|
|
||||||
apis=[api] + (deps or []),
|
|
||||||
providers=chosen,
|
|
||||||
)
|
|
||||||
run_config = parse_and_maybe_upgrade_config(run_config)
|
|
||||||
try:
|
|
||||||
impls = await resolve_impls(run_config, get_provider_registry())
|
|
||||||
except ModuleNotFoundError as e:
|
|
||||||
print_pip_install_help(providers)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
if "provider_data" in config_dict:
|
|
||||||
provider_id = chosen[api.value][0].provider_id
|
|
||||||
provider_data = config_dict["provider_data"].get(provider_id, {})
|
|
||||||
if provider_data:
|
|
||||||
set_request_provider_data(
|
|
||||||
{"X-LlamaStack-ProviderData": json.dumps(provider_data)}
|
|
||||||
)
|
|
||||||
|
|
||||||
return impls
|
|
||||||
|
|
||||||
|
|
||||||
def read_providers(api: Api, config_dict: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
if "providers" not in config_dict:
|
|
||||||
raise ValueError("Config file should contain a `providers` key")
|
|
||||||
|
|
||||||
providers = config_dict["providers"]
|
|
||||||
if isinstance(providers, dict):
|
|
||||||
return providers
|
|
||||||
elif isinstance(providers, list):
|
|
||||||
return {
|
|
||||||
api.value: providers,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Config file should contain a list of providers or dict(api to providers)"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def choose_providers(
|
|
||||||
providers: Dict[str, Any], api: Api, deps: List[Api] = None
|
|
||||||
) -> Dict[str, Provider]:
|
|
||||||
chosen = {}
|
|
||||||
if api.value not in providers:
|
|
||||||
raise ValueError(f"No providers found for `{api}`?")
|
|
||||||
chosen[api.value] = [pick_provider(api, providers[api.value], "PROVIDER_ID")]
|
|
||||||
|
|
||||||
for dep in deps or []:
|
|
||||||
if dep.value not in providers:
|
|
||||||
raise ValueError(f"No providers specified for `{dep}` in config?")
|
|
||||||
chosen[dep.value] = [Provider(**x) for x in providers[dep.value]]
|
|
||||||
|
|
||||||
return chosen
|
|
||||||
|
|
||||||
|
|
||||||
def pick_provider(api: Api, providers: List[Any], key: str) -> Provider:
|
|
||||||
providers_by_id = {x["provider_id"]: x for x in providers}
|
|
||||||
if len(providers_by_id) == 0:
|
|
||||||
raise ValueError(f"No providers found for `{api}` in config file")
|
|
||||||
|
|
||||||
if key in os.environ:
|
|
||||||
provider_id = os.environ[key]
|
|
||||||
if provider_id not in providers_by_id:
|
|
||||||
raise ValueError(f"Provider ID {provider_id} not found in config file")
|
|
||||||
provider = providers_by_id[provider_id]
|
|
||||||
else:
|
|
||||||
provider = list(providers_by_id.values())[0]
|
|
||||||
provider_id = provider["provider_id"]
|
|
||||||
print(f"No provider ID specified, picking first `{provider_id}`")
|
|
||||||
|
|
||||||
return Provider(**provider)
|
|
||||||
|
|
|
@ -66,14 +66,14 @@ def pytest_configure(config):
|
||||||
|
|
||||||
def pytest_addoption(parser):
|
def pytest_addoption(parser):
|
||||||
parser.addoption(
|
parser.addoption(
|
||||||
"--safety-model",
|
"--safety-shield",
|
||||||
action="store",
|
action="store",
|
||||||
default=None,
|
default=None,
|
||||||
help="Specify the safety model to use for testing",
|
help="Specify the safety shield to use for testing",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
SAFETY_MODEL_PARAMS = [
|
SAFETY_SHIELD_PARAMS = [
|
||||||
pytest.param("Llama-Guard-3-1B", marks=pytest.mark.guard_1b, id="guard_1b"),
|
pytest.param("Llama-Guard-3-1B", marks=pytest.mark.guard_1b, id="guard_1b"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -83,13 +83,13 @@ def pytest_generate_tests(metafunc):
|
||||||
# But a user can also pass in a custom combination via the CLI by doing
|
# But a user can also pass in a custom combination via the CLI by doing
|
||||||
# `--providers inference=together,safety=meta_reference`
|
# `--providers inference=together,safety=meta_reference`
|
||||||
|
|
||||||
if "safety_model" in metafunc.fixturenames:
|
if "safety_shield" in metafunc.fixturenames:
|
||||||
model = metafunc.config.getoption("--safety-model")
|
shield_id = metafunc.config.getoption("--safety-shield")
|
||||||
if model:
|
if shield_id:
|
||||||
params = [pytest.param(model, id="")]
|
params = [pytest.param(shield_id, id="")]
|
||||||
else:
|
else:
|
||||||
params = SAFETY_MODEL_PARAMS
|
params = SAFETY_SHIELD_PARAMS
|
||||||
for fixture in ["inference_model", "safety_model"]:
|
for fixture in ["inference_model", "safety_shield"]:
|
||||||
metafunc.parametrize(
|
metafunc.parametrize(
|
||||||
fixture,
|
fixture,
|
||||||
params,
|
params,
|
||||||
|
|
|
@ -16,7 +16,7 @@ from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
|
||||||
from llama_stack.providers.inline.safety.prompt_guard import PromptGuardConfig
|
from llama_stack.providers.inline.safety.prompt_guard import PromptGuardConfig
|
||||||
from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig
|
from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig
|
||||||
|
|
||||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||||
|
|
||||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||||
from ..env import get_env_or_fail
|
from ..env import get_env_or_fail
|
||||||
|
@ -27,19 +27,38 @@ def safety_remote() -> ProviderFixture:
|
||||||
return remote_stack_fixture()
|
return remote_stack_fixture()
|
||||||
|
|
||||||
|
|
||||||
|
def safety_model_from_shield(shield_id):
|
||||||
|
if shield_id in ("Bedrock", "CodeScanner", "CodeShield"):
|
||||||
|
return None
|
||||||
|
|
||||||
|
return shield_id
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def safety_model(request):
|
def safety_shield(request):
|
||||||
if hasattr(request, "param"):
|
if hasattr(request, "param"):
|
||||||
return request.param
|
shield_id = request.param
|
||||||
return request.config.getoption("--safety-model", None)
|
else:
|
||||||
|
shield_id = request.config.getoption("--safety-shield", None)
|
||||||
|
|
||||||
|
if shield_id == "bedrock":
|
||||||
|
shield_id = get_env_or_fail("BEDROCK_GUARDRAIL_IDENTIFIER")
|
||||||
|
params = {"guardrailVersion": get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")}
|
||||||
|
else:
|
||||||
|
params = {}
|
||||||
|
|
||||||
|
return ShieldInput(
|
||||||
|
shield_id=shield_id,
|
||||||
|
params=params,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def safety_llama_guard(safety_model) -> ProviderFixture:
|
def safety_llama_guard() -> ProviderFixture:
|
||||||
return ProviderFixture(
|
return ProviderFixture(
|
||||||
providers=[
|
providers=[
|
||||||
Provider(
|
Provider(
|
||||||
provider_id="inline::llama-guard",
|
provider_id="llama-guard",
|
||||||
provider_type="inline::llama-guard",
|
provider_type="inline::llama-guard",
|
||||||
config=LlamaGuardConfig().model_dump(),
|
config=LlamaGuardConfig().model_dump(),
|
||||||
)
|
)
|
||||||
|
@ -55,7 +74,7 @@ def safety_prompt_guard() -> ProviderFixture:
|
||||||
return ProviderFixture(
|
return ProviderFixture(
|
||||||
providers=[
|
providers=[
|
||||||
Provider(
|
Provider(
|
||||||
provider_id="inline::prompt-guard",
|
provider_id="prompt-guard",
|
||||||
provider_type="inline::prompt-guard",
|
provider_type="inline::prompt-guard",
|
||||||
config=PromptGuardConfig().model_dump(),
|
config=PromptGuardConfig().model_dump(),
|
||||||
)
|
)
|
||||||
|
@ -80,50 +99,25 @@ SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote"]
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session")
|
||||||
async def safety_stack(inference_model, safety_model, request):
|
async def safety_stack(inference_model, safety_shield, request):
|
||||||
# We need an inference + safety fixture to test safety
|
# We need an inference + safety fixture to test safety
|
||||||
fixture_dict = request.param
|
fixture_dict = request.param
|
||||||
inference_fixture = request.getfixturevalue(
|
|
||||||
f"inference_{fixture_dict['inference']}"
|
|
||||||
)
|
|
||||||
safety_fixture = request.getfixturevalue(f"safety_{fixture_dict['safety']}")
|
|
||||||
|
|
||||||
providers = {
|
providers = {}
|
||||||
"inference": inference_fixture.providers,
|
|
||||||
"safety": safety_fixture.providers,
|
|
||||||
}
|
|
||||||
provider_data = {}
|
provider_data = {}
|
||||||
if inference_fixture.provider_data:
|
for key in ["inference", "safety"]:
|
||||||
provider_data.update(inference_fixture.provider_data)
|
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||||
if safety_fixture.provider_data:
|
providers[key] = fixture.providers
|
||||||
provider_data.update(safety_fixture.provider_data)
|
if fixture.provider_data:
|
||||||
|
provider_data.update(fixture.provider_data)
|
||||||
|
|
||||||
shield_provider_type = safety_fixture.providers[0].provider_type
|
test_stack = await construct_stack_for_test(
|
||||||
shield_input = get_shield_to_register(shield_provider_type, safety_model)
|
|
||||||
|
|
||||||
print(f"inference_model: {inference_model}")
|
|
||||||
print(f"shield_input = {shield_input}")
|
|
||||||
impls = await resolve_impls_for_test_v2(
|
|
||||||
[Api.safety, Api.shields, Api.inference],
|
[Api.safety, Api.shields, Api.inference],
|
||||||
providers,
|
providers,
|
||||||
provider_data,
|
provider_data,
|
||||||
models=[ModelInput(model_id=inference_model)],
|
models=[ModelInput(model_id=inference_model)],
|
||||||
shields=[shield_input],
|
shields=[safety_shield],
|
||||||
)
|
)
|
||||||
|
|
||||||
shield = await impls[Api.shields].get_shield(shield_input.shield_id)
|
shield = await test_stack.impls[Api.shields].get_shield(safety_shield.shield_id)
|
||||||
return impls[Api.safety], impls[Api.shields], shield
|
return test_stack.impls[Api.safety], test_stack.impls[Api.shields], shield
|
||||||
|
|
||||||
|
|
||||||
def get_shield_to_register(provider_type: str, safety_model: str) -> ShieldInput:
|
|
||||||
if provider_type == "remote::bedrock":
|
|
||||||
identifier = get_env_or_fail("BEDROCK_GUARDRAIL_IDENTIFIER")
|
|
||||||
params = {"guardrailVersion": get_env_or_fail("BEDROCK_GUARDRAIL_VERSION")}
|
|
||||||
else:
|
|
||||||
params = {}
|
|
||||||
identifier = safety_model
|
|
||||||
|
|
||||||
return ShieldInput(
|
|
||||||
shield_id=identifier,
|
|
||||||
params=params,
|
|
||||||
)
|
|
||||||
|
|
|
@ -18,13 +18,6 @@ from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
class TestSafety:
|
class TestSafety:
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_new_shield(self, safety_stack):
|
|
||||||
_, shields_impl, shield = safety_stack
|
|
||||||
assert shield is not None
|
|
||||||
assert shield.provider_resource_id == shield.identifier
|
|
||||||
assert shield.provider_id is not None
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_shield_list(self, safety_stack):
|
async def test_shield_list(self, safety_stack):
|
||||||
_, shields_impl, _ = safety_stack
|
_, shields_impl, _ = safety_stack
|
||||||
|
|
|
@ -11,7 +11,7 @@ from llama_stack.apis.models import ModelInput
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, Provider
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
|
|
||||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||||
|
|
||||||
|
|
||||||
|
@ -74,7 +74,7 @@ async def scoring_stack(request, inference_model):
|
||||||
if fixture.provider_data:
|
if fixture.provider_data:
|
||||||
provider_data.update(fixture.provider_data)
|
provider_data.update(fixture.provider_data)
|
||||||
|
|
||||||
impls = await resolve_impls_for_test_v2(
|
test_stack = await construct_stack_for_test(
|
||||||
[Api.scoring, Api.datasetio, Api.inference],
|
[Api.scoring, Api.datasetio, Api.inference],
|
||||||
providers,
|
providers,
|
||||||
provider_data,
|
provider_data,
|
||||||
|
@ -88,4 +88,4 @@ async def scoring_stack(request, inference_model):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
return impls
|
return test_stack.impls
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue