apis_to_serve -> apis

This commit is contained in:
Ashwin Bharambe 2024-10-05 23:16:11 -07:00 committed by Ashwin Bharambe
parent 59302a86df
commit 60dead6196
7 changed files with 38 additions and 48 deletions

View file

@ -154,7 +154,7 @@ class StackConfigure(Subcommand):
config = StackRunConfig(
built_at=datetime.now(),
image_name=image_name,
apis_to_serve=[],
apis=[],
providers={},
models=[],
shields=[],

View file

@ -46,6 +46,7 @@ class StackRun(Subcommand):
import pkg_resources
import yaml
from termcolor import cprint
from llama_stack.distribution.build import ImageType
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
@ -75,6 +76,7 @@ class StackRun(Subcommand):
)
return
cprint(f"Using config `{config_file}`", "green")
with open(config_file, "r") as f:
config = StackRunConfig(**yaml.safe_load(f))

View file

@ -64,8 +64,8 @@ def configure_api_providers(
) -> StackRunConfig:
is_nux = len(config.providers) == 0
apis = set((config.apis_to_serve or list(build_spec.providers.keys())))
config.apis_to_serve = [a for a in apis if a != "telemetry"]
apis = set((config.apis or list(build_spec.providers.keys())))
config.apis = [a for a in apis if a != "telemetry"]
if is_nux:
print(
@ -79,7 +79,7 @@ def configure_api_providers(
provider_registry = get_provider_registry()
builtin_apis = [a.routing_table_api for a in builtin_automatically_routed_apis()]
for api_str in config.apis_to_serve:
for api_str in config.apis:
api = Api(api_str)
if api in builtin_apis:
continue
@ -342,6 +342,9 @@ def upgrade_from_routing_table_to_registry(
del config_dict["routing_table"]
del config_dict["api_providers"]
config_dict["apis"] = config_dict["apis_to_serve"]
del config_dict["apis_to_serve"]
return config_dict

View file

@ -39,15 +39,6 @@ RoutedProtocol = Union[
]
class GenericProviderConfig(BaseModel):
provider_type: str
config: Dict[str, Any]
class RoutableProviderConfig(GenericProviderConfig):
routing_key: RoutingKey
# Example: /inference, /safety
class AutoRoutedProviderSpec(ProviderSpec):
provider_type: str = "router"
@ -92,7 +83,6 @@ in the runtime configuration to help route to the correct provider.""",
)
# TODO: rename as ProviderInstanceConfig
class Provider(BaseModel):
provider_id: str
provider_type: str
@ -118,40 +108,36 @@ this could be just a hash
default=None,
description="Reference to the conda environment if this package refers to a conda environment",
)
apis_to_serve: List[str] = Field(
apis: List[str] = Field(
description="""
The list of APIs to serve. If not specified, all APIs specified in the provider_map will be served""",
)
providers: Dict[str, List[Provider]]
providers: Dict[str, List[Provider]] = Field(
description="""
One or more providers to use for each API. The same provider_type (e.g., meta-reference)
can be instantiated multiple times (with different configs) if necessary.
""",
)
models: List[ModelDef]
shields: List[ShieldDef]
memory_banks: List[MemoryBankDef]
# api_providers: Dict[
# str, Union[GenericProviderConfig, PlaceholderProviderConfig]
# ] = Field(
# description="""
# Provider configurations for each of the APIs provided by this package.
# """,
# )
# routing_table: Dict[str, List[RoutableProviderConfig]] = Field(
# default_factory=dict,
# description="""
# E.g. The following is a ProviderRoutingEntry for models:
# - routing_key: Llama3.1-8B-Instruct
# provider_type: meta-reference
# config:
# model: Llama3.1-8B-Instruct
# quantization: null
# torch_seed: null
# max_seq_len: 4096
# max_batch_size: 1
# """,
# )
models: List[ModelDef] = Field(
description="""
List of model definitions to serve. This list may get extended by
/models/register API calls at runtime.
""",
)
shields: List[ShieldDef] = Field(
description="""
List of shield definitions to serve. This list may get extended by
/shields/register API calls at runtime.
""",
)
memory_banks: List[MemoryBankDef] = Field(
description="""
List of memory bank definitions to serve. This list may get extended by
/memory_banks/register API calls at runtime.
""",
)
class BuildConfig(BaseModel):

View file

@ -59,7 +59,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
key = api_str if api not in router_apis else f"inner-{api_str}"
providers_with_specs[key] = specs
apis_to_serve = run_config.apis_to_serve or set(
apis_to_serve = run_config.apis or set(
list(providers_with_specs.keys()) + list(routing_table_apis)
)

View file

@ -291,8 +291,8 @@ def main(
all_endpoints = get_all_api_endpoints()
if config.apis_to_serve:
apis_to_serve = set(config.apis_to_serve)
if config.apis:
apis_to_serve = set(config.apis)
else:
apis_to_serve = set(impls.keys())

View file

@ -20,8 +20,7 @@ from llama_stack.providers.utils.inference.augment_messages import (
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
# TODO: Eventually this will move to the llama cli model list command
# mapping of Model SKUs to ollama models
OLLAMA_SUPPORTED_SKUS = {
"Llama3.1-8B-Instruct": "llama3.1:8b-instruct-fp16",
"Llama3.1-70B-Instruct": "llama3.1:70b-instruct-fp16",