Basic build and run succeeded

This commit is contained in:
Ashwin Bharambe 2024-09-22 16:30:32 -07:00
parent 4d5ca49eed
commit 2ed65a47a4
9 changed files with 50 additions and 42 deletions

View file

@ -6,7 +6,7 @@
from typing import List, Optional, Protocol
from llama_memory_banks.schema_utils import json_schema_type, webmethod
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.apis.memory import MemoryBankType

View file

@ -6,7 +6,7 @@
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type
@ -139,10 +139,6 @@ Fully-qualified name of the module to import. The module is expected to have:
provider_data_validator: Optional[str] = Field(
default=None,
)
supported_model_ids: List[str] = Field(
default_factory=list,
description="The list of model ids that this adapter supports",
)
@json_schema_type

View file

@ -35,22 +35,22 @@ def stack_apis() -> List[Api]:
class AutoRoutedApiInfo(BaseModel):
api_with_routing_table: Api
routing_table_api: Api
router_api: Api
def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
return [
AutoRoutedApiInfo(
api_with_routing_table=Api.models,
routing_table_api=Api.models,
router_api=Api.inference,
),
AutoRoutedApiInfo(
api_with_routing_table=Api.shields,
routing_table_api=Api.shields,
router_api=Api.safety,
),
AutoRoutedApiInfo(
api_with_routing_table=Api.memory_banks,
routing_table_api=Api.memory_banks,
router_api=Api.memory,
),
]
@ -97,7 +97,7 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
ret = {}
routing_table_apis = set(
x.api_with_routing_table for x in builtin_automatically_routed_apis()
x.routing_table_api for x in builtin_automatically_routed_apis()
)
for api in stack_apis():
if api in routing_table_apis:

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Tuple
from typing import Any, List, Tuple
from llama_stack.distribution.datatypes import * # noqa: F403
@ -14,7 +14,7 @@ async def get_routing_table_impl(
inner_impls: List[Tuple[str, Any]],
routing_table_config: RoutingTableConfig,
_deps,
) -> Dict[str, List[ProviderRoutingEntry]]:
) -> Any:
from .routing_tables import (
MemoryBanksRoutingTable,
ModelsRoutingTable,
@ -22,9 +22,9 @@ async def get_routing_table_impl(
)
api_to_tables = {
"memory": MemoryBanksRoutingTable,
"inference": ModelsRoutingTable,
"safety": ShieldsRoutingTable,
"memory_banks": MemoryBanksRoutingTable,
"models": ModelsRoutingTable,
"shields": ShieldsRoutingTable,
}
if api.value not in api_to_tables:
raise ValueError(f"API {api.value} not found in router map")
@ -37,7 +37,6 @@ async def get_routing_table_impl(
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
from .routers import InferenceRouter, MemoryRouter, SafetyRouter
# TODO: make this completely dynamic
api_to_routers = {
"memory": MemoryRouter,
"inference": InferenceRouter,

View file

@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, List, Optional, Tuple
from llama_models.sku_list import resolve_model
from llama_models.llama3.api.datatypes import * # noqa: F403

View file

@ -47,7 +47,11 @@ from llama_stack.providers.utils.telemetry.tracing import (
)
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.distribution import api_endpoints, api_providers
from llama_stack.distribution.distribution import (
api_endpoints,
api_providers,
builtin_automatically_routed_apis,
)
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.utils.dynamic import instantiate_provider
@ -310,8 +314,12 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
specs[api] = providers[config.provider_id]
configs[api] = config
apis_to_serve = run_config.apis_to_serve or set(
list(specs.keys()) + list(run_config.routing_tables.keys())
)
print("apis_to_serve", apis_to_serve)
for info in builtin_automatically_routed_apis():
source_api = info.api_with_routing_table
source_api = info.routing_table_api
assert (
source_api not in specs
@ -320,6 +328,9 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
info.router_api not in specs
), f"Auto-routed API {info.router_api} specified in wrong place?"
if info.router_api.value not in apis_to_serve:
continue
if source_api.value not in run_config.routing_tables:
raise ValueError(f"Routing table for `{source_api.value}` is not provided?")
@ -352,7 +363,10 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
configs[info.router_api] = {}
sorted_specs = topological_sort(specs.values())
print(f"Resolved {len(sorted_specs)} providers in topological order")
for spec in sorted_specs:
print(f" {spec.api}: {spec.provider_id}")
print("")
impls = {}
for spec in sorted_specs:
api = spec.api
@ -376,9 +390,17 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
all_endpoints = api_endpoints()
apis_to_serve = config.apis_to_serve or list(config.provider_map.keys())
if config.apis_to_serve:
apis_to_serve = set(config.apis_to_serve)
for inf in builtin_automatically_routed_apis():
if inf.router_api.value in apis_to_serve:
apis_to_serve.add(inf.routing_table_api)
else:
apis_to_serve = set(impls.keys())
for api_str in apis_to_serve:
api = Api(api_str)
endpoints = all_endpoints[api]
impl = impls[api]
@ -405,7 +427,11 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
create_dynamic_typed_route(
impl_method,
endpoint.method,
provider_spec.provider_data_validator,
(
provider_spec.provider_data_validator
if not isinstance(provider_spec, RoutingTableProviderSpec)
else None
),
)
)

View file

@ -38,6 +38,7 @@ async def instantiate_provider(
elif isinstance(provider_spec, AutoRoutedProviderSpec):
method = "get_auto_router_impl"
config = None
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps]
elif isinstance(provider_spec, RoutingTableProviderSpec):
method = "get_routing_table_impl"

View file

@ -6,17 +6,14 @@
from typing import Optional
from llama_models.datatypes import ModelFamily
from llama_models.schema_utils import json_schema_type
from llama_models.datatypes import * # noqa: F403
from llama_models.sku_list import all_registered_models, resolve_model
from llama_stack.apis.inference import * # noqa: F401, F403
from pydantic import BaseModel, Field, field_validator
from llama_stack.apis.inference import QuantizationConfig
@json_schema_type
class MetaReferenceImplConfig(BaseModel):
model: str = Field(
default="Meta-Llama3.1-8B-Instruct",
@ -34,6 +31,7 @@ class MetaReferenceImplConfig(BaseModel):
m.descriptor()
for m in all_registered_models()
if m.model_family == ModelFamily.llama3_1
or m.core_model_id == CoreModelId.llama_guard_3_8b
]
if model not in permitted_models:
model_list = "\n\t".join(permitted_models)

View file

@ -32,10 +32,6 @@ def available_providers() -> List[ProviderSpec]:
adapter_id="ollama",
pip_packages=["ollama"],
module="llama_stack.providers.adapters.inference.ollama",
supported_model_ids=[
"Meta-Llama3.1-8B-Instruct",
"Meta-Llama3.1-70B-Instruct",
],
),
),
remote_provider_spec(
@ -56,11 +52,6 @@ def available_providers() -> List[ProviderSpec]:
],
module="llama_stack.providers.adapters.inference.fireworks",
config_class="llama_stack.providers.adapters.inference.fireworks.FireworksImplConfig",
supported_model_ids=[
"Meta-Llama3.1-8B-Instruct",
"Meta-Llama3.1-70B-Instruct",
"Meta-Llama3.1-405B-Instruct",
],
),
),
remote_provider_spec(
@ -73,11 +64,6 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.adapters.inference.together",
config_class="llama_stack.providers.adapters.inference.together.TogetherImplConfig",
header_extractor_class="llama_stack.providers.adapters.inference.together.TogetherHeaderExtractor",
supported_model_ids=[
"Meta-Llama3.1-8B-Instruct",
"Meta-Llama3.1-70B-Instruct",
"Meta-Llama3.1-405B-Instruct",
],
),
),
]