More work towards making remote stacks usable from tests

This commit is contained in:
Ashwin Bharambe 2024-11-12 17:09:31 -08:00
parent 8645f8bc9e
commit 8b7be87bec
7 changed files with 91 additions and 99 deletions

View file

@ -59,12 +59,16 @@ def api_protocol_map() -> Dict[Api, Any]:
def additional_protocols_map() -> Dict[Api, Any]:
return {
Api.inference: (ModelsProtocolPrivate, Models),
Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks),
Api.safety: (ShieldsProtocolPrivate, Shields),
Api.datasetio: (DatasetsProtocolPrivate, Datasets),
Api.scoring: (ScoringFunctionsProtocolPrivate, ScoringFunctions),
Api.eval_tasks: (EvalTasksProtocolPrivate, EvalTasks),
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks, Api.memory_banks),
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
Api.scoring: (
ScoringFunctionsProtocolPrivate,
ScoringFunctions,
Api.scoring_functions,
),
Api.eval: (EvalTasksProtocolPrivate, EvalTasks, Api.eval_tasks),
}

View file

@ -33,83 +33,20 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
api = get_impl_api(p)
is_remote = obj.provider_id == "remote"
if is_remote:
# TODO: these are incomplete fixes since (a) they are kind of adhoc and likely to break
# and (b) MemoryBankInput is missing BankParams
if isinstance(obj, Model):
obj = ModelInput(
model_id=obj.identifier,
metadata=obj.metadata,
provider_model_id=obj.provider_resource_id,
)
elif isinstance(obj, Shield):
obj = ShieldInput(
shield_id=obj.identifier,
params=obj.params,
provider_shield_id=obj.provider_resource_id,
)
elif isinstance(obj, MemoryBank):
# need to calculate params here
obj = MemoryBankInput(
memory_bank_id=obj.identifier,
provider_memory_bank_id=obj.provider_resource_id,
)
elif isinstance(obj, ScoringFn):
obj = ScoringFnInput(
scoring_fn_id=obj.identifier,
provider_scoring_fn_id=obj.provider_resource_id,
description=obj.description,
metadata=obj.metadata,
return_type=obj.return_type,
params=obj.params,
)
elif isinstance(obj, EvalTask):
obj = EvalTaskInput(
eval_task_id=obj.identifier,
provider_eval_task_id=obj.provider_resource_id,
dataset_id=obj.dataset_id,
scoring_function_id=obj.scoring_functions,
metadata=obj.metadata,
)
elif isinstance(obj, Dataset):
obj = DatasetInput(
dataset_id=obj.identifier,
provider_dataset_id=obj.provider_resource_id,
schema=obj.schema,
url=obj.url,
metadata=obj.metadata,
)
else:
raise ValueError(f"Unknown object type {type(obj)}")
assert obj.provider_id != "remote", "Remote provider should not be registered"
if api == Api.inference:
return await p.register_model(obj)
elif api == Api.safety:
if is_remote:
await p.register_shield(**obj.model_dump())
else:
await p.register_shield(obj)
await p.register_shield(**obj.model_dump())
elif api == Api.memory:
if is_remote:
await p.register_memory_bank(**obj.model_dump())
else:
await p.register_memory_bank(obj)
await p.register_memory_bank(**obj.model_dump())
elif api == Api.datasetio:
if is_remote:
await p.register_dataset(**obj.model_dump())
else:
await p.register_dataset(obj)
await p.register_dataset(**obj.model_dump())
elif api == Api.scoring:
if is_remote:
await p.register_scoring_function(**obj.model_dump())
else:
await p.register_scoring_function(obj)
await p.register_scoring_function(**obj.model_dump())
elif api == Api.eval:
if is_remote:
await p.register_eval_task(**obj.model_dump())
else:
await p.register_eval_task(obj)
await p.register_eval_task(**obj.model_dump())
else:
raise ValueError(f"Unknown API {api} for registering object with provider")
@ -137,15 +74,10 @@ class CommonRoutingTableImpl(RoutingTable):
if cls is None:
obj.provider_id = provider_id
else:
if provider_id == "remote":
# if this is just a passthrough, we got the *WithProvider object
# so we should just override the provider in-place
obj.provider_id = provider_id
else:
# Create a copy of the model data and explicitly set provider_id
model_data = obj.model_dump()
model_data["provider_id"] = provider_id
obj = cls(**model_data)
# Create a copy of the model data and explicitly set provider_id
model_data = obj.model_dump()
model_data["provider_id"] = provider_id
obj = cls(**model_data)
await self.dist_registry.register(obj)
# Register all objects from providers

View file

@ -28,11 +28,16 @@ from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.inspect import * # noqa: F403
from llama_stack.apis.eval_tasks import * # noqa: F403
from llama_stack.distribution.client import get_client_impl
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import resolve_impls
from llama_stack.distribution.resolver import (
additional_protocols_map,
api_protocol_map,
resolve_impls,
)
from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.providers.datatypes import Api
from llama_stack.providers.datatypes import Api, RemoteProviderConfig
class LlamaStack(
@ -65,7 +70,9 @@ async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]:
run_config.metadata_store, run_config.image_name
)
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
impls = await maybe_get_remote_stack_impls(run_config)
if impls is None:
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
resources = [
("models", Api.models, "register_model", "list_models"),
@ -97,3 +104,54 @@ async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]:
print("")
return impls
# NOTE: this code path is really for the tests so you can send HTTP requests
# to the remote stack without needing to use llama-stack-client
async def maybe_get_remote_stack_impls(
run_config: StackRunConfig,
) -> Optional[Dict[Api, Any]]:
remote_config = remote_provider_config(run_config)
if not remote_config:
return None
protocols = api_protocol_map()
additional_protocols = additional_protocols_map()
impls = {}
for api_str in run_config.apis:
api = Api(api_str)
impls[api] = await get_client_impl(
protocols[api],
None,
remote_config,
{},
)
if api in additional_protocols:
_, additional_protocol, additional_api = additional_protocols[api]
impls[additional_api] = await get_client_impl(
additional_protocol,
None,
remote_config,
{},
)
return impls
def remote_provider_config(
run_config: StackRunConfig,
) -> Optional[RemoteProviderConfig]:
remote_config = None
has_non_remote = False
for api_providers in run_config.providers.values():
for provider in api_providers:
if provider.provider_type == "remote":
remote_config = RemoteProviderConfig(**provider.config)
else:
has_non_remote = True
if remote_config:
assert not has_non_remote, "Remote stack cannot have non-remote providers"
return remote_config

View file

@ -53,6 +53,7 @@ def available_providers() -> List[ProviderSpec]:
adapter_type="chromadb",
pip_packages=EMBEDDING_DEPS + ["chromadb-client"],
module="llama_stack.providers.remote.memory.chroma",
config_class="llama_stack.distribution.datatypes.RemoteProviderConfig",
),
),
remote_provider_spec(

View file

@ -85,4 +85,5 @@ async def agents_stack(request, inference_model, safety_shield):
],
shields=[safety_shield],
)
return impls[Api.agents], impls[Api.memory]

View file

@ -186,12 +186,7 @@ async def inference_stack(request, inference_model):
[Api.inference],
{"inference": inference_fixture.providers},
inference_fixture.provider_data,
models=[
ModelInput(
model_id=inference_model,
provider_id=inference_fixture.providers[0].provider_id,
)
],
models=[ModelInput(model_id=inference_model)],
)
return (impls[Api.inference], impls[Api.models])

View file

@ -17,6 +17,7 @@ from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls
from llama_stack.distribution.stack import construct_stack
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
@ -25,12 +26,12 @@ async def resolve_impls_for_test_v2(
apis: List[Api],
providers: Dict[str, List[Provider]],
provider_data: Optional[Dict[str, Any]] = None,
models: Optional[List[Model]] = None,
shields: Optional[List[Shield]] = None,
memory_banks: Optional[List[MemoryBank]] = None,
datasets: Optional[List[Dataset]] = None,
scoring_fns: Optional[List[ScoringFn]] = None,
eval_tasks: Optional[List[EvalTask]] = None,
models: Optional[List[ModelInput]] = None,
shields: Optional[List[ShieldInput]] = None,
memory_banks: Optional[List[MemoryBankInput]] = None,
datasets: Optional[List[DatasetInput]] = None,
scoring_fns: Optional[List[ScoringFnInput]] = None,
eval_tasks: Optional[List[EvalTaskInput]] = None,
):
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
run_config = dict(