diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index 6917617bc..8bbe7f6de 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -29,17 +29,12 @@ class ModelServingSpec(BaseModel): @json_schema_type class ModelsListResponse(BaseModel): - models_list: List[ModelSpec] + models_list: List[ModelServingSpec] @json_schema_type class ModelsGetResponse(BaseModel): - core_model_spec: Optional[ModelSpec] = None - - -@json_schema_type -class ModelsRegisterResponse(BaseModel): - core_model_spec: Optional[ModelSpec] = None + core_model_spec: Optional[ModelServingSpec] = None class Models(Protocol): diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 24b8443bf..7f15e0495 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -20,6 +20,7 @@ class Api(Enum): agents = "agents" memory = "memory" telemetry = "telemetry" + models = "models" @json_schema_type @@ -81,6 +82,29 @@ class RouterProviderSpec(ProviderSpec): raise AssertionError("Should not be called on RouterProviderSpec") +@json_schema_type +class BuiltinProviderSpec(ProviderSpec): + provider_id: str = "builtin" + config_class: str = "" + docker_image: Optional[str] = None + api_dependencies: List[Api] = [] + provider_data_validator: Optional[str] = Field( + default=None, + ) + module: str = Field( + ..., + description=""" + Fully-qualified name of the module to import. The module is expected to have: + + - `get_router_impl(config, provider_specs, deps)`: returns the router implementation + """, + ) + pip_packages: List[str] = Field( + default_factory=list, + description="The pip dependencies needed for this implementation", + ) + + @json_schema_type class AdapterSpec(BaseModel): adapter_id: str = Field( diff --git a/llama_stack/distribution/distribution.py b/llama_stack/distribution/distribution.py index 0825121dc..3dd406ccc 100644 --- a/llama_stack/distribution/distribution.py +++ b/llama_stack/distribution/distribution.py @@ -11,6 +11,7 @@ from typing import Dict, List from llama_stack.apis.agents import Agents from llama_stack.apis.inference import Inference from llama_stack.apis.memory import Memory +from llama_stack.apis.models import Models from llama_stack.apis.safety import Safety from llama_stack.apis.telemetry import Telemetry @@ -38,6 +39,7 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]: Api.agents: Agents, Api.memory: Memory, Api.telemetry: Telemetry, + Api.models: Models, } for api, protocol in protocols.items(): diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 468501980..76d467881 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -51,6 +51,7 @@ from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.distribution import api_endpoints, api_providers from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.utils.dynamic import ( + instantiate_builtin_provider, instantiate_provider, instantiate_router, ) @@ -306,15 +307,6 @@ async def resolve_impls_with_routing( api = Api(api_str) providers = all_providers[api] - # check for regular providers without routing - if api_str in stack_run_config.provider_map: - provider_map_entry = stack_run_config.provider_map[api_str] - if provider_map_entry.provider_id not in providers: - raise ValueError( - f"Unknown provider `{provider_id}` is not available for API `{api}`" - ) - specs[api] = providers[provider_map_entry.provider_id] - # check for routing table, we need to pass routing table to the router implementation if api_str in stack_run_config.provider_routing_table: specs[api] = RouterProviderSpec( @@ -323,6 +315,19 @@ async def resolve_impls_with_routing( api_dependencies=[], routing_table=stack_run_config.provider_routing_table[api_str], ) + else: + if api_str in stack_run_config.provider_map: + provider_map_entry = stack_run_config.provider_map[api_str] + provider_id = provider_map_entry.provider_id + else: + # not defined in config, will be a builtin provider, assign builtin provider id + provider_id = "builtin" + + if provider_id not in providers: + raise ValueError( + f"Unknown provider `{provider_id}` is not available for API `{api}`" + ) + specs[api] = providers[provider_id] sorted_specs = topological_sort(specs.values()) @@ -338,7 +343,7 @@ async def resolve_impls_with_routing( spec, api.value, stack_run_config.provider_routing_table ) else: - raise ValueError(f"Cannot find provider_config for Api {api.value}") + impl = await instantiate_builtin_provider(spec, stack_run_config) impls[api] = impl return impls, specs diff --git a/llama_stack/distribution/utils/dynamic.py b/llama_stack/distribution/utils/dynamic.py index f807b096d..85254b246 100644 --- a/llama_stack/distribution/utils/dynamic.py +++ b/llama_stack/distribution/utils/dynamic.py @@ -29,6 +29,18 @@ async def instantiate_router( return impl +async def instantiate_builtin_provider( + provider_spec: BuiltinProviderSpec, + run_config: StackRunConfig, +): + print("!!! instantiate_builtin_provider") + module = importlib.import_module(provider_spec.module) + fn = getattr(module, "get_builtin_impl") + impl = await fn(run_config) + impl.__provider_spec__ = provider_spec + return impl + + # returns a class implementing the protocol corresponding to the Api async def instantiate_provider( provider_spec: ProviderSpec, diff --git a/llama_stack/examples/router-table-run.yaml b/llama_stack/examples/router-table-run.yaml index 0f83cef06..6bc0c6d14 100644 --- a/llama_stack/examples/router-table-run.yaml +++ b/llama_stack/examples/router-table-run.yaml @@ -8,6 +8,7 @@ apis_to_serve: - telemetry - agents - safety +- models provider_map: telemetry: provider_id: meta-reference @@ -22,27 +23,39 @@ provider_map: disable_output_check: false prompt_guard_shield: model: Prompt-Guard-86M + # inference: + # provider_id: meta-reference + # config: + # model: Meta-Llama3.1-8B-Instruct + # quantization: null + # torch_seed: null + # max_seq_len: 4096 + # max_batch_size: 1 + inference: + provider_id: remote::ollama + config: + agents: provider_id: meta-reference config: {} provider_routing_table: - inference: - - routing_key: Meta-Llama3.1-8B-Instruct - provider_id: meta-reference - config: - model: Meta-Llama3.1-8B-Instruct - quantization: null - torch_seed: null - max_seq_len: 4096 - max_batch_size: 1 - # - routing_key: Meta-Llama3.1-8B - # provider_id: meta-reference - # config: - # model: Meta-Llama3.1-8B - # quantization: null - # torch_seed: null - # max_seq_len: 4096 - # max_batch_size: 1 + # inference: + # - routing_key: Meta-Llama3.1-8B-Instruct + # provider_id: meta-reference + # config: + # model: Meta-Llama3.1-8B-Instruct + # quantization: null + # torch_seed: null + # max_seq_len: 4096 + # max_batch_size: 1 + # - routing_key: Meta-Llama3.1-8B-Instruct + # provider_id: meta-reference + # config: + # model: Meta-Llama3.1-8B + # quantization: null + # torch_seed: null + # max_seq_len: 4096 + # max_batch_size: 1 memory: # - routing_key: keyvalue # provider_id: remote::pgvector diff --git a/llama_stack/providers/impls/builtin/models/__init__.py b/llama_stack/providers/impls/builtin/models/__init__.py new file mode 100644 index 000000000..788ecfbab --- /dev/null +++ b/llama_stack/providers/impls/builtin/models/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict + +from llama_stack.distribution.datatypes import Api, ProviderSpec, StackRunConfig + + +async def get_builtin_impl(config: StackRunConfig): + from .models import BuiltinModelsImpl + + assert isinstance(config, StackRunConfig), f"Unexpected config type: {type(config)}" + + impl = BuiltinModelsImpl(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/impls/builtin/models/config.py b/llama_stack/providers/impls/builtin/models/config.py new file mode 100644 index 000000000..0a21e3b20 --- /dev/null +++ b/llama_stack/providers/impls/builtin/models/config.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel + + +@json_schema_type +class BuiltinImplConfig(BaseModel): ... diff --git a/llama_stack/providers/impls/builtin/models/models.py b/llama_stack/providers/impls/builtin/models/models.py new file mode 100644 index 000000000..842302000 --- /dev/null +++ b/llama_stack/providers/impls/builtin/models/models.py @@ -0,0 +1,78 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +import asyncio + +from typing import AsyncIterator, Union + +from llama_models.llama3.api.datatypes import StopReason +from llama_models.sku_list import resolve_model + +from llama_stack.apis.models import * # noqa: F403 +from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_models.datatypes import CoreModelId, Model +from llama_models.sku_list import resolve_model + +from llama_stack.distribution.datatypes import ( + Api, + GenericProviderConfig, + StackRunConfig, +) +from termcolor import cprint + + +class BuiltinModelsImpl(Models): + def __init__( + self, + config: StackRunConfig, + ) -> None: + print("BuiltinModelsImpl init") + + self.run_config = config + self.models = {} + + print("BuiltinModelsImpl run_config", config) + + # check against inference & safety api + apis_with_models = [Api.inference, Api.safety] + + for api in apis_with_models: + # check against provider_map (simple case single model) + if api.value in config.provider_map: + provider_spec = config.provider_map[api.value] + core_model_id = provider_spec.config + print("provider_spec", provider_spec) + model_spec = ModelServingSpec( + provider_config=provider_spec, + ) + # get supported model ids from the provider + supported_model_ids = self.get_supported_model_ids(provider_spec) + for model_id in supported_model_ids: + self.models[model_id] = ModelServingSpec( + llama_model=resolve_model(model_id), + provider_config=provider_spec, + api=api.value, + ) + + # check against provider_routing_table (router with multiple models) + # with routing table, we use the routing_key as the supported models + + def resolve_supported_model_ids(self) -> list[CoreModelId]: + # TODO: for remote providers, provide registry to list supported models + + return ["Meta-Llama3.1-8B-Instruct"] + + async def initialize(self) -> None: + pass + + async def list_models(self) -> ModelsListResponse: + pass + # return ModelsListResponse(models_list=list(self.models.values())) + + async def get_model(self, core_model_id: str) -> ModelsGetResponse: + pass + # if core_model_id in self.models: + # return ModelsGetResponse(core_model_spec=self.models[core_model_id]) + # raise RuntimeError(f"Cannot find {core_model_id} in model registry") diff --git a/llama_stack/providers/registry/models.py b/llama_stack/providers/registry/models.py new file mode 100644 index 000000000..47ec948c4 --- /dev/null +++ b/llama_stack/providers/registry/models.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List + +from llama_stack.distribution.datatypes import * # noqa: F403 + + +def available_providers() -> List[ProviderSpec]: + return [ + BuiltinProviderSpec( + api=Api.models, + provider_id="builtin", + pip_packages=[], + module="llama_stack.providers.impls.builtin.models", + config_class="llama_stack.providers.impls.builtin.models.BuiltinImplConfig", + api_dependencies=[], + ) + ]