From 2ed65a47a4e38eedbb74addee47e508002c49602 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sun, 22 Sep 2024 16:30:32 -0700 Subject: [PATCH] Basic build and run succeeded --- llama_stack/apis/memory_banks/memory_banks.py | 2 +- llama_stack/distribution/datatypes.py | 6 +--- llama_stack/distribution/distribution.py | 10 +++--- llama_stack/distribution/routers/__init__.py | 11 +++--- .../distribution/routers/routing_tables.py | 2 ++ llama_stack/distribution/server/server.py | 36 ++++++++++++++++--- llama_stack/distribution/utils/dynamic.py | 1 + .../impls/meta_reference/inference/config.py | 10 +++--- llama_stack/providers/registry/inference.py | 14 -------- 9 files changed, 50 insertions(+), 42 deletions(-) diff --git a/llama_stack/apis/memory_banks/memory_banks.py b/llama_stack/apis/memory_banks/memory_banks.py index 901899512..23bfb69e1 100644 --- a/llama_stack/apis/memory_banks/memory_banks.py +++ b/llama_stack/apis/memory_banks/memory_banks.py @@ -6,7 +6,7 @@ from typing import List, Optional, Protocol -from llama_memory_banks.schema_utils import json_schema_type, webmethod +from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field from llama_stack.apis.memory import MemoryBankType diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 1bdcb5473..52522886e 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -6,7 +6,7 @@ from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Protocol, Union from llama_models.schema_utils import json_schema_type @@ -139,10 +139,6 @@ Fully-qualified name of the module to import. The module is expected to have: provider_data_validator: Optional[str] = Field( default=None, ) - supported_model_ids: List[str] = Field( - default_factory=list, - description="The list of model ids that this adapter supports", - ) @json_schema_type diff --git a/llama_stack/distribution/distribution.py b/llama_stack/distribution/distribution.py index bfea3da5f..6b72afed5 100644 --- a/llama_stack/distribution/distribution.py +++ b/llama_stack/distribution/distribution.py @@ -35,22 +35,22 @@ def stack_apis() -> List[Api]: class AutoRoutedApiInfo(BaseModel): - api_with_routing_table: Api + routing_table_api: Api router_api: Api def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]: return [ AutoRoutedApiInfo( - api_with_routing_table=Api.models, + routing_table_api=Api.models, router_api=Api.inference, ), AutoRoutedApiInfo( - api_with_routing_table=Api.shields, + routing_table_api=Api.shields, router_api=Api.safety, ), AutoRoutedApiInfo( - api_with_routing_table=Api.memory_banks, + routing_table_api=Api.memory_banks, router_api=Api.memory, ), ] @@ -97,7 +97,7 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]: def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]: ret = {} routing_table_apis = set( - x.api_with_routing_table for x in builtin_automatically_routed_apis() + x.routing_table_api for x in builtin_automatically_routed_apis() ) for api in stack_apis(): if api in routing_table_apis: diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index d8e076072..e8b8938b0 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List, Tuple +from typing import Any, List, Tuple from llama_stack.distribution.datatypes import * # noqa: F403 @@ -14,7 +14,7 @@ async def get_routing_table_impl( inner_impls: List[Tuple[str, Any]], routing_table_config: RoutingTableConfig, _deps, -) -> Dict[str, List[ProviderRoutingEntry]]: +) -> Any: from .routing_tables import ( MemoryBanksRoutingTable, ModelsRoutingTable, @@ -22,9 +22,9 @@ async def get_routing_table_impl( ) api_to_tables = { - "memory": MemoryBanksRoutingTable, - "inference": ModelsRoutingTable, - "safety": ShieldsRoutingTable, + "memory_banks": MemoryBanksRoutingTable, + "models": ModelsRoutingTable, + "shields": ShieldsRoutingTable, } if api.value not in api_to_tables: raise ValueError(f"API {api.value} not found in router map") @@ -37,7 +37,6 @@ async def get_routing_table_impl( async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any: from .routers import InferenceRouter, MemoryRouter, SafetyRouter - # TODO: make this completely dynamic api_to_routers = { "memory": MemoryRouter, "inference": InferenceRouter, diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 54e853c73..a3f40b2b7 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -4,6 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Any, List, Optional, Tuple + from llama_models.sku_list import resolve_model from llama_models.llama3.api.datatypes import * # noqa: F403 diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 24fe42ab5..18433596f 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -47,7 +47,11 @@ from llama_stack.providers.utils.telemetry.tracing import ( ) from llama_stack.distribution.datatypes import * # noqa: F403 -from llama_stack.distribution.distribution import api_endpoints, api_providers +from llama_stack.distribution.distribution import ( + api_endpoints, + api_providers, + builtin_automatically_routed_apis, +) from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.utils.dynamic import instantiate_provider @@ -310,8 +314,12 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An specs[api] = providers[config.provider_id] configs[api] = config + apis_to_serve = run_config.apis_to_serve or set( + list(specs.keys()) + list(run_config.routing_tables.keys()) + ) + print("apis_to_serve", apis_to_serve) for info in builtin_automatically_routed_apis(): - source_api = info.api_with_routing_table + source_api = info.routing_table_api assert ( source_api not in specs @@ -320,6 +328,9 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An info.router_api not in specs ), f"Auto-routed API {info.router_api} specified in wrong place?" + if info.router_api.value not in apis_to_serve: + continue + if source_api.value not in run_config.routing_tables: raise ValueError(f"Routing table for `{source_api.value}` is not provided?") @@ -352,7 +363,10 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An configs[info.router_api] = {} sorted_specs = topological_sort(specs.values()) - + print(f"Resolved {len(sorted_specs)} providers in topological order") + for spec in sorted_specs: + print(f" {spec.api}: {spec.provider_id}") + print("") impls = {} for spec in sorted_specs: api = spec.api @@ -376,9 +390,17 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): all_endpoints = api_endpoints() - apis_to_serve = config.apis_to_serve or list(config.provider_map.keys()) + if config.apis_to_serve: + apis_to_serve = set(config.apis_to_serve) + for inf in builtin_automatically_routed_apis(): + if inf.router_api.value in apis_to_serve: + apis_to_serve.add(inf.routing_table_api) + else: + apis_to_serve = set(impls.keys()) + for api_str in apis_to_serve: api = Api(api_str) + endpoints = all_endpoints[api] impl = impls[api] @@ -405,7 +427,11 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False): create_dynamic_typed_route( impl_method, endpoint.method, - provider_spec.provider_data_validator, + ( + provider_spec.provider_data_validator + if not isinstance(provider_spec, RoutingTableProviderSpec) + else None + ), ) ) diff --git a/llama_stack/distribution/utils/dynamic.py b/llama_stack/distribution/utils/dynamic.py index 431feddc1..6d9c57dfd 100644 --- a/llama_stack/distribution/utils/dynamic.py +++ b/llama_stack/distribution/utils/dynamic.py @@ -38,6 +38,7 @@ async def instantiate_provider( elif isinstance(provider_spec, AutoRoutedProviderSpec): method = "get_auto_router_impl" + config = None args = [provider_spec.api, deps[provider_spec.routing_table_api], deps] elif isinstance(provider_spec, RoutingTableProviderSpec): method = "get_routing_table_impl" diff --git a/llama_stack/providers/impls/meta_reference/inference/config.py b/llama_stack/providers/impls/meta_reference/inference/config.py index 8e3d3ed3c..d9b397571 100644 --- a/llama_stack/providers/impls/meta_reference/inference/config.py +++ b/llama_stack/providers/impls/meta_reference/inference/config.py @@ -6,17 +6,14 @@ from typing import Optional -from llama_models.datatypes import ModelFamily - -from llama_models.schema_utils import json_schema_type +from llama_models.datatypes import * # noqa: F403 from llama_models.sku_list import all_registered_models, resolve_model +from llama_stack.apis.inference import * # noqa: F401, F403 + from pydantic import BaseModel, Field, field_validator -from llama_stack.apis.inference import QuantizationConfig - -@json_schema_type class MetaReferenceImplConfig(BaseModel): model: str = Field( default="Meta-Llama3.1-8B-Instruct", @@ -34,6 +31,7 @@ class MetaReferenceImplConfig(BaseModel): m.descriptor() for m in all_registered_models() if m.model_family == ModelFamily.llama3_1 + or m.core_model_id == CoreModelId.llama_guard_3_8b ] if model not in permitted_models: model_list = "\n\t".join(permitted_models) diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index bf739eefa..10b3d6ccc 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -32,10 +32,6 @@ def available_providers() -> List[ProviderSpec]: adapter_id="ollama", pip_packages=["ollama"], module="llama_stack.providers.adapters.inference.ollama", - supported_model_ids=[ - "Meta-Llama3.1-8B-Instruct", - "Meta-Llama3.1-70B-Instruct", - ], ), ), remote_provider_spec( @@ -56,11 +52,6 @@ def available_providers() -> List[ProviderSpec]: ], module="llama_stack.providers.adapters.inference.fireworks", config_class="llama_stack.providers.adapters.inference.fireworks.FireworksImplConfig", - supported_model_ids=[ - "Meta-Llama3.1-8B-Instruct", - "Meta-Llama3.1-70B-Instruct", - "Meta-Llama3.1-405B-Instruct", - ], ), ), remote_provider_spec( @@ -73,11 +64,6 @@ def available_providers() -> List[ProviderSpec]: module="llama_stack.providers.adapters.inference.together", config_class="llama_stack.providers.adapters.inference.together.TogetherImplConfig", header_extractor_class="llama_stack.providers.adapters.inference.together.TogetherHeaderExtractor", - supported_model_ids=[ - "Meta-Llama3.1-8B-Instruct", - "Meta-Llama3.1-70B-Instruct", - "Meta-Llama3.1-405B-Instruct", - ], ), ), ]