mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Basic build and run succeeded
This commit is contained in:
parent
4d5ca49eed
commit
2ed65a47a4
9 changed files with 50 additions and 42 deletions
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import List, Optional, Protocol
|
||||
|
||||
from llama_memory_banks.schema_utils import json_schema_type, webmethod
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.memory import MemoryBankType
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Protocol, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
|
||||
|
@ -139,10 +139,6 @@ Fully-qualified name of the module to import. The module is expected to have:
|
|||
provider_data_validator: Optional[str] = Field(
|
||||
default=None,
|
||||
)
|
||||
supported_model_ids: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="The list of model ids that this adapter supports",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -35,22 +35,22 @@ def stack_apis() -> List[Api]:
|
|||
|
||||
|
||||
class AutoRoutedApiInfo(BaseModel):
|
||||
api_with_routing_table: Api
|
||||
routing_table_api: Api
|
||||
router_api: Api
|
||||
|
||||
|
||||
def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
|
||||
return [
|
||||
AutoRoutedApiInfo(
|
||||
api_with_routing_table=Api.models,
|
||||
routing_table_api=Api.models,
|
||||
router_api=Api.inference,
|
||||
),
|
||||
AutoRoutedApiInfo(
|
||||
api_with_routing_table=Api.shields,
|
||||
routing_table_api=Api.shields,
|
||||
router_api=Api.safety,
|
||||
),
|
||||
AutoRoutedApiInfo(
|
||||
api_with_routing_table=Api.memory_banks,
|
||||
routing_table_api=Api.memory_banks,
|
||||
router_api=Api.memory,
|
||||
),
|
||||
]
|
||||
|
@ -97,7 +97,7 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
|
|||
def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||
ret = {}
|
||||
routing_table_apis = set(
|
||||
x.api_with_routing_table for x in builtin_automatically_routed_apis()
|
||||
x.routing_table_api for x in builtin_automatically_routed_apis()
|
||||
)
|
||||
for api in stack_apis():
|
||||
if api in routing_table_apis:
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
@ -14,7 +14,7 @@ async def get_routing_table_impl(
|
|||
inner_impls: List[Tuple[str, Any]],
|
||||
routing_table_config: RoutingTableConfig,
|
||||
_deps,
|
||||
) -> Dict[str, List[ProviderRoutingEntry]]:
|
||||
) -> Any:
|
||||
from .routing_tables import (
|
||||
MemoryBanksRoutingTable,
|
||||
ModelsRoutingTable,
|
||||
|
@ -22,9 +22,9 @@ async def get_routing_table_impl(
|
|||
)
|
||||
|
||||
api_to_tables = {
|
||||
"memory": MemoryBanksRoutingTable,
|
||||
"inference": ModelsRoutingTable,
|
||||
"safety": ShieldsRoutingTable,
|
||||
"memory_banks": MemoryBanksRoutingTable,
|
||||
"models": ModelsRoutingTable,
|
||||
"shields": ShieldsRoutingTable,
|
||||
}
|
||||
if api.value not in api_to_tables:
|
||||
raise ValueError(f"API {api.value} not found in router map")
|
||||
|
@ -37,7 +37,6 @@ async def get_routing_table_impl(
|
|||
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
|
||||
from .routers import InferenceRouter, MemoryRouter, SafetyRouter
|
||||
|
||||
# TODO: make this completely dynamic
|
||||
api_to_routers = {
|
||||
"memory": MemoryRouter,
|
||||
"inference": InferenceRouter,
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
from llama_models.sku_list import resolve_model
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
|
|
|
@ -47,7 +47,11 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
|||
)
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.distribution import api_endpoints, api_providers
|
||||
from llama_stack.distribution.distribution import (
|
||||
api_endpoints,
|
||||
api_providers,
|
||||
builtin_automatically_routed_apis,
|
||||
)
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_provider
|
||||
|
||||
|
@ -310,8 +314,12 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
|||
specs[api] = providers[config.provider_id]
|
||||
configs[api] = config
|
||||
|
||||
apis_to_serve = run_config.apis_to_serve or set(
|
||||
list(specs.keys()) + list(run_config.routing_tables.keys())
|
||||
)
|
||||
print("apis_to_serve", apis_to_serve)
|
||||
for info in builtin_automatically_routed_apis():
|
||||
source_api = info.api_with_routing_table
|
||||
source_api = info.routing_table_api
|
||||
|
||||
assert (
|
||||
source_api not in specs
|
||||
|
@ -320,6 +328,9 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
|||
info.router_api not in specs
|
||||
), f"Auto-routed API {info.router_api} specified in wrong place?"
|
||||
|
||||
if info.router_api.value not in apis_to_serve:
|
||||
continue
|
||||
|
||||
if source_api.value not in run_config.routing_tables:
|
||||
raise ValueError(f"Routing table for `{source_api.value}` is not provided?")
|
||||
|
||||
|
@ -352,7 +363,10 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
|||
configs[info.router_api] = {}
|
||||
|
||||
sorted_specs = topological_sort(specs.values())
|
||||
|
||||
print(f"Resolved {len(sorted_specs)} providers in topological order")
|
||||
for spec in sorted_specs:
|
||||
print(f" {spec.api}: {spec.provider_id}")
|
||||
print("")
|
||||
impls = {}
|
||||
for spec in sorted_specs:
|
||||
api = spec.api
|
||||
|
@ -376,9 +390,17 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
|||
|
||||
all_endpoints = api_endpoints()
|
||||
|
||||
apis_to_serve = config.apis_to_serve or list(config.provider_map.keys())
|
||||
if config.apis_to_serve:
|
||||
apis_to_serve = set(config.apis_to_serve)
|
||||
for inf in builtin_automatically_routed_apis():
|
||||
if inf.router_api.value in apis_to_serve:
|
||||
apis_to_serve.add(inf.routing_table_api)
|
||||
else:
|
||||
apis_to_serve = set(impls.keys())
|
||||
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
|
||||
endpoints = all_endpoints[api]
|
||||
impl = impls[api]
|
||||
|
||||
|
@ -405,7 +427,11 @@ def main(yaml_config: str, port: int = 5000, disable_ipv6: bool = False):
|
|||
create_dynamic_typed_route(
|
||||
impl_method,
|
||||
endpoint.method,
|
||||
provider_spec.provider_data_validator,
|
||||
(
|
||||
provider_spec.provider_data_validator
|
||||
if not isinstance(provider_spec, RoutingTableProviderSpec)
|
||||
else None
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -38,6 +38,7 @@ async def instantiate_provider(
|
|||
elif isinstance(provider_spec, AutoRoutedProviderSpec):
|
||||
method = "get_auto_router_impl"
|
||||
|
||||
config = None
|
||||
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps]
|
||||
elif isinstance(provider_spec, RoutingTableProviderSpec):
|
||||
method = "get_routing_table_impl"
|
||||
|
|
|
@ -6,17 +6,14 @@
|
|||
|
||||
from typing import Optional
|
||||
|
||||
from llama_models.datatypes import ModelFamily
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from llama_models.datatypes import * # noqa: F403
|
||||
from llama_models.sku_list import all_registered_models, resolve_model
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F401, F403
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from llama_stack.apis.inference import QuantizationConfig
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MetaReferenceImplConfig(BaseModel):
|
||||
model: str = Field(
|
||||
default="Meta-Llama3.1-8B-Instruct",
|
||||
|
@ -34,6 +31,7 @@ class MetaReferenceImplConfig(BaseModel):
|
|||
m.descriptor()
|
||||
for m in all_registered_models()
|
||||
if m.model_family == ModelFamily.llama3_1
|
||||
or m.core_model_id == CoreModelId.llama_guard_3_8b
|
||||
]
|
||||
if model not in permitted_models:
|
||||
model_list = "\n\t".join(permitted_models)
|
||||
|
|
|
@ -32,10 +32,6 @@ def available_providers() -> List[ProviderSpec]:
|
|||
adapter_id="ollama",
|
||||
pip_packages=["ollama"],
|
||||
module="llama_stack.providers.adapters.inference.ollama",
|
||||
supported_model_ids=[
|
||||
"Meta-Llama3.1-8B-Instruct",
|
||||
"Meta-Llama3.1-70B-Instruct",
|
||||
],
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
|
@ -56,11 +52,6 @@ def available_providers() -> List[ProviderSpec]:
|
|||
],
|
||||
module="llama_stack.providers.adapters.inference.fireworks",
|
||||
config_class="llama_stack.providers.adapters.inference.fireworks.FireworksImplConfig",
|
||||
supported_model_ids=[
|
||||
"Meta-Llama3.1-8B-Instruct",
|
||||
"Meta-Llama3.1-70B-Instruct",
|
||||
"Meta-Llama3.1-405B-Instruct",
|
||||
],
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
|
@ -73,11 +64,6 @@ def available_providers() -> List[ProviderSpec]:
|
|||
module="llama_stack.providers.adapters.inference.together",
|
||||
config_class="llama_stack.providers.adapters.inference.together.TogetherImplConfig",
|
||||
header_extractor_class="llama_stack.providers.adapters.inference.together.TogetherHeaderExtractor",
|
||||
supported_model_ids=[
|
||||
"Meta-Llama3.1-8B-Instruct",
|
||||
"Meta-Llama3.1-70B-Instruct",
|
||||
"Meta-Llama3.1-405B-Instruct",
|
||||
],
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue