Basic build and run succeeded

This commit is contained in:
Ashwin Bharambe 2024-09-22 16:30:32 -07:00
parent 4d5ca49eed
commit 2ed65a47a4
9 changed files with 50 additions and 42 deletions

View file

@ -6,7 +6,7 @@
from typing import List, Optional, Protocol from typing import List, Optional, Protocol
from llama_memory_banks.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.apis.memory import MemoryBankType from llama_stack.apis.memory import MemoryBankType

View file

@ -6,7 +6,7 @@
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
@ -139,10 +139,6 @@ Fully-qualified name of the module to import. The module is expected to have:
provider_data_validator: Optional[str] = Field( provider_data_validator: Optional[str] = Field(
default=None, default=None,
) )
supported_model_ids: List[str] = Field(
default_factory=list,
description="The list of model ids that this adapter supports",
)
@json_schema_type @json_schema_type

View file

@ -35,22 +35,22 @@ def stack_apis() -> List[Api]:
class AutoRoutedApiInfo(BaseModel): class AutoRoutedApiInfo(BaseModel):
api_with_routing_table: Api routing_table_api: Api
router_api: Api router_api: Api
def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]: def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
return [ return [
AutoRoutedApiInfo( AutoRoutedApiInfo(
api_with_routing_table=Api.models, routing_table_api=Api.models,
router_api=Api.inference, router_api=Api.inference,
), ),
AutoRoutedApiInfo( AutoRoutedApiInfo(
api_with_routing_table=Api.shields, routing_table_api=Api.shields,
router_api=Api.safety, router_api=Api.safety,
), ),
AutoRoutedApiInfo( AutoRoutedApiInfo(
api_with_routing_table=Api.memory_banks, routing_table_api=Api.memory_banks,
router_api=Api.memory, router_api=Api.memory,
), ),
] ]
@ -97,7 +97,7 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]: def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
ret = {} ret = {}
routing_table_apis = set( routing_table_apis = set(
x.api_with_routing_table for x in builtin_automatically_routed_apis() x.routing_table_api for x in builtin_automatically_routed_apis()
) )
for api in stack_apis(): for api in stack_apis():
if api in routing_table_apis: if api in routing_table_apis:

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, List, Tuple from typing import Any, List, Tuple
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
@ -14,7 +14,7 @@ async def get_routing_table_impl(
inner_impls: List[Tuple[str, Any]], inner_impls: List[Tuple[str, Any]],
routing_table_config: RoutingTableConfig, routing_table_config: RoutingTableConfig,
_deps, _deps,
) -> Dict[str, List[ProviderRoutingEntry]]: ) -> Any:
from .routing_tables import ( from .routing_tables import (
MemoryBanksRoutingTable, MemoryBanksRoutingTable,
ModelsRoutingTable, ModelsRoutingTable,
@ -22,9 +22,9 @@ async def get_routing_table_impl(
) )
api_to_tables = { api_to_tables = {
"memory": MemoryBanksRoutingTable, "memory_banks": MemoryBanksRoutingTable,
"inference": ModelsRoutingTable, "models": ModelsRoutingTable,
"safety": ShieldsRoutingTable, "shields": ShieldsRoutingTable,
} }
if api.value not in api_to_tables: if api.value not in api_to_tables:
raise ValueError(f"API {api.value} not found in router map") raise ValueError(f"API {api.value} not found in router map")
@ -37,7 +37,6 @@ async def get_routing_table_impl(
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any: async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
from .routers import InferenceRouter, MemoryRouter, SafetyRouter from .routers import InferenceRouter, MemoryRouter, SafetyRouter
# TODO: make this completely dynamic
api_to_routers = { api_to_routers = {
"memory": MemoryRouter, "memory": MemoryRouter,
"inference": InferenceRouter, "inference": InferenceRouter,

View file

@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, List, Optional, Tuple
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403

View file

@ -47,7 +47,11 @@ from llama_stack.providers.utils.telemetry.tracing import (
) )
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.distribution import api_endpoints, api_providers from llama_stack.distribution.distribution import (
api_endpoints,
api_providers,
builtin_automatically_routed_apis,
)
from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.utils.dynamic import instantiate_provider from llama_stack.distribution.utils.dynamic import instantiate_provider
@ -310,8 +314,12 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
specs[api] = providers[config.provider_id] specs[api] = providers[config.provider_id]
configs[api] = config configs[api] = config
apis_to_serve = run_config.apis_to_serve or set(
list(specs.keys()) + list(run_config.routing_tables.keys())
)
print("apis_to_serve", apis_to_serve)
for info in builtin_automatically_routed_apis(): for info in builtin_automatically_routed_apis():
source_api = info.api_with_routing_table source_api = info.routing_table_api
assert ( assert (
source_api not in specs source_api not in specs
@ -320,6 +328,9 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
info.router_api not in specs info.router_api not in specs
), f"Auto-routed API {info.router_api} specified in wrong place?" ), f"Auto-routed API {info.router_api} specified in wrong place?"
if info.router_api.value not in apis_to_serve:
continue
if source_api.value not in run_config.routing_tables: if source_api.value not in run_config.routing_tables:
raise ValueError(f"Routing table for `{source_api.value}` is not provided?") raise ValueError(f"Routing table for `{source_api.value}` is not provided?")
@ -352,7 +363,10 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
configs[info.router_api] = {} configs[info.router_api] = {}
sorted_specs = topological_sort(specs.values()) sorted_specs = topological_sort(specs.values())
print(f"Resolved {len(sorted_specs)} providers in topological order")
for spec in sorted_specs:
print(f" {spec.api}: {spec.provider_id}")
print("")
impls = {} impls = {}
for spec in sorted_specs: for spec in sorted_specs:
api = spec.api api = spec.api
@ -376,9 +390,17 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
all_endpoints = api_endpoints() all_endpoints = api_endpoints()
apis_to_serve = config.apis_to_serve or list(config.provider_map.keys()) if config.apis_to_serve:
apis_to_serve = set(config.apis_to_serve)
for inf in builtin_automatically_routed_apis():
if inf.router_api.value in apis_to_serve:
apis_to_serve.add(inf.routing_table_api)
else:
apis_to_serve = set(impls.keys())
for api_str in apis_to_serve: for api_str in apis_to_serve:
api = Api(api_str) api = Api(api_str)
endpoints = all_endpoints[api] endpoints = all_endpoints[api]
impl = impls[api] impl = impls[api]
@ -405,7 +427,11 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
create_dynamic_typed_route( create_dynamic_typed_route(
impl_method, impl_method,
endpoint.method, endpoint.method,
provider_spec.provider_data_validator, (
provider_spec.provider_data_validator
if not isinstance(provider_spec, RoutingTableProviderSpec)
else None
),
) )
) )

View file

@ -38,6 +38,7 @@ async def instantiate_provider(
elif isinstance(provider_spec, AutoRoutedProviderSpec): elif isinstance(provider_spec, AutoRoutedProviderSpec):
method = "get_auto_router_impl" method = "get_auto_router_impl"
config = None
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps] args = [provider_spec.api, deps[provider_spec.routing_table_api], deps]
elif isinstance(provider_spec, RoutingTableProviderSpec): elif isinstance(provider_spec, RoutingTableProviderSpec):
method = "get_routing_table_impl" method = "get_routing_table_impl"

View file

@ -6,17 +6,14 @@
from typing import Optional from typing import Optional
from llama_models.datatypes import ModelFamily from llama_models.datatypes import * # noqa: F403
from llama_models.schema_utils import json_schema_type
from llama_models.sku_list import all_registered_models, resolve_model from llama_models.sku_list import all_registered_models, resolve_model
from llama_stack.apis.inference import * # noqa: F401, F403
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from llama_stack.apis.inference import QuantizationConfig
@json_schema_type
class MetaReferenceImplConfig(BaseModel): class MetaReferenceImplConfig(BaseModel):
model: str = Field( model: str = Field(
default="Meta-Llama3.1-8B-Instruct", default="Meta-Llama3.1-8B-Instruct",
@ -34,6 +31,7 @@ class MetaReferenceImplConfig(BaseModel):
m.descriptor() m.descriptor()
for m in all_registered_models() for m in all_registered_models()
if m.model_family == ModelFamily.llama3_1 if m.model_family == ModelFamily.llama3_1
or m.core_model_id == CoreModelId.llama_guard_3_8b
] ]
if model not in permitted_models: if model not in permitted_models:
model_list = "\n\t".join(permitted_models) model_list = "\n\t".join(permitted_models)

View file

@ -32,10 +32,6 @@ def available_providers() -> List[ProviderSpec]:
adapter_id="ollama", adapter_id="ollama",
pip_packages=["ollama"], pip_packages=["ollama"],
module="llama_stack.providers.adapters.inference.ollama", module="llama_stack.providers.adapters.inference.ollama",
supported_model_ids=[
"Meta-Llama3.1-8B-Instruct",
"Meta-Llama3.1-70B-Instruct",
],
), ),
), ),
remote_provider_spec( remote_provider_spec(
@ -56,11 +52,6 @@ def available_providers() -> List[ProviderSpec]:
], ],
module="llama_stack.providers.adapters.inference.fireworks", module="llama_stack.providers.adapters.inference.fireworks",
config_class="llama_stack.providers.adapters.inference.fireworks.FireworksImplConfig", config_class="llama_stack.providers.adapters.inference.fireworks.FireworksImplConfig",
supported_model_ids=[
"Meta-Llama3.1-8B-Instruct",
"Meta-Llama3.1-70B-Instruct",
"Meta-Llama3.1-405B-Instruct",
],
), ),
), ),
remote_provider_spec( remote_provider_spec(
@ -73,11 +64,6 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.adapters.inference.together", module="llama_stack.providers.adapters.inference.together",
config_class="llama_stack.providers.adapters.inference.together.TogetherImplConfig", config_class="llama_stack.providers.adapters.inference.together.TogetherImplConfig",
header_extractor_class="llama_stack.providers.adapters.inference.together.TogetherHeaderExtractor", header_extractor_class="llama_stack.providers.adapters.inference.together.TogetherHeaderExtractor",
supported_model_ids=[
"Meta-Llama3.1-8B-Instruct",
"Meta-Llama3.1-70B-Instruct",
"Meta-Llama3.1-405B-Instruct",
],
), ),
), ),
] ]