apis_to_serve -> apis

This commit is contained in:
Ashwin Bharambe 2024-10-05 23:16:11 -07:00 committed by Ashwin Bharambe
parent 59302a86df
commit 60dead6196
7 changed files with 38 additions and 48 deletions

View file

@ -154,7 +154,7 @@ class StackConfigure(Subcommand):
config = StackRunConfig( config = StackRunConfig(
built_at=datetime.now(), built_at=datetime.now(),
image_name=image_name, image_name=image_name,
apis_to_serve=[], apis=[],
providers={}, providers={},
models=[], models=[],
shields=[], shields=[],

View file

@ -46,6 +46,7 @@ class StackRun(Subcommand):
import pkg_resources import pkg_resources
import yaml import yaml
from termcolor import cprint
from llama_stack.distribution.build import ImageType from llama_stack.distribution.build import ImageType
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
@ -75,6 +76,7 @@ class StackRun(Subcommand):
) )
return return
cprint(f"Using config `{config_file}`", "green")
with open(config_file, "r") as f: with open(config_file, "r") as f:
config = StackRunConfig(**yaml.safe_load(f)) config = StackRunConfig(**yaml.safe_load(f))

View file

@ -64,8 +64,8 @@ def configure_api_providers(
) -> StackRunConfig: ) -> StackRunConfig:
is_nux = len(config.providers) == 0 is_nux = len(config.providers) == 0
apis = set((config.apis_to_serve or list(build_spec.providers.keys()))) apis = set((config.apis or list(build_spec.providers.keys())))
config.apis_to_serve = [a for a in apis if a != "telemetry"] config.apis = [a for a in apis if a != "telemetry"]
if is_nux: if is_nux:
print( print(
@ -79,7 +79,7 @@ def configure_api_providers(
provider_registry = get_provider_registry() provider_registry = get_provider_registry()
builtin_apis = [a.routing_table_api for a in builtin_automatically_routed_apis()] builtin_apis = [a.routing_table_api for a in builtin_automatically_routed_apis()]
for api_str in config.apis_to_serve: for api_str in config.apis:
api = Api(api_str) api = Api(api_str)
if api in builtin_apis: if api in builtin_apis:
continue continue
@ -342,6 +342,9 @@ def upgrade_from_routing_table_to_registry(
del config_dict["routing_table"] del config_dict["routing_table"]
del config_dict["api_providers"] del config_dict["api_providers"]
config_dict["apis"] = config_dict["apis_to_serve"]
del config_dict["apis_to_serve"]
return config_dict return config_dict

View file

@ -39,15 +39,6 @@ RoutedProtocol = Union[
] ]
class GenericProviderConfig(BaseModel):
provider_type: str
config: Dict[str, Any]
class RoutableProviderConfig(GenericProviderConfig):
routing_key: RoutingKey
# Example: /inference, /safety # Example: /inference, /safety
class AutoRoutedProviderSpec(ProviderSpec): class AutoRoutedProviderSpec(ProviderSpec):
provider_type: str = "router" provider_type: str = "router"
@ -92,7 +83,6 @@ in the runtime configuration to help route to the correct provider.""",
) )
# TODO: rename as ProviderInstanceConfig
class Provider(BaseModel): class Provider(BaseModel):
provider_id: str provider_id: str
provider_type: str provider_type: str
@ -118,40 +108,36 @@ this could be just a hash
default=None, default=None,
description="Reference to the conda environment if this package refers to a conda environment", description="Reference to the conda environment if this package refers to a conda environment",
) )
apis_to_serve: List[str] = Field( apis: List[str] = Field(
description=""" description="""
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""", The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
) )
providers: Dict[str, List[Provider]] providers: Dict[str, List[Provider]] = Field(
description="""
One or more providers to use for each API. The same provider_type (e.g., meta-reference)
can be instantiated multiple times (with different configs) if necessary.
""",
)
models: List[ModelDef] models: List[ModelDef] = Field(
shields: List[ShieldDef] description="""
memory_banks: List[MemoryBankDef] List of model definitions to serve. This list may get extended by
/models/register API calls at runtime.
""",
# api_providers: Dict[ )
# str, Union[GenericProviderConfig, PlaceholderProviderConfig] shields: List[ShieldDef] = Field(
# ] = Field( description="""
# description=""" List of shield definitions to serve. This list may get extended by
# Provider configurations for each of the APIs provided by this package. /shields/register API calls at runtime.
# """, """,
# ) )
# routing_table: Dict[str, List[RoutableProviderConfig]] = Field( memory_banks: List[MemoryBankDef] = Field(
# default_factory=dict, description="""
# description=""" List of memory bank definitions to serve. This list may get extended by
/memory_banks/register API calls at runtime.
# E.g. The following is a ProviderRoutingEntry for models: """,
# - routing_key: Llama3.1-8B-Instruct )
# provider_type: meta-reference
# config:
# model: Llama3.1-8B-Instruct
# quantization: null
# torch_seed: null
# max_seq_len: 4096
# max_batch_size: 1
# """,
# )
class BuildConfig(BaseModel): class BuildConfig(BaseModel):

View file

@ -59,7 +59,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
key = api_str if api not in router_apis else f"inner-{api_str}" key = api_str if api not in router_apis else f"inner-{api_str}"
providers_with_specs[key] = specs providers_with_specs[key] = specs
apis_to_serve = run_config.apis_to_serve or set( apis_to_serve = run_config.apis or set(
list(providers_with_specs.keys()) + list(routing_table_apis) list(providers_with_specs.keys()) + list(routing_table_apis)
) )

View file

@ -291,8 +291,8 @@ def main(
all_endpoints = get_all_api_endpoints() all_endpoints = get_all_api_endpoints()
if config.apis_to_serve: if config.apis:
apis_to_serve = set(config.apis_to_serve) apis_to_serve = set(config.apis)
else: else:
apis_to_serve = set(impls.keys()) apis_to_serve = set(impls.keys())

View file

@ -20,8 +20,7 @@ from llama_stack.providers.utils.inference.augment_messages import (
) )
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
# TODO: Eventually this will move to the llama cli model list command
# mapping of Model SKUs to ollama models
OLLAMA_SUPPORTED_SKUS = { OLLAMA_SUPPORTED_SKUS = {
"Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16", "Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
"Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16", "Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",