forked from phoenix-oss/llama-stack-mirror
Kill "remote" providers and fix testing with a remote stack properly (#435)
# What does this PR do? This PR kills the notion of "pure passthrough" remote providers. You cannot specify a single provider you must specify a whole distribution (stack) as remote. This PR also significantly fixes / upgrades testing infrastructure so you can now test against a remotely hosted stack server by just doing ```bash pytest -s -v -m remote test_agents.py \ --inference-model=Llama3.1-8B-Instruct --safety-shield=Llama-Guard-3-1B \ --env REMOTE_STACK_URL=http://localhost:5001 ``` Also fixed `test_agents_persistence.py` (which was broken) and killed some deprecated testing functions. ## Test Plan All the tests.
This commit is contained in:
parent
59a65e34d3
commit
12947ac19e
28 changed files with 406 additions and 519 deletions
|
@ -20,21 +20,17 @@ from llama_stack.providers.datatypes import RemoteProviderConfig
|
|||
_CLIENT_CLASSES = {}
|
||||
|
||||
|
||||
async def get_client_impl(
|
||||
protocol, additional_protocol, config: RemoteProviderConfig, _deps: Any
|
||||
):
|
||||
client_class = create_api_client_class(protocol, additional_protocol)
|
||||
async def get_client_impl(protocol, config: RemoteProviderConfig, _deps: Any):
|
||||
client_class = create_api_client_class(protocol)
|
||||
impl = client_class(config.url)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
def create_api_client_class(protocol, additional_protocol) -> Type:
|
||||
def create_api_client_class(protocol) -> Type:
|
||||
if protocol in _CLIENT_CLASSES:
|
||||
return _CLIENT_CLASSES[protocol]
|
||||
|
||||
protocols = [protocol, additional_protocol] if additional_protocol else [protocol]
|
||||
|
||||
class APIClient:
|
||||
def __init__(self, base_url: str):
|
||||
print(f"({protocol.__name__}) Connecting to {base_url}")
|
||||
|
@ -42,11 +38,10 @@ def create_api_client_class(protocol, additional_protocol) -> Type:
|
|||
self.routes = {}
|
||||
|
||||
# Store routes for this protocol
|
||||
for p in protocols:
|
||||
for name, method in inspect.getmembers(p):
|
||||
if hasattr(method, "__webmethod__"):
|
||||
sig = inspect.signature(method)
|
||||
self.routes[name] = (method.__webmethod__, sig)
|
||||
for name, method in inspect.getmembers(protocol):
|
||||
if hasattr(method, "__webmethod__"):
|
||||
sig = inspect.signature(method)
|
||||
self.routes[name] = (method.__webmethod__, sig)
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
@ -160,17 +155,16 @@ def create_api_client_class(protocol, additional_protocol) -> Type:
|
|||
return ret
|
||||
|
||||
# Add protocol methods to the wrapper
|
||||
for p in protocols:
|
||||
for name, method in inspect.getmembers(p):
|
||||
if hasattr(method, "__webmethod__"):
|
||||
for name, method in inspect.getmembers(protocol):
|
||||
if hasattr(method, "__webmethod__"):
|
||||
|
||||
async def method_impl(self, *args, method_name=name, **kwargs):
|
||||
return await self.__acall__(method_name, *args, **kwargs)
|
||||
async def method_impl(self, *args, method_name=name, **kwargs):
|
||||
return await self.__acall__(method_name, *args, **kwargs)
|
||||
|
||||
method_impl.__name__ = name
|
||||
method_impl.__qualname__ = f"APIClient.{name}"
|
||||
method_impl.__signature__ = inspect.signature(method)
|
||||
setattr(APIClient, name, method_impl)
|
||||
method_impl.__name__ = name
|
||||
method_impl.__qualname__ = f"APIClient.{name}"
|
||||
method_impl.__signature__ = inspect.signature(method)
|
||||
setattr(APIClient, name, method_impl)
|
||||
|
||||
# Name the class after the protocol
|
||||
APIClient.__name__ = f"{protocol.__name__}Client"
|
||||
|
|
|
@ -9,7 +9,7 @@ from typing import Dict, List
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec, remote_provider_spec
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
|
||||
|
||||
def stack_apis() -> List[Api]:
|
||||
|
@ -62,9 +62,6 @@ def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
|
|||
for api in providable_apis():
|
||||
name = api.name.lower()
|
||||
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
||||
ret[api] = {
|
||||
"remote": remote_provider_spec(api),
|
||||
**{a.provider_type: a for a in module.available_providers()},
|
||||
}
|
||||
ret[api] = {a.provider_type: a for a in module.available_providers()}
|
||||
|
||||
return ret
|
||||
|
|
|
@ -28,6 +28,7 @@ from llama_stack.apis.scoring import Scoring
|
|||
from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||
from llama_stack.apis.shields import Shields
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
from llama_stack.distribution.client import get_client_impl
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
|
@ -59,12 +60,16 @@ def api_protocol_map() -> Dict[Api, Any]:
|
|||
|
||||
def additional_protocols_map() -> Dict[Api, Any]:
|
||||
return {
|
||||
Api.inference: (ModelsProtocolPrivate, Models),
|
||||
Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks),
|
||||
Api.safety: (ShieldsProtocolPrivate, Shields),
|
||||
Api.datasetio: (DatasetsProtocolPrivate, Datasets),
|
||||
Api.scoring: (ScoringFunctionsProtocolPrivate, ScoringFunctions),
|
||||
Api.eval_tasks: (EvalTasksProtocolPrivate, EvalTasks),
|
||||
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
|
||||
Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks, Api.memory_banks),
|
||||
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
|
||||
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
|
||||
Api.scoring: (
|
||||
ScoringFunctionsProtocolPrivate,
|
||||
ScoringFunctions,
|
||||
Api.scoring_functions,
|
||||
),
|
||||
Api.eval: (EvalTasksProtocolPrivate, EvalTasks, Api.eval_tasks),
|
||||
}
|
||||
|
||||
|
||||
|
@ -73,10 +78,13 @@ class ProviderWithSpec(Provider):
|
|||
spec: ProviderSpec
|
||||
|
||||
|
||||
ProviderRegistry = Dict[Api, Dict[str, ProviderSpec]]
|
||||
|
||||
|
||||
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
|
||||
async def resolve_impls(
|
||||
run_config: StackRunConfig,
|
||||
provider_registry: Dict[Api, Dict[str, ProviderSpec]],
|
||||
provider_registry: ProviderRegistry,
|
||||
dist_registry: DistributionRegistry,
|
||||
) -> Dict[Api, Any]:
|
||||
"""
|
||||
|
@ -273,17 +281,8 @@ async def instantiate_provider(
|
|||
config_type = instantiate_class_type(provider_spec.config_class)
|
||||
config = config_type(**provider.config)
|
||||
|
||||
if provider_spec.adapter:
|
||||
method = "get_adapter_impl"
|
||||
args = [config, deps]
|
||||
else:
|
||||
method = "get_client_impl"
|
||||
protocol = protocols[provider_spec.api]
|
||||
if provider_spec.api in additional_protocols:
|
||||
_, additional_protocol = additional_protocols[provider_spec.api]
|
||||
else:
|
||||
additional_protocol = None
|
||||
args = [protocol, additional_protocol, config, deps]
|
||||
method = "get_adapter_impl"
|
||||
args = [config, deps]
|
||||
|
||||
elif isinstance(provider_spec, AutoRoutedProviderSpec):
|
||||
method = "get_auto_router_impl"
|
||||
|
@ -313,7 +312,7 @@ async def instantiate_provider(
|
|||
not isinstance(provider_spec, AutoRoutedProviderSpec)
|
||||
and provider_spec.api in additional_protocols
|
||||
):
|
||||
additional_api, _ = additional_protocols[provider_spec.api]
|
||||
additional_api, _, _ = additional_protocols[provider_spec.api]
|
||||
check_protocol_compliance(impl, additional_api)
|
||||
|
||||
return impl
|
||||
|
@ -359,3 +358,29 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
|||
raise ValueError(
|
||||
f"Provider `{obj.__provider_id__} ({obj.__provider_spec__.api})` does not implement the following methods:\n{missing_methods}"
|
||||
)
|
||||
|
||||
|
||||
async def resolve_remote_stack_impls(
|
||||
config: RemoteProviderConfig,
|
||||
apis: List[str],
|
||||
) -> Dict[Api, Any]:
|
||||
protocols = api_protocol_map()
|
||||
additional_protocols = additional_protocols_map()
|
||||
|
||||
impls = {}
|
||||
for api_str in apis:
|
||||
api = Api(api_str)
|
||||
impls[api] = await get_client_impl(
|
||||
protocols[api],
|
||||
config,
|
||||
{},
|
||||
)
|
||||
if api in additional_protocols:
|
||||
_, additional_protocol, additional_api = additional_protocols[api]
|
||||
impls[additional_api] = await get_client_impl(
|
||||
additional_protocol,
|
||||
config,
|
||||
{},
|
||||
)
|
||||
|
||||
return impls
|
||||
|
|
|
@ -33,28 +33,20 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
|
|||
|
||||
api = get_impl_api(p)
|
||||
|
||||
if obj.provider_id == "remote":
|
||||
# TODO: this is broken right now because we use the generic
|
||||
# { identifier, provider_id, provider_resource_id } tuple here
|
||||
# but the APIs expect things like ModelInput, ShieldInput, etc.
|
||||
|
||||
# if this is just a passthrough, we want to let the remote
|
||||
# end actually do the registration with the correct provider
|
||||
obj = obj.model_copy(deep=True)
|
||||
obj.provider_id = ""
|
||||
assert obj.provider_id != "remote", "Remote provider should not be registered"
|
||||
|
||||
if api == Api.inference:
|
||||
return await p.register_model(obj)
|
||||
elif api == Api.safety:
|
||||
await p.register_shield(obj)
|
||||
return await p.register_shield(obj)
|
||||
elif api == Api.memory:
|
||||
await p.register_memory_bank(obj)
|
||||
return await p.register_memory_bank(obj)
|
||||
elif api == Api.datasetio:
|
||||
await p.register_dataset(obj)
|
||||
return await p.register_dataset(obj)
|
||||
elif api == Api.scoring:
|
||||
await p.register_scoring_function(obj)
|
||||
return await p.register_scoring_function(obj)
|
||||
elif api == Api.eval:
|
||||
await p.register_eval_task(obj)
|
||||
return await p.register_eval_task(obj)
|
||||
else:
|
||||
raise ValueError(f"Unknown API {api} for registering object with provider")
|
||||
|
||||
|
@ -82,15 +74,10 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
if cls is None:
|
||||
obj.provider_id = provider_id
|
||||
else:
|
||||
if provider_id == "remote":
|
||||
# if this is just a passthrough, we got the *WithProvider object
|
||||
# so we should just override the provider in-place
|
||||
obj.provider_id = provider_id
|
||||
else:
|
||||
# Create a copy of the model data and explicitly set provider_id
|
||||
model_data = obj.model_dump()
|
||||
model_data["provider_id"] = provider_id
|
||||
obj = cls(**model_data)
|
||||
# Create a copy of the model data and explicitly set provider_id
|
||||
model_data = obj.model_dump()
|
||||
model_data["provider_id"] = provider_id
|
||||
obj = cls(**model_data)
|
||||
await self.dist_registry.register(obj)
|
||||
|
||||
# Register all objects from providers
|
||||
|
@ -100,18 +87,14 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
p.model_store = self
|
||||
elif api == Api.safety:
|
||||
p.shield_store = self
|
||||
|
||||
elif api == Api.memory:
|
||||
p.memory_bank_store = self
|
||||
|
||||
elif api == Api.datasetio:
|
||||
p.dataset_store = self
|
||||
|
||||
elif api == Api.scoring:
|
||||
p.scoring_function_store = self
|
||||
scoring_functions = await p.list_scoring_functions()
|
||||
await add_objects(scoring_functions, pid, ScoringFn)
|
||||
|
||||
elif api == Api.eval:
|
||||
p.eval_task_store = self
|
||||
|
||||
|
|
|
@ -182,15 +182,6 @@ async def lifespan(app: FastAPI):
|
|||
await impl.shutdown()
|
||||
|
||||
|
||||
def create_dynamic_passthrough(
|
||||
downstream_url: str, downstream_headers: Optional[Dict[str, str]] = None
|
||||
):
|
||||
async def endpoint(request: Request):
|
||||
return await passthrough(request, downstream_url, downstream_headers)
|
||||
|
||||
return endpoint
|
||||
|
||||
|
||||
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
||||
# TODO: pass the api method and punt it to the Protocol definition directly
|
||||
return kwargs.get("stream", False)
|
||||
|
@ -305,28 +296,19 @@ def main(
|
|||
endpoints = all_endpoints[api]
|
||||
impl = impls[api]
|
||||
|
||||
if is_passthrough(impl.__provider_spec__):
|
||||
for endpoint in endpoints:
|
||||
url = impl.__provider_config__.url.rstrip("/") + endpoint.route
|
||||
getattr(app, endpoint.method)(endpoint.route)(
|
||||
create_dynamic_passthrough(url)
|
||||
)
|
||||
else:
|
||||
for endpoint in endpoints:
|
||||
if not hasattr(impl, endpoint.name):
|
||||
# ideally this should be a typing violation already
|
||||
raise ValueError(
|
||||
f"Could not find method {endpoint.name} on {impl}!!"
|
||||
)
|
||||
for endpoint in endpoints:
|
||||
if not hasattr(impl, endpoint.name):
|
||||
# ideally this should be a typing violation already
|
||||
raise ValueError(f"Could not find method {endpoint.name} on {impl}!!")
|
||||
|
||||
impl_method = getattr(impl, endpoint.name)
|
||||
impl_method = getattr(impl, endpoint.name)
|
||||
|
||||
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
||||
create_dynamic_typed_route(
|
||||
impl_method,
|
||||
endpoint.method,
|
||||
)
|
||||
getattr(app, endpoint.method)(endpoint.route, response_model=None)(
|
||||
create_dynamic_typed_route(
|
||||
impl_method,
|
||||
endpoint.method,
|
||||
)
|
||||
)
|
||||
|
||||
cprint(f"Serving API {api_str}", "white", attrs=["bold"])
|
||||
for endpoint in endpoints:
|
||||
|
|
|
@ -30,7 +30,7 @@ from llama_stack.apis.eval_tasks import * # noqa: F403
|
|||
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.resolver import resolve_impls
|
||||
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
|
||||
from llama_stack.distribution.store.registry import create_dist_registry
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
||||
|
@ -58,29 +58,23 @@ class LlamaStack(
|
|||
pass
|
||||
|
||||
|
||||
# Produces a stack of providers for the given run config. Not all APIs may be
|
||||
# asked for in the run config.
|
||||
async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]:
|
||||
dist_registry, _ = await create_dist_registry(
|
||||
run_config.metadata_store, run_config.image_name
|
||||
)
|
||||
RESOURCES = [
|
||||
("models", Api.models, "register_model", "list_models"),
|
||||
("shields", Api.shields, "register_shield", "list_shields"),
|
||||
("memory_banks", Api.memory_banks, "register_memory_bank", "list_memory_banks"),
|
||||
("datasets", Api.datasets, "register_dataset", "list_datasets"),
|
||||
(
|
||||
"scoring_fns",
|
||||
Api.scoring_functions,
|
||||
"register_scoring_function",
|
||||
"list_scoring_functions",
|
||||
),
|
||||
("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"),
|
||||
]
|
||||
|
||||
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
|
||||
|
||||
resources = [
|
||||
("models", Api.models, "register_model", "list_models"),
|
||||
("shields", Api.shields, "register_shield", "list_shields"),
|
||||
("memory_banks", Api.memory_banks, "register_memory_bank", "list_memory_banks"),
|
||||
("datasets", Api.datasets, "register_dataset", "list_datasets"),
|
||||
(
|
||||
"scoring_fns",
|
||||
Api.scoring_functions,
|
||||
"register_scoring_function",
|
||||
"list_scoring_functions",
|
||||
),
|
||||
("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"),
|
||||
]
|
||||
for rsrc, api, register_method, list_method in resources:
|
||||
async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
|
||||
for rsrc, api, register_method, list_method in RESOURCES:
|
||||
objects = getattr(run_config, rsrc)
|
||||
if api not in impls:
|
||||
continue
|
||||
|
@ -96,4 +90,18 @@ async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]:
|
|||
)
|
||||
|
||||
print("")
|
||||
|
||||
|
||||
# Produces a stack of providers for the given run config. Not all APIs may be
|
||||
# asked for in the run config.
|
||||
async def construct_stack(
|
||||
run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None
|
||||
) -> Dict[Api, Any]:
|
||||
dist_registry, _ = await create_dist_registry(
|
||||
run_config.metadata_store, run_config.image_name
|
||||
)
|
||||
impls = await resolve_impls(
|
||||
run_config, provider_registry or get_provider_registry(), dist_registry
|
||||
)
|
||||
await register_resources(run_config, impls)
|
||||
return impls
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue