mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 23:51:00 +00:00
More work towards making remote stacks usable from tests
This commit is contained in:
parent
8645f8bc9e
commit
8b7be87bec
7 changed files with 91 additions and 99 deletions
|
@ -59,12 +59,16 @@ def api_protocol_map() -> Dict[Api, Any]:
|
|||
|
||||
def additional_protocols_map() -> Dict[Api, Any]:
|
||||
return {
|
||||
Api.inference: (ModelsProtocolPrivate, Models),
|
||||
Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks),
|
||||
Api.safety: (ShieldsProtocolPrivate, Shields),
|
||||
Api.datasetio: (DatasetsProtocolPrivate, Datasets),
|
||||
Api.scoring: (ScoringFunctionsProtocolPrivate, ScoringFunctions),
|
||||
Api.eval_tasks: (EvalTasksProtocolPrivate, EvalTasks),
|
||||
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
|
||||
Api.memory: (MemoryBanksProtocolPrivate, MemoryBanks, Api.memory_banks),
|
||||
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
|
||||
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
|
||||
Api.scoring: (
|
||||
ScoringFunctionsProtocolPrivate,
|
||||
ScoringFunctions,
|
||||
Api.scoring_functions,
|
||||
),
|
||||
Api.eval: (EvalTasksProtocolPrivate, EvalTasks, Api.eval_tasks),
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -33,83 +33,20 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
|
|||
|
||||
api = get_impl_api(p)
|
||||
|
||||
is_remote = obj.provider_id == "remote"
|
||||
if is_remote:
|
||||
# TODO: these are incomplete fixes since (a) they are kind of adhoc and likely to break
|
||||
# and (b) MemoryBankInput is missing BankParams
|
||||
if isinstance(obj, Model):
|
||||
obj = ModelInput(
|
||||
model_id=obj.identifier,
|
||||
metadata=obj.metadata,
|
||||
provider_model_id=obj.provider_resource_id,
|
||||
)
|
||||
elif isinstance(obj, Shield):
|
||||
obj = ShieldInput(
|
||||
shield_id=obj.identifier,
|
||||
params=obj.params,
|
||||
provider_shield_id=obj.provider_resource_id,
|
||||
)
|
||||
elif isinstance(obj, MemoryBank):
|
||||
# need to calculate params here
|
||||
obj = MemoryBankInput(
|
||||
memory_bank_id=obj.identifier,
|
||||
provider_memory_bank_id=obj.provider_resource_id,
|
||||
)
|
||||
elif isinstance(obj, ScoringFn):
|
||||
obj = ScoringFnInput(
|
||||
scoring_fn_id=obj.identifier,
|
||||
provider_scoring_fn_id=obj.provider_resource_id,
|
||||
description=obj.description,
|
||||
metadata=obj.metadata,
|
||||
return_type=obj.return_type,
|
||||
params=obj.params,
|
||||
)
|
||||
elif isinstance(obj, EvalTask):
|
||||
obj = EvalTaskInput(
|
||||
eval_task_id=obj.identifier,
|
||||
provider_eval_task_id=obj.provider_resource_id,
|
||||
dataset_id=obj.dataset_id,
|
||||
scoring_function_id=obj.scoring_functions,
|
||||
metadata=obj.metadata,
|
||||
)
|
||||
elif isinstance(obj, Dataset):
|
||||
obj = DatasetInput(
|
||||
dataset_id=obj.identifier,
|
||||
provider_dataset_id=obj.provider_resource_id,
|
||||
schema=obj.schema,
|
||||
url=obj.url,
|
||||
metadata=obj.metadata,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown object type {type(obj)}")
|
||||
assert obj.provider_id != "remote", "Remote provider should not be registered"
|
||||
|
||||
if api == Api.inference:
|
||||
return await p.register_model(obj)
|
||||
elif api == Api.safety:
|
||||
if is_remote:
|
||||
await p.register_shield(**obj.model_dump())
|
||||
else:
|
||||
await p.register_shield(obj)
|
||||
await p.register_shield(**obj.model_dump())
|
||||
elif api == Api.memory:
|
||||
if is_remote:
|
||||
await p.register_memory_bank(**obj.model_dump())
|
||||
else:
|
||||
await p.register_memory_bank(obj)
|
||||
await p.register_memory_bank(**obj.model_dump())
|
||||
elif api == Api.datasetio:
|
||||
if is_remote:
|
||||
await p.register_dataset(**obj.model_dump())
|
||||
else:
|
||||
await p.register_dataset(obj)
|
||||
await p.register_dataset(**obj.model_dump())
|
||||
elif api == Api.scoring:
|
||||
if is_remote:
|
||||
await p.register_scoring_function(**obj.model_dump())
|
||||
else:
|
||||
await p.register_scoring_function(obj)
|
||||
await p.register_scoring_function(**obj.model_dump())
|
||||
elif api == Api.eval:
|
||||
if is_remote:
|
||||
await p.register_eval_task(**obj.model_dump())
|
||||
else:
|
||||
await p.register_eval_task(obj)
|
||||
await p.register_eval_task(**obj.model_dump())
|
||||
else:
|
||||
raise ValueError(f"Unknown API {api} for registering object with provider")
|
||||
|
||||
|
@ -137,15 +74,10 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
if cls is None:
|
||||
obj.provider_id = provider_id
|
||||
else:
|
||||
if provider_id == "remote":
|
||||
# if this is just a passthrough, we got the *WithProvider object
|
||||
# so we should just override the provider in-place
|
||||
obj.provider_id = provider_id
|
||||
else:
|
||||
# Create a copy of the model data and explicitly set provider_id
|
||||
model_data = obj.model_dump()
|
||||
model_data["provider_id"] = provider_id
|
||||
obj = cls(**model_data)
|
||||
# Create a copy of the model data and explicitly set provider_id
|
||||
model_data = obj.model_dump()
|
||||
model_data["provider_id"] = provider_id
|
||||
obj = cls(**model_data)
|
||||
await self.dist_registry.register(obj)
|
||||
|
||||
# Register all objects from providers
|
||||
|
|
|
@ -28,11 +28,16 @@ from llama_stack.apis.shields import * # noqa: F403
|
|||
from llama_stack.apis.inspect import * # noqa: F403
|
||||
from llama_stack.apis.eval_tasks import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.client import get_client_impl
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.resolver import resolve_impls
|
||||
from llama_stack.distribution.resolver import (
|
||||
additional_protocols_map,
|
||||
api_protocol_map,
|
||||
resolve_impls,
|
||||
)
|
||||
from llama_stack.distribution.store.registry import create_dist_registry
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack.providers.datatypes import Api, RemoteProviderConfig
|
||||
|
||||
|
||||
class LlamaStack(
|
||||
|
@ -65,7 +70,9 @@ async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]:
|
|||
run_config.metadata_store, run_config.image_name
|
||||
)
|
||||
|
||||
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
|
||||
impls = await maybe_get_remote_stack_impls(run_config)
|
||||
if impls is None:
|
||||
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
|
||||
|
||||
resources = [
|
||||
("models", Api.models, "register_model", "list_models"),
|
||||
|
@ -97,3 +104,54 @@ async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]:
|
|||
|
||||
print("")
|
||||
return impls
|
||||
|
||||
|
||||
# NOTE: this code path is really for the tests so you can send HTTP requests
|
||||
# to the remote stack without needing to use llama-stack-client
|
||||
async def maybe_get_remote_stack_impls(
|
||||
run_config: StackRunConfig,
|
||||
) -> Optional[Dict[Api, Any]]:
|
||||
remote_config = remote_provider_config(run_config)
|
||||
if not remote_config:
|
||||
return None
|
||||
|
||||
protocols = api_protocol_map()
|
||||
additional_protocols = additional_protocols_map()
|
||||
|
||||
impls = {}
|
||||
for api_str in run_config.apis:
|
||||
api = Api(api_str)
|
||||
impls[api] = await get_client_impl(
|
||||
protocols[api],
|
||||
None,
|
||||
remote_config,
|
||||
{},
|
||||
)
|
||||
if api in additional_protocols:
|
||||
_, additional_protocol, additional_api = additional_protocols[api]
|
||||
impls[additional_api] = await get_client_impl(
|
||||
additional_protocol,
|
||||
None,
|
||||
remote_config,
|
||||
{},
|
||||
)
|
||||
|
||||
return impls
|
||||
|
||||
|
||||
def remote_provider_config(
|
||||
run_config: StackRunConfig,
|
||||
) -> Optional[RemoteProviderConfig]:
|
||||
remote_config = None
|
||||
has_non_remote = False
|
||||
for api_providers in run_config.providers.values():
|
||||
for provider in api_providers:
|
||||
if provider.provider_type == "remote":
|
||||
remote_config = RemoteProviderConfig(**provider.config)
|
||||
else:
|
||||
has_non_remote = True
|
||||
|
||||
if remote_config:
|
||||
assert not has_non_remote, "Remote stack cannot have non-remote providers"
|
||||
|
||||
return remote_config
|
||||
|
|
|
@ -53,6 +53,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
adapter_type="chromadb",
|
||||
pip_packages=EMBEDDING_DEPS + ["chromadb-client"],
|
||||
module="llama_stack.providers.remote.memory.chroma",
|
||||
config_class="llama_stack.distribution.datatypes.RemoteProviderConfig",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
|
|
|
@ -85,4 +85,5 @@ async def agents_stack(request, inference_model, safety_shield):
|
|||
],
|
||||
shields=[safety_shield],
|
||||
)
|
||||
|
||||
return impls[Api.agents], impls[Api.memory]
|
||||
|
|
|
@ -186,12 +186,7 @@ async def inference_stack(request, inference_model):
|
|||
[Api.inference],
|
||||
{"inference": inference_fixture.providers},
|
||||
inference_fixture.provider_data,
|
||||
models=[
|
||||
ModelInput(
|
||||
model_id=inference_model,
|
||||
provider_id=inference_fixture.providers[0].provider_id,
|
||||
)
|
||||
],
|
||||
models=[ModelInput(model_id=inference_model)],
|
||||
)
|
||||
|
||||
return (impls[Api.inference], impls[Api.models])
|
||||
|
|
|
@ -17,6 +17,7 @@ from llama_stack.distribution.build import print_pip_install_help
|
|||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.resolver import resolve_impls
|
||||
from llama_stack.distribution.stack import construct_stack
|
||||
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
|
||||
|
||||
|
@ -25,12 +26,12 @@ async def resolve_impls_for_test_v2(
|
|||
apis: List[Api],
|
||||
providers: Dict[str, List[Provider]],
|
||||
provider_data: Optional[Dict[str, Any]] = None,
|
||||
models: Optional[List[Model]] = None,
|
||||
shields: Optional[List[Shield]] = None,
|
||||
memory_banks: Optional[List[MemoryBank]] = None,
|
||||
datasets: Optional[List[Dataset]] = None,
|
||||
scoring_fns: Optional[List[ScoringFn]] = None,
|
||||
eval_tasks: Optional[List[EvalTask]] = None,
|
||||
models: Optional[List[ModelInput]] = None,
|
||||
shields: Optional[List[ShieldInput]] = None,
|
||||
memory_banks: Optional[List[MemoryBankInput]] = None,
|
||||
datasets: Optional[List[DatasetInput]] = None,
|
||||
scoring_fns: Optional[List[ScoringFnInput]] = None,
|
||||
eval_tasks: Optional[List[EvalTaskInput]] = None,
|
||||
):
|
||||
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||
run_config = dict(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue