diff --git a/llama_stack/cli/stack/configure.py b/llama_stack/cli/stack/configure.py index 76ade470e..9aa7e2f6e 100644 --- a/llama_stack/cli/stack/configure.py +++ b/llama_stack/cli/stack/configure.py @@ -154,7 +154,7 @@ class StackConfigure(Subcommand): config = StackRunConfig( built_at=datetime.now(), image_name=image_name, - apis_to_serve=[], + apis=[], providers={}, models=[], shields=[], diff --git a/llama_stack/cli/stack/run.py b/llama_stack/cli/stack/run.py index 1c528baed..033b2a81f 100644 --- a/llama_stack/cli/stack/run.py +++ b/llama_stack/cli/stack/run.py @@ -46,6 +46,7 @@ class StackRun(Subcommand): import pkg_resources import yaml + from termcolor import cprint from llama_stack.distribution.build import ImageType from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR @@ -75,6 +76,7 @@ class StackRun(Subcommand): ) return + cprint(f"Using config `{config_file}`", "green") with open(config_file, "r") as f: config = StackRunConfig(**yaml.safe_load(f)) diff --git a/llama_stack/distribution/configure.py b/llama_stack/distribution/configure.py index b40cff242..f343c13bb 100644 --- a/llama_stack/distribution/configure.py +++ b/llama_stack/distribution/configure.py @@ -64,8 +64,8 @@ def configure_api_providers( ) -> StackRunConfig: is_nux = len(config.providers) == 0 - apis = set((config.apis_to_serve or list(build_spec.providers.keys()))) - config.apis_to_serve = [a for a in apis if a != "telemetry"] + apis = set((config.apis or list(build_spec.providers.keys()))) + config.apis = [a for a in apis if a != "telemetry"] if is_nux: print( @@ -79,7 +79,7 @@ def configure_api_providers( provider_registry = get_provider_registry() builtin_apis = [a.routing_table_api for a in builtin_automatically_routed_apis()] - for api_str in config.apis_to_serve: + for api_str in config.apis: api = Api(api_str) if api in builtin_apis: continue @@ -342,6 +342,9 @@ def upgrade_from_routing_table_to_registry( del config_dict["routing_table"] del config_dict["api_providers"] + config_dict["apis"] = config_dict["apis_to_serve"] + del config_dict["apis_to_serve"] + return config_dict diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 05b2ad0d6..c987d4c87 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -39,15 +39,6 @@ RoutedProtocol = Union[ ] -class GenericProviderConfig(BaseModel): - provider_type: str - config: Dict[str, Any] - - -class RoutableProviderConfig(GenericProviderConfig): - routing_key: RoutingKey - - # Example: /inference, /safety class AutoRoutedProviderSpec(ProviderSpec): provider_type: str = "router" @@ -92,7 +83,6 @@ in the runtime configuration to help route to the correct provider.""", ) -# TODO: rename as ProviderInstanceConfig class Provider(BaseModel): provider_id: str provider_type: str @@ -118,40 +108,36 @@ this could be just a hash default=None, description="Reference to the conda environment if this package refers to a conda environment", ) - apis_to_serve: List[str] = Field( + apis: List[str] = Field( description=""" The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""", ) - providers: Dict[str, List[Provider]] + providers: Dict[str, List[Provider]] = Field( + description=""" +One or more providers to use for each API. The same provider_type (e.g., meta-reference) +can be instantiated multiple times (with different configs) if necessary. +""", + ) - models: List[ModelDef] - shields: List[ShieldDef] - memory_banks: List[MemoryBankDef] - - -# api_providers: Dict[ -# str, Union[GenericProviderConfig, PlaceholderProviderConfig] -# ] = Field( -# description=""" -# Provider configurations for each of the APIs provided by this package. -# """, -# ) -# routing_table: Dict[str, List[RoutableProviderConfig]] = Field( -# default_factory=dict, -# description=""" - -# E.g. The following is a ProviderRoutingEntry for models: -# - routing_key: Llama3.1-8B-Instruct -# provider_type: meta-reference -# config: -# model: Llama3.1-8B-Instruct -# quantization: null -# torch_seed: null -# max_seq_len: 4096 -# max_batch_size: 1 -# """, -# ) + models: List[ModelDef] = Field( + description=""" +List of model definitions to serve. This list may get extended by +/models/register API calls at runtime. +""", + ) + shields: List[ShieldDef] = Field( + description=""" +List of shield definitions to serve. This list may get extended by +/shields/register API calls at runtime. +""", + ) + memory_banks: List[MemoryBankDef] = Field( + description=""" +List of memory bank definitions to serve. This list may get extended by +/memory_banks/register API calls at runtime. +""", + ) class BuildConfig(BaseModel): diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 2c383587c..d0c3adb84 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -59,7 +59,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An key = api_str if api not in router_apis else f"inner-{api_str}" providers_with_specs[key] = specs - apis_to_serve = run_config.apis_to_serve or set( + apis_to_serve = run_config.apis or set( list(providers_with_specs.keys()) + list(routing_table_apis) ) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index f664bb674..ed3b4b9f2 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -291,8 +291,8 @@ def main( all_endpoints = get_all_api_endpoints() - if config.apis_to_serve: - apis_to_serve = set(config.apis_to_serve) + if config.apis: + apis_to_serve = set(config.apis) else: apis_to_serve = set(impls.keys()) diff --git a/llama_stack/providers/adapters/inference/ollama/ollama.py b/llama_stack/providers/adapters/inference/ollama/ollama.py index bc1b3d103..aa9a25658 100644 --- a/llama_stack/providers/adapters/inference/ollama/ollama.py +++ b/llama_stack/providers/adapters/inference/ollama/ollama.py @@ -20,8 +20,7 @@ from llama_stack.providers.utils.inference.augment_messages import ( ) from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper -# TODO: Eventually this will move to the llama cli model list command -# mapping of Model SKUs to ollama models + OLLAMA_SUPPORTED_SKUS = { "Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16", "Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",