From 546f05bd3f58e4dbdf254799f3f0cb0383c183a5 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 2 Oct 2024 12:25:54 -0700 Subject: [PATCH 01/11] No automatic pager --- llama_stack/cli/model/prompt_format.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/llama_stack/cli/model/prompt_format.py b/llama_stack/cli/model/prompt_format.py index e6fd8aac7..67f456175 100644 --- a/llama_stack/cli/model/prompt_format.py +++ b/llama_stack/cli/model/prompt_format.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import argparse -import subprocess import textwrap from io import StringIO @@ -110,7 +109,4 @@ def render_markdown_to_pager(markdown_content: str): console = Console(file=output, force_terminal=True, width=100) # Set a fixed width console.print(md) rendered_content = output.getvalue() - - # Pipe to pager - pager = subprocess.Popen(["less", "-R"], stdin=subprocess.PIPE) - pager.communicate(input=rendered_content.encode()) + print(rendered_content) From df68db644bbb22860727bfbf1635536910b7e533 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 2 Oct 2024 13:20:17 -0700 Subject: [PATCH 02/11] Refactoring distribution/distribution.py This file was becoming too large and unclear what it housed. Split it into pieces. --- llama_stack/cli/stack/build.py | 4 +- llama_stack/cli/stack/list_providers.py | 4 +- llama_stack/distribution/build.py | 14 ++++- llama_stack/distribution/configure.py | 4 +- llama_stack/distribution/distribution.py | 61 +------------------ llama_stack/distribution/resolver.py | 4 +- llama_stack/distribution/server/endpoints.py | 64 ++++++++++++++++++++ llama_stack/distribution/server/server.py | 5 +- llama_stack/providers/datatypes.py | 7 --- 9 files changed, 89 insertions(+), 78 deletions(-) create mode 100644 llama_stack/distribution/server/endpoints.py diff --git a/llama_stack/cli/stack/build.py b/llama_stack/cli/stack/build.py index ef1f1a807..b7c25fa1b 100644 --- a/llama_stack/cli/stack/build.py +++ b/llama_stack/cli/stack/build.py @@ -175,7 +175,7 @@ class StackBuild(Subcommand): import yaml from llama_stack.distribution.distribution import ( Api, - api_providers, + get_provider_registry, builtin_automatically_routed_apis, ) from llama_stack.distribution.utils.dynamic import instantiate_class_type @@ -245,7 +245,7 @@ class StackBuild(Subcommand): ) providers = dict() - all_providers = api_providers() + all_providers = get_provider_registry() routing_table_apis = set( x.routing_table_api for x in builtin_automatically_routed_apis() ) diff --git a/llama_stack/cli/stack/list_providers.py b/llama_stack/cli/stack/list_providers.py index 18c4de201..25875ecbf 100644 --- a/llama_stack/cli/stack/list_providers.py +++ b/llama_stack/cli/stack/list_providers.py @@ -34,9 +34,9 @@ class StackListProviders(Subcommand): def _run_providers_list_cmd(self, args: argparse.Namespace) -> None: from llama_stack.cli.table import print_table - from llama_stack.distribution.distribution import Api, api_providers + from llama_stack.distribution.distribution import Api, get_provider_registry - all_providers = api_providers() + all_providers = get_provider_registry() providers_for_api = all_providers[Api(args.api)] # eventually, this should query a registry at llama.meta.com/llamastack/distributions diff --git a/llama_stack/distribution/build.py b/llama_stack/distribution/build.py index dabcad2a6..fe778bdb8 100644 --- a/llama_stack/distribution/build.py +++ b/llama_stack/distribution/build.py @@ -17,7 +17,17 @@ from llama_stack.distribution.utils.exec import run_with_pty from llama_stack.distribution.datatypes import * # noqa: F403 from pathlib import Path -from llama_stack.distribution.distribution import api_providers, SERVER_DEPENDENCIES +from llama_stack.distribution.distribution import get_provider_registry + + +# These are the dependencies needed by the distribution server. +# `llama-stack` is automatically installed by the installation script. +SERVER_DEPENDENCIES = [ + "fastapi", + "fire", + "httpx", + "uvicorn", +] class ImageType(Enum): @@ -42,7 +52,7 @@ def build_image(build_config: BuildConfig, build_file_path: Path): ) # extend package dependencies based on providers spec - all_providers = api_providers() + all_providers = get_provider_registry() for ( api_str, provider_or_providers, diff --git a/llama_stack/distribution/configure.py b/llama_stack/distribution/configure.py index d3b807d4a..e9b682dc0 100644 --- a/llama_stack/distribution/configure.py +++ b/llama_stack/distribution/configure.py @@ -15,8 +15,8 @@ from termcolor import cprint from llama_stack.apis.memory.memory import MemoryBankType from llama_stack.distribution.distribution import ( - api_providers, builtin_automatically_routed_apis, + get_provider_registry, stack_apis, ) from llama_stack.distribution.utils.dynamic import instantiate_class_type @@ -62,7 +62,7 @@ def configure_api_providers( config.apis_to_serve = list(set([a for a in apis if a != "telemetry"])) apis = [v.value for v in stack_apis()] - all_providers = api_providers() + all_providers = get_provider_registry() # configure simple case for with non-routing providers to api_providers for api_str in spec.providers.keys(): diff --git a/llama_stack/distribution/distribution.py b/llama_stack/distribution/distribution.py index 035febb80..0c47fd750 100644 --- a/llama_stack/distribution/distribution.py +++ b/llama_stack/distribution/distribution.py @@ -5,30 +5,11 @@ # the root directory of this source tree. import importlib -import inspect from typing import Dict, List from pydantic import BaseModel -from llama_stack.apis.agents import Agents -from llama_stack.apis.inference import Inference -from llama_stack.apis.memory import Memory -from llama_stack.apis.memory_banks import MemoryBanks -from llama_stack.apis.models import Models -from llama_stack.apis.safety import Safety -from llama_stack.apis.shields import Shields -from llama_stack.apis.telemetry import Telemetry - -from .datatypes import Api, ApiEndpoint, ProviderSpec, remote_provider_spec - -# These are the dependencies needed by the distribution server. -# `llama-stack` is automatically installed by the installation script. -SERVER_DEPENDENCIES = [ - "fastapi", - "fire", - "httpx", - "uvicorn", -] +from llama_stack.providers.datatypes import Api, ProviderSpec, remote_provider_spec def stack_apis() -> List[Api]: @@ -57,45 +38,7 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]: ] -def api_endpoints() -> Dict[Api, List[ApiEndpoint]]: - apis = {} - - protocols = { - Api.inference: Inference, - Api.safety: Safety, - Api.agents: Agents, - Api.memory: Memory, - Api.telemetry: Telemetry, - Api.models: Models, - Api.shields: Shields, - Api.memory_banks: MemoryBanks, - } - - for api, protocol in protocols.items(): - endpoints = [] - protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction) - - for name, method in protocol_methods: - if not hasattr(method, "__webmethod__"): - continue - - webmethod = method.__webmethod__ - route = webmethod.route - - if webmethod.method == "GET": - method = "get" - elif webmethod.method == "DELETE": - method = "delete" - else: - method = "post" - endpoints.append(ApiEndpoint(route=route, method=method, name=name)) - - apis[api] = endpoints - - return apis - - -def api_providers() -> Dict[Api, Dict[str, ProviderSpec]]: +def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]: ret = {} routing_table_apis = set( x.routing_table_api for x in builtin_automatically_routed_apis() diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index f7d51c64a..8c8084969 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -8,8 +8,8 @@ from typing import Any, Dict, List, Set from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.distribution import ( - api_providers, builtin_automatically_routed_apis, + get_provider_registry, ) from llama_stack.distribution.utils.dynamic import instantiate_provider @@ -20,7 +20,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An - flatmaps, sorts and resolves the providers in dependency order - for each API, produces either a (local, passthrough or router) implementation """ - all_providers = api_providers() + all_providers = get_provider_registry() specs = {} configs = {} diff --git a/llama_stack/distribution/server/endpoints.py b/llama_stack/distribution/server/endpoints.py new file mode 100644 index 000000000..96de31c4b --- /dev/null +++ b/llama_stack/distribution/server/endpoints.py @@ -0,0 +1,64 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import inspect +from typing import Dict, List + +from pydantic import BaseModel + +from llama_stack.apis.agents import Agents +from llama_stack.apis.inference import Inference +from llama_stack.apis.memory import Memory +from llama_stack.apis.memory_banks import MemoryBanks +from llama_stack.apis.models import Models +from llama_stack.apis.safety import Safety +from llama_stack.apis.shields import Shields +from llama_stack.apis.telemetry import Telemetry +from llama_stack.providers.datatypes import Api + + +class ApiEndpoint(BaseModel): + route: str + method: str + name: str + + +def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]: + apis = {} + + protocols = { + Api.inference: Inference, + Api.safety: Safety, + Api.agents: Agents, + Api.memory: Memory, + Api.telemetry: Telemetry, + Api.models: Models, + Api.shields: Shields, + Api.memory_banks: MemoryBanks, + } + + for api, protocol in protocols.items(): + endpoints = [] + protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction) + + for name, method in protocol_methods: + if not hasattr(method, "__webmethod__"): + continue + + webmethod = method.__webmethod__ + route = webmethod.route + + if webmethod.method == "GET": + method = "get" + elif webmethod.method == "DELETE": + method = "delete" + else: + method = "post" + endpoints.append(ApiEndpoint(route=route, method=method, name=name)) + + apis[api] = endpoints + + return apis diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 16b1fb619..1ac1a1b16 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -39,10 +39,11 @@ from llama_stack.providers.utils.telemetry.tracing import ( ) from llama_stack.distribution.datatypes import * # noqa: F403 -from llama_stack.distribution.distribution import api_endpoints from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.resolver import resolve_impls_with_routing +from .endpoints import get_all_api_endpoints + def is_async_iterator_type(typ): if hasattr(typ, "__origin__"): @@ -299,7 +300,7 @@ def main( if Api.telemetry in impls: setup_logger(impls[Api.telemetry]) - all_endpoints = api_endpoints() + all_endpoints = get_all_api_endpoints() if config.apis_to_serve: apis_to_serve = set(config.apis_to_serve) diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index a9a3d86e9..d661b6649 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -25,13 +25,6 @@ class Api(Enum): memory_banks = "memory_banks" -@json_schema_type -class ApiEndpoint(BaseModel): - route: str - method: str - name: str - - @json_schema_type class ProviderSpec(BaseModel): api: Api From fe4aabd690c0fe812812363d16a2df8f72763261 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 2 Oct 2024 14:05:59 -0700 Subject: [PATCH 03/11] provider_id => provider_type, adapter_id => adapter_type --- docs/resources/llama-stack-spec.html | 12 +++++----- docs/resources/llama-stack-spec.yaml | 12 +++++----- llama_stack/apis/memory_banks/memory_banks.py | 2 +- llama_stack/apis/models/models.py | 2 +- llama_stack/apis/shields/shields.py | 2 +- llama_stack/cli/stack/list_providers.py | 4 ++-- llama_stack/distribution/configure.py | 10 ++++----- llama_stack/distribution/datatypes.py | 2 +- llama_stack/distribution/distribution.py | 2 +- llama_stack/distribution/request_headers.py | 4 ++-- llama_stack/distribution/resolver.py | 16 +++++++------- .../docker/llamastack-local-cpu/run.yaml | 10 ++++----- .../docker/llamastack-local-gpu/run.yaml | 10 ++++----- llama_stack/distribution/utils/dynamic.py | 4 ++-- llama_stack/providers/datatypes.py | 18 +++++++-------- llama_stack/providers/registry/agents.py | 4 ++-- llama_stack/providers/registry/inference.py | 22 +++++++++---------- llama_stack/providers/registry/memory.py | 8 +++---- llama_stack/providers/registry/safety.py | 8 +++---- llama_stack/providers/registry/telemetry.py | 6 ++--- tests/examples/local-run.yaml | 10 ++++----- 21 files changed, 83 insertions(+), 85 deletions(-) diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index c77ebe2a7..814c2edef 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -4783,7 +4783,7 @@ "provider_config": { "type": "object", "properties": { - "provider_id": { + "provider_type": { "type": "string" }, "config": { @@ -4814,7 +4814,7 @@ }, "additionalProperties": false, "required": [ - "provider_id", + "provider_type", "config" ] } @@ -4843,7 +4843,7 @@ "provider_config": { "type": "object", "properties": { - "provider_id": { + "provider_type": { "type": "string" }, "config": { @@ -4874,7 +4874,7 @@ }, "additionalProperties": false, "required": [ - "provider_id", + "provider_type", "config" ] } @@ -4894,7 +4894,7 @@ "provider_config": { "type": "object", "properties": { - "provider_id": { + "provider_type": { "type": "string" }, "config": { @@ -4925,7 +4925,7 @@ }, "additionalProperties": false, "required": [ - "provider_id", + "provider_type", "config" ] } diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 83b415649..3557365d5 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -1117,10 +1117,10 @@ components: - type: array - type: object type: object - provider_id: + provider_type: type: string required: - - provider_id + - provider_type - config type: object required: @@ -1362,10 +1362,10 @@ components: - type: array - type: object type: object - provider_id: + provider_type: type: string required: - - provider_id + - provider_type - config type: object required: @@ -1916,10 +1916,10 @@ components: - type: array - type: object type: object - provider_id: + provider_type: type: string required: - - provider_id + - provider_type - config type: object shield_type: diff --git a/llama_stack/apis/memory_banks/memory_banks.py b/llama_stack/apis/memory_banks/memory_banks.py index b4e35fb0c..53ca83e84 100644 --- a/llama_stack/apis/memory_banks/memory_banks.py +++ b/llama_stack/apis/memory_banks/memory_banks.py @@ -18,7 +18,7 @@ from llama_stack.distribution.datatypes import GenericProviderConfig class MemoryBankSpec(BaseModel): bank_type: MemoryBankType provider_config: GenericProviderConfig = Field( - description="Provider config for the model, including provider_id, and corresponding config. ", + description="Provider config for the model, including provider_type, and corresponding config. ", ) diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index d542517ba..2952a8dee 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -20,7 +20,7 @@ class ModelServingSpec(BaseModel): description="All metadatas associated with llama model (defined in llama_models.models.sku_list).", ) provider_config: GenericProviderConfig = Field( - description="Provider config for the model, including provider_id, and corresponding config. ", + description="Provider config for the model, including provider_type, and corresponding config. ", ) diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py index 006178b5d..2b8242263 100644 --- a/llama_stack/apis/shields/shields.py +++ b/llama_stack/apis/shields/shields.py @@ -16,7 +16,7 @@ from llama_stack.distribution.datatypes import GenericProviderConfig class ShieldSpec(BaseModel): shield_type: str provider_config: GenericProviderConfig = Field( - description="Provider config for the model, including provider_id, and corresponding config. ", + description="Provider config for the model, including provider_type, and corresponding config. ", ) diff --git a/llama_stack/cli/stack/list_providers.py b/llama_stack/cli/stack/list_providers.py index 25875ecbf..96e978826 100644 --- a/llama_stack/cli/stack/list_providers.py +++ b/llama_stack/cli/stack/list_providers.py @@ -47,11 +47,11 @@ class StackListProviders(Subcommand): rows = [] for spec in providers_for_api.values(): - if spec.provider_id == "sample": + if spec.provider_type == "sample": continue rows.append( [ - spec.provider_id, + spec.provider_type, ",".join(spec.pip_packages), ] ) diff --git a/llama_stack/distribution/configure.py b/llama_stack/distribution/configure.py index e9b682dc0..e03b201ec 100644 --- a/llama_stack/distribution/configure.py +++ b/llama_stack/distribution/configure.py @@ -109,7 +109,7 @@ def configure_api_providers( routing_entries.append( RoutableProviderConfig( routing_key=routing_key, - provider_id=p, + provider_type=p, config=cfg.dict(), ) ) @@ -120,7 +120,7 @@ def configure_api_providers( routing_entries.append( RoutableProviderConfig( routing_key=[s.value for s in MetaReferenceShieldType], - provider_id=p, + provider_type=p, config=cfg.dict(), ) ) @@ -133,7 +133,7 @@ def configure_api_providers( routing_entries.append( RoutableProviderConfig( routing_key=routing_key, - provider_id=p, + provider_type=p, config=cfg.dict(), ) ) @@ -153,7 +153,7 @@ def configure_api_providers( routing_entries.append( RoutableProviderConfig( routing_key=routing_key, - provider_id=p, + provider_type=p, config=cfg.dict(), ) ) @@ -164,7 +164,7 @@ def configure_api_providers( ) else: config.api_providers[api_str] = GenericProviderConfig( - provider_id=p, + provider_type=p, config=cfg.dict(), ) diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index fa88ad5cf..c18f715fe 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -71,7 +71,7 @@ Provider configurations for each of the APIs provided by this package. E.g. The following is a ProviderRoutingEntry for models: - routing_key: Meta-Llama3.1-8B-Instruct - provider_id: meta-reference + provider_type: meta-reference config: model: Meta-Llama3.1-8B-Instruct quantization: null diff --git a/llama_stack/distribution/distribution.py b/llama_stack/distribution/distribution.py index 0c47fd750..218105f59 100644 --- a/llama_stack/distribution/distribution.py +++ b/llama_stack/distribution/distribution.py @@ -51,7 +51,7 @@ def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]: module = importlib.import_module(f"llama_stack.providers.registry.{name}") ret[api] = { "remote": remote_provider_spec(api), - **{a.provider_id: a for a in module.available_providers()}, + **{a.provider_type: a for a in module.available_providers()}, } return ret diff --git a/llama_stack/distribution/request_headers.py b/llama_stack/distribution/request_headers.py index 990fa66d5..bbb1fff9d 100644 --- a/llama_stack/distribution/request_headers.py +++ b/llama_stack/distribution/request_headers.py @@ -18,10 +18,10 @@ class NeedsRequestProviderData: spec = self.__provider_spec__ assert spec, f"Provider spec not set on {self.__class__}" - provider_id = spec.provider_id + provider_type = spec.provider_type validator_class = spec.provider_data_validator if not validator_class: - raise ValueError(f"Provider {provider_id} does not have a validator") + raise ValueError(f"Provider {provider_type} does not have a validator") val = getattr(_THREAD_LOCAL, "provider_data_header_value", None) if not val: diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 8c8084969..091769d74 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -34,11 +34,11 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An if isinstance(config, PlaceholderProviderConfig): continue - if config.provider_id not in providers: + if config.provider_type not in providers: raise ValueError( - f"Unknown provider `{config.provider_id}` is not available for API `{api}`" + f"Provider `{config.provider_type}` is not available for API `{api}`" ) - specs[api] = providers[config.provider_id] + specs[api] = providers[config.provider_type] configs[api] = config apis_to_serve = run_config.apis_to_serve or set( @@ -68,12 +68,12 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An inner_specs = [] inner_deps = [] for rt_entry in routing_table: - if rt_entry.provider_id not in providers: + if rt_entry.provider_type not in providers: raise ValueError( - f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`" + f"Provider `{rt_entry.provider_type}` is not available for API `{api}`" ) - inner_specs.append(providers[rt_entry.provider_id]) - inner_deps.extend(providers[rt_entry.provider_id].api_dependencies) + inner_specs.append(providers[rt_entry.provider_type]) + inner_deps.extend(providers[rt_entry.provider_type].api_dependencies) specs[source_api] = RoutingTableProviderSpec( api=source_api, @@ -94,7 +94,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An sorted_specs = topological_sort(specs.values()) print(f"Resolved {len(sorted_specs)} providers in topological order") for spec in sorted_specs: - print(f" {spec.api}: {spec.provider_id}") + print(f" {spec.api}: {spec.provider_type}") print("") impls = {} for spec in sorted_specs: diff --git a/llama_stack/distribution/templates/docker/llamastack-local-cpu/run.yaml b/llama_stack/distribution/templates/docker/llamastack-local-cpu/run.yaml index 0a845582c..aa5bb916f 100644 --- a/llama_stack/distribution/templates/docker/llamastack-local-cpu/run.yaml +++ b/llama_stack/distribution/templates/docker/llamastack-local-cpu/run.yaml @@ -18,7 +18,7 @@ api_providers: providers: - meta-reference agents: - provider_id: meta-reference + provider_type: meta-reference config: persistence_store: namespace: null @@ -28,22 +28,22 @@ api_providers: providers: - meta-reference telemetry: - provider_id: meta-reference + provider_type: meta-reference config: {} routing_table: inference: - - provider_id: remote::ollama + - provider_type: remote::ollama config: host: localhost port: 6000 routing_key: Meta-Llama3.1-8B-Instruct safety: - - provider_id: meta-reference + - provider_type: meta-reference config: llama_guard_shield: null prompt_guard_shield: null routing_key: ["llama_guard", "code_scanner_guard", "injection_shield", "jailbreak_shield"] memory: - - provider_id: meta-reference + - provider_type: meta-reference config: {} routing_key: vector diff --git a/llama_stack/distribution/templates/docker/llamastack-local-gpu/run.yaml b/llama_stack/distribution/templates/docker/llamastack-local-gpu/run.yaml index 66f6cfcef..bb7a2cc0d 100644 --- a/llama_stack/distribution/templates/docker/llamastack-local-gpu/run.yaml +++ b/llama_stack/distribution/templates/docker/llamastack-local-gpu/run.yaml @@ -18,7 +18,7 @@ api_providers: providers: - meta-reference agents: - provider_id: meta-reference + provider_type: meta-reference config: persistence_store: namespace: null @@ -28,11 +28,11 @@ api_providers: providers: - meta-reference telemetry: - provider_id: meta-reference + provider_type: meta-reference config: {} routing_table: inference: - - provider_id: meta-reference + - provider_type: meta-reference config: model: Llama3.1-8B-Instruct quantization: null @@ -41,12 +41,12 @@ routing_table: max_batch_size: 1 routing_key: Llama3.1-8B-Instruct safety: - - provider_id: meta-reference + - provider_type: meta-reference config: llama_guard_shield: null prompt_guard_shield: null routing_key: ["llama_guard", "code_scanner_guard", "injection_shield", "jailbreak_shield"] memory: - - provider_id: meta-reference + - provider_type: meta-reference config: {} routing_key: vector diff --git a/llama_stack/distribution/utils/dynamic.py b/llama_stack/distribution/utils/dynamic.py index 7c2ac2e6a..91aeb4ac7 100644 --- a/llama_stack/distribution/utils/dynamic.py +++ b/llama_stack/distribution/utils/dynamic.py @@ -46,11 +46,11 @@ async def instantiate_provider( assert isinstance(provider_config, List) routing_table = provider_config - inner_specs = {x.provider_id: x for x in provider_spec.inner_specs} + inner_specs = {x.provider_type: x for x in provider_spec.inner_specs} inner_impls = [] for routing_entry in routing_table: impl = await instantiate_provider( - inner_specs[routing_entry.provider_id], + inner_specs[routing_entry.provider_type], deps, routing_entry, ) diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index d661b6649..a328acd6b 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -28,7 +28,7 @@ class Api(Enum): @json_schema_type class ProviderSpec(BaseModel): api: Api - provider_id: str + provider_type: str config_class: str = Field( ..., description="Fully-qualified classname of the config for this provider", @@ -56,7 +56,7 @@ class RoutableProvider(Protocol): class GenericProviderConfig(BaseModel): - provider_id: str + provider_type: str config: Dict[str, Any] @@ -76,7 +76,7 @@ class RoutableProviderConfig(GenericProviderConfig): # Example: /inference, /safety @json_schema_type class AutoRoutedProviderSpec(ProviderSpec): - provider_id: str = "router" + provider_type: str = "router" config_class: str = "" docker_image: Optional[str] = None @@ -101,7 +101,7 @@ class AutoRoutedProviderSpec(ProviderSpec): # Example: /models, /shields @json_schema_type class RoutingTableProviderSpec(ProviderSpec): - provider_id: str = "routing_table" + provider_type: str = "routing_table" config_class: str = "" docker_image: Optional[str] = None @@ -119,7 +119,7 @@ class RoutingTableProviderSpec(ProviderSpec): @json_schema_type class AdapterSpec(BaseModel): - adapter_id: str = Field( + adapter_type: str = Field( ..., description="Unique identifier for this adapter", ) @@ -179,8 +179,8 @@ class RemoteProviderConfig(BaseModel): return f"http://{self.host}:{self.port}" -def remote_provider_id(adapter_id: str) -> str: - return f"remote::{adapter_id}" +def remote_provider_type(adapter_type: str) -> str: + return f"remote::{adapter_type}" @json_schema_type @@ -226,8 +226,8 @@ def remote_provider_spec( if adapter and adapter.config_class else "llama_stack.distribution.datatypes.RemoteProviderConfig" ) - provider_id = remote_provider_id(adapter.adapter_id) if adapter else "remote" + provider_type = remote_provider_type(adapter.adapter_type) if adapter else "remote" return RemoteProviderSpec( - api=api, provider_id=provider_id, config_class=config_class, adapter=adapter + api=api, provider_type=provider_type, config_class=config_class, adapter=adapter ) diff --git a/llama_stack/providers/registry/agents.py b/llama_stack/providers/registry/agents.py index 16a872572..2603b5faf 100644 --- a/llama_stack/providers/registry/agents.py +++ b/llama_stack/providers/registry/agents.py @@ -14,7 +14,7 @@ def available_providers() -> List[ProviderSpec]: return [ InlineProviderSpec( api=Api.agents, - provider_id="meta-reference", + provider_type="meta-reference", pip_packages=[ "matplotlib", "pillow", @@ -33,7 +33,7 @@ def available_providers() -> List[ProviderSpec]: remote_provider_spec( api=Api.agents, adapter=AdapterSpec( - adapter_id="sample", + adapter_type="sample", pip_packages=[], module="llama_stack.providers.adapters.agents.sample", config_class="llama_stack.providers.adapters.agents.sample.SampleConfig", diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 8f9786a95..47e142201 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -13,7 +13,7 @@ def available_providers() -> List[ProviderSpec]: return [ InlineProviderSpec( api=Api.inference, - provider_id="meta-reference", + provider_type="meta-reference", pip_packages=[ "accelerate", "blobfile", @@ -30,7 +30,7 @@ def available_providers() -> List[ProviderSpec]: remote_provider_spec( api=Api.inference, adapter=AdapterSpec( - adapter_id="sample", + adapter_type="sample", pip_packages=[], module="llama_stack.providers.adapters.inference.sample", config_class="llama_stack.providers.adapters.inference.sample.SampleConfig", @@ -39,7 +39,7 @@ def available_providers() -> List[ProviderSpec]: remote_provider_spec( api=Api.inference, adapter=AdapterSpec( - adapter_id="ollama", + adapter_type="ollama", pip_packages=["ollama"], module="llama_stack.providers.adapters.inference.ollama", ), @@ -47,7 +47,7 @@ def available_providers() -> List[ProviderSpec]: remote_provider_spec( api=Api.inference, adapter=AdapterSpec( - adapter_id="tgi", + adapter_type="tgi", pip_packages=["huggingface_hub", "aiohttp"], module="llama_stack.providers.adapters.inference.tgi", config_class="llama_stack.providers.adapters.inference.tgi.TGIImplConfig", @@ -56,7 +56,7 @@ def available_providers() -> List[ProviderSpec]: remote_provider_spec( api=Api.inference, adapter=AdapterSpec( - adapter_id="hf::serverless", + adapter_type="hf::serverless", pip_packages=["huggingface_hub", "aiohttp"], module="llama_stack.providers.adapters.inference.tgi", config_class="llama_stack.providers.adapters.inference.tgi.InferenceAPIImplConfig", @@ -65,7 +65,7 @@ def available_providers() -> List[ProviderSpec]: remote_provider_spec( api=Api.inference, adapter=AdapterSpec( - adapter_id="hf::endpoint", + adapter_type="hf::endpoint", pip_packages=["huggingface_hub", "aiohttp"], module="llama_stack.providers.adapters.inference.tgi", config_class="llama_stack.providers.adapters.inference.tgi.InferenceEndpointImplConfig", @@ -74,7 +74,7 @@ def available_providers() -> List[ProviderSpec]: remote_provider_spec( api=Api.inference, adapter=AdapterSpec( - adapter_id="fireworks", + adapter_type="fireworks", pip_packages=[ "fireworks-ai", ], @@ -85,7 +85,7 @@ def available_providers() -> List[ProviderSpec]: remote_provider_spec( api=Api.inference, adapter=AdapterSpec( - adapter_id="together", + adapter_type="together", pip_packages=[ "together", ], @@ -97,10 +97,8 @@ def available_providers() -> List[ProviderSpec]: remote_provider_spec( api=Api.inference, adapter=AdapterSpec( - adapter_id="bedrock", - pip_packages=[ - "boto3" - ], + adapter_type="bedrock", + pip_packages=["boto3"], module="llama_stack.providers.adapters.inference.bedrock", config_class="llama_stack.providers.adapters.inference.bedrock.BedrockConfig", ), diff --git a/llama_stack/providers/registry/memory.py b/llama_stack/providers/registry/memory.py index d6776ff69..4687e262c 100644 --- a/llama_stack/providers/registry/memory.py +++ b/llama_stack/providers/registry/memory.py @@ -34,7 +34,7 @@ def available_providers() -> List[ProviderSpec]: return [ InlineProviderSpec( api=Api.memory, - provider_id="meta-reference", + provider_type="meta-reference", pip_packages=EMBEDDING_DEPS + ["faiss-cpu"], module="llama_stack.providers.impls.meta_reference.memory", config_class="llama_stack.providers.impls.meta_reference.memory.FaissImplConfig", @@ -42,7 +42,7 @@ def available_providers() -> List[ProviderSpec]: remote_provider_spec( Api.memory, AdapterSpec( - adapter_id="chromadb", + adapter_type="chromadb", pip_packages=EMBEDDING_DEPS + ["chromadb-client"], module="llama_stack.providers.adapters.memory.chroma", ), @@ -50,7 +50,7 @@ def available_providers() -> List[ProviderSpec]: remote_provider_spec( Api.memory, AdapterSpec( - adapter_id="pgvector", + adapter_type="pgvector", pip_packages=EMBEDDING_DEPS + ["psycopg2-binary"], module="llama_stack.providers.adapters.memory.pgvector", config_class="llama_stack.providers.adapters.memory.pgvector.PGVectorConfig", @@ -59,7 +59,7 @@ def available_providers() -> List[ProviderSpec]: remote_provider_spec( api=Api.memory, adapter=AdapterSpec( - adapter_id="sample", + adapter_type="sample", pip_packages=[], module="llama_stack.providers.adapters.memory.sample", config_class="llama_stack.providers.adapters.memory.sample.SampleConfig", diff --git a/llama_stack/providers/registry/safety.py b/llama_stack/providers/registry/safety.py index e0022f02b..58307be11 100644 --- a/llama_stack/providers/registry/safety.py +++ b/llama_stack/providers/registry/safety.py @@ -19,7 +19,7 @@ def available_providers() -> List[ProviderSpec]: return [ InlineProviderSpec( api=Api.safety, - provider_id="meta-reference", + provider_type="meta-reference", pip_packages=[ "codeshield", "transformers", @@ -34,7 +34,7 @@ def available_providers() -> List[ProviderSpec]: remote_provider_spec( api=Api.safety, adapter=AdapterSpec( - adapter_id="sample", + adapter_type="sample", pip_packages=[], module="llama_stack.providers.adapters.safety.sample", config_class="llama_stack.providers.adapters.safety.sample.SampleConfig", @@ -43,7 +43,7 @@ def available_providers() -> List[ProviderSpec]: remote_provider_spec( api=Api.safety, adapter=AdapterSpec( - adapter_id="bedrock", + adapter_type="bedrock", pip_packages=["boto3"], module="llama_stack.providers.adapters.safety.bedrock", config_class="llama_stack.providers.adapters.safety.bedrock.BedrockSafetyConfig", @@ -52,7 +52,7 @@ def available_providers() -> List[ProviderSpec]: remote_provider_spec( api=Api.safety, adapter=AdapterSpec( - adapter_id="together", + adapter_type="together", pip_packages=[ "together", ], diff --git a/llama_stack/providers/registry/telemetry.py b/llama_stack/providers/registry/telemetry.py index 02b71077e..39bcb75d8 100644 --- a/llama_stack/providers/registry/telemetry.py +++ b/llama_stack/providers/registry/telemetry.py @@ -13,7 +13,7 @@ def available_providers() -> List[ProviderSpec]: return [ InlineProviderSpec( api=Api.telemetry, - provider_id="meta-reference", + provider_type="meta-reference", pip_packages=[], module="llama_stack.providers.impls.meta_reference.telemetry", config_class="llama_stack.providers.impls.meta_reference.telemetry.ConsoleConfig", @@ -21,7 +21,7 @@ def available_providers() -> List[ProviderSpec]: remote_provider_spec( api=Api.telemetry, adapter=AdapterSpec( - adapter_id="sample", + adapter_type="sample", pip_packages=[], module="llama_stack.providers.adapters.telemetry.sample", config_class="llama_stack.providers.adapters.telemetry.sample.SampleConfig", @@ -30,7 +30,7 @@ def available_providers() -> List[ProviderSpec]: remote_provider_spec( api=Api.telemetry, adapter=AdapterSpec( - adapter_id="opentelemetry-jaeger", + adapter_type="opentelemetry-jaeger", pip_packages=[ "opentelemetry-api", "opentelemetry-sdk", diff --git a/tests/examples/local-run.yaml b/tests/examples/local-run.yaml index 98d105233..94340c4d1 100644 --- a/tests/examples/local-run.yaml +++ b/tests/examples/local-run.yaml @@ -18,7 +18,7 @@ api_providers: providers: - meta-reference agents: - provider_id: meta-reference + provider_type: meta-reference config: persistence_store: namespace: null @@ -28,11 +28,11 @@ api_providers: providers: - meta-reference telemetry: - provider_id: meta-reference + provider_type: meta-reference config: {} routing_table: inference: - - provider_id: meta-reference + - provider_type: meta-reference config: model: Meta-Llama3.1-8B-Instruct quantization: null @@ -41,7 +41,7 @@ routing_table: max_batch_size: 1 routing_key: Meta-Llama3.1-8B-Instruct safety: - - provider_id: meta-reference + - provider_type: meta-reference config: llama_guard_shield: model: Llama-Guard-3-1B @@ -52,6 +52,6 @@ routing_table: model: Prompt-Guard-86M routing_key: ["llama_guard", "code_scanner_guard", "injection_shield", "jailbreak_shield"] memory: - - provider_id: meta-reference + - provider_type: meta-reference config: {} routing_key: vector From 01d93be948cfa613ba06431d1fadc0856a6ec672 Mon Sep 17 00:00:00 2001 From: Adrian Cole <64215+codefromthecrypt@users.noreply.github.com> Date: Thu, 3 Oct 2024 05:26:20 +0800 Subject: [PATCH 04/11] Adds markdown-link-check and fixes a broken link (#165) Signed-off-by: Adrian Cole Co-authored-by: Ashwin Bharambe --- .pre-commit-config.yaml | 6 ++++++ docs/cli_reference.md | 2 +- llama_stack/providers/utils/inference/augment_messages.py | 3 ++- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c00ea3040..555a475b2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -51,3 +51,9 @@ repos: # hooks: # - id: pydoclint # args: [--config=pyproject.toml] + +- repo: https://github.com/tcort/markdown-link-check + rev: v3.11.2 + hooks: + - id: markdown-link-check + args: ['--quiet'] diff --git a/docs/cli_reference.md b/docs/cli_reference.md index 28874641f..3541d0b4e 100644 --- a/docs/cli_reference.md +++ b/docs/cli_reference.md @@ -5,7 +5,7 @@ The `llama` CLI tool helps you setup and use the Llama toolchain & agentic syste ### Subcommands 1. `download`: `llama` cli tools supports downloading the model from Meta or Hugging Face. 2. `model`: Lists available models and their properties. -3. `stack`: Allows you to build and run a Llama Stack server. You can read more about this [here](/docs/cli_reference.md#step-3-building-configuring-and-running-llama-stack-servers). +3. `stack`: Allows you to build and run a Llama Stack server. You can read more about this [here](cli_reference.md#step-3-building-and-configuring-llama-stack-distributions). ### Sample Usage diff --git a/llama_stack/providers/utils/inference/augment_messages.py b/llama_stack/providers/utils/inference/augment_messages.py index 10375cf0e..613a39525 100644 --- a/llama_stack/providers/utils/inference/augment_messages.py +++ b/llama_stack/providers/utils/inference/augment_messages.py @@ -34,7 +34,8 @@ def augment_messages_for_tools(request: ChatCompletionRequest) -> List[Message]: return request.messages if model.model_family == ModelFamily.llama3_1 or ( - model.model_family == ModelFamily.llama3_2 and is_multimodal(model.core_model_id) + model.model_family == ModelFamily.llama3_2 + and is_multimodal(model.core_model_id) ): # llama3.1 and llama3.2 multimodal models follow the same tool prompt format return augment_messages_for_tools_llama_3_1(request) From 8d049000e3fd7a060238376e7d5e23b8a527e3cd Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 2 Oct 2024 15:13:24 -0700 Subject: [PATCH 05/11] Add an introspection "Api.inspect" API --- docs/openapi_generator/generate.py | 2 + docs/resources/llama-stack-spec.html | 203 +++++++++++++++++-- docs/resources/llama-stack-spec.yaml | 127 +++++++++++- llama_stack/apis/inspect/__init__.py | 7 + llama_stack/apis/inspect/client.py | 82 ++++++++ llama_stack/apis/inspect/inspect.py | 40 ++++ llama_stack/distribution/datatypes.py | 47 +++++ llama_stack/distribution/distribution.py | 2 + llama_stack/distribution/inspect.py | 54 +++++ llama_stack/distribution/resolver.py | 70 ++++++- llama_stack/distribution/server/endpoints.py | 3 + llama_stack/distribution/server/server.py | 23 +-- llama_stack/distribution/utils/dynamic.py | 60 ------ llama_stack/providers/datatypes.py | 73 +------ 14 files changed, 619 insertions(+), 174 deletions(-) create mode 100644 llama_stack/apis/inspect/__init__.py create mode 100644 llama_stack/apis/inspect/client.py create mode 100644 llama_stack/apis/inspect/inspect.py create mode 100644 llama_stack/distribution/inspect.py diff --git a/docs/openapi_generator/generate.py b/docs/openapi_generator/generate.py index c5ba23b14..c5b156bb8 100644 --- a/docs/openapi_generator/generate.py +++ b/docs/openapi_generator/generate.py @@ -46,6 +46,7 @@ from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.models import * # noqa: F403 from llama_stack.apis.memory_banks import * # noqa: F403 from llama_stack.apis.shields import * # noqa: F403 +from llama_stack.apis.inspect import * # noqa: F403 class LlamaStack( @@ -63,6 +64,7 @@ class LlamaStack( Evaluations, Models, Shields, + Inspect, ): pass diff --git a/docs/resources/llama-stack-spec.html b/docs/resources/llama-stack-spec.html index 814c2edef..0d06ce03d 100644 --- a/docs/resources/llama-stack-spec.html +++ b/docs/resources/llama-stack-spec.html @@ -21,7 +21,7 @@ "info": { "title": "[DRAFT] Llama Stack Specification", "version": "0.0.1", - "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-09-23 16:58:41.469308" + "description": "This is the specification of the llama stack that provides\n a set of endpoints and their corresponding interfaces that are tailored to\n best leverage Llama Models. The specification is still in draft and subject to change.\n Generated at 2024-10-02 15:40:53.008257" }, "servers": [ { @@ -1542,6 +1542,36 @@ ] } }, + "/health": { + "get": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HealthInfo" + } + } + } + } + }, + "tags": [ + "Inspect" + ], + "parameters": [ + { + "name": "X-LlamaStack-ProviderData", + "in": "header", + "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", + "required": false, + "schema": { + "type": "string" + } + } + ] + } + }, "/memory/insert": { "post": { "responses": { @@ -1665,6 +1695,75 @@ ] } }, + "/providers/list": { + "get": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "type": "object", + "additionalProperties": { + "$ref": "#/components/schemas/ProviderInfo" + } + } + } + } + } + }, + "tags": [ + "Inspect" + ], + "parameters": [ + { + "name": "X-LlamaStack-ProviderData", + "in": "header", + "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", + "required": false, + "schema": { + "type": "string" + } + } + ] + } + }, + "/routes/list": { + "get": { + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "type": "object", + "additionalProperties": { + "type": "array", + "items": { + "$ref": "#/components/schemas/RouteInfo" + } + } + } + } + } + } + }, + "tags": [ + "Inspect" + ], + "parameters": [ + { + "name": "X-LlamaStack-ProviderData", + "in": "header", + "description": "JSON-encoded provider data which will be made available to the adapter servicing the API", + "required": false, + "schema": { + "type": "string" + } + } + ] + } + }, "/shields/list": { "get": { "responses": { @@ -5086,6 +5185,18 @@ "job_uuid" ] }, + "HealthInfo": { + "type": "object", + "properties": { + "status": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "status" + ] + }, "InsertDocumentsRequest": { "type": "object", "properties": { @@ -5108,6 +5219,45 @@ "documents" ] }, + "ProviderInfo": { + "type": "object", + "properties": { + "provider_type": { + "type": "string" + }, + "description": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "provider_type", + "description" + ] + }, + "RouteInfo": { + "type": "object", + "properties": { + "route": { + "type": "string" + }, + "method": { + "type": "string" + }, + "providers": { + "type": "array", + "items": { + "type": "string" + } + } + }, + "additionalProperties": false, + "required": [ + "route", + "method", + "providers" + ] + }, "LogSeverity": { "type": "string", "enum": [ @@ -6220,19 +6370,34 @@ ], "tags": [ { - "name": "Shields" + "name": "Datasets" + }, + { + "name": "Inspect" + }, + { + "name": "Memory" }, { "name": "BatchInference" }, { - "name": "RewardScoring" + "name": "Agents" + }, + { + "name": "Inference" + }, + { + "name": "Shields" }, { "name": "SyntheticDataGeneration" }, { - "name": "Agents" + "name": "Models" + }, + { + "name": "RewardScoring" }, { "name": "MemoryBanks" @@ -6241,13 +6406,7 @@ "name": "Safety" }, { - "name": "Models" - }, - { - "name": "Inference" - }, - { - "name": "Memory" + "name": "Evaluations" }, { "name": "Telemetry" @@ -6255,12 +6414,6 @@ { "name": "PostTraining" }, - { - "name": "Datasets" - }, - { - "name": "Evaluations" - }, { "name": "BuiltinTool", "description": "" @@ -6653,10 +6806,22 @@ "name": "PostTrainingJob", "description": "" }, + { + "name": "HealthInfo", + "description": "" + }, { "name": "InsertDocumentsRequest", "description": "" }, + { + "name": "ProviderInfo", + "description": "" + }, + { + "name": "RouteInfo", + "description": "" + }, { "name": "LogSeverity", "description": "" @@ -6787,6 +6952,7 @@ "Datasets", "Evaluations", "Inference", + "Inspect", "Memory", "MemoryBanks", "Models", @@ -6857,6 +7023,7 @@ "FunctionCallToolDefinition", "GetAgentsSessionRequest", "GetDocumentsRequest", + "HealthInfo", "ImageMedia", "InferenceStep", "InsertDocumentsRequest", @@ -6880,6 +7047,7 @@ "PostTrainingJobStatus", "PostTrainingJobStatusResponse", "PreferenceOptimizeRequest", + "ProviderInfo", "QLoraFinetuningConfig", "QueryDocumentsRequest", "QueryDocumentsResponse", @@ -6888,6 +7056,7 @@ "RestAPIMethod", "RewardScoreRequest", "RewardScoringResponse", + "RouteInfo", "RunShieldRequest", "RunShieldResponse", "SafetyViolation", diff --git a/docs/resources/llama-stack-spec.yaml b/docs/resources/llama-stack-spec.yaml index 3557365d5..317d1ee33 100644 --- a/docs/resources/llama-stack-spec.yaml +++ b/docs/resources/llama-stack-spec.yaml @@ -908,6 +908,14 @@ components: required: - document_ids type: object + HealthInfo: + additionalProperties: false + properties: + status: + type: string + required: + - status + type: object ImageMedia: additionalProperties: false properties: @@ -1543,6 +1551,17 @@ components: - hyperparam_search_config - logger_config type: object + ProviderInfo: + additionalProperties: false + properties: + description: + type: string + provider_type: + type: string + required: + - provider_type + - description + type: object QLoraFinetuningConfig: additionalProperties: false properties: @@ -1704,6 +1723,22 @@ components: title: Response from the reward scoring. Batch of (prompt, response, score) tuples that pass the threshold. type: object + RouteInfo: + additionalProperties: false + properties: + method: + type: string + providers: + items: + type: string + type: array + route: + type: string + required: + - route + - method + - providers + type: object RunShieldRequest: additionalProperties: false properties: @@ -2569,7 +2604,7 @@ info: description: "This is the specification of the llama stack that provides\n \ \ a set of endpoints and their corresponding interfaces that are tailored\ \ to\n best leverage Llama Models. The specification is still in\ - \ draft and subject to change.\n Generated at 2024-09-23 16:58:41.469308" + \ draft and subject to change.\n Generated at 2024-10-02 15:40:53.008257" title: '[DRAFT] Llama Stack Specification' version: 0.0.1 jsonSchemaDialect: https://json-schema.org/draft/2020-12/schema @@ -3093,6 +3128,25 @@ paths: description: OK tags: - Evaluations + /health: + get: + parameters: + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-ProviderData + required: false + schema: + type: string + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/HealthInfo' + description: OK + tags: + - Inspect /inference/chat_completion: post: parameters: @@ -3637,6 +3691,27 @@ paths: description: OK tags: - PostTraining + /providers/list: + get: + parameters: + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-ProviderData + required: false + schema: + type: string + responses: + '200': + content: + application/json: + schema: + additionalProperties: + $ref: '#/components/schemas/ProviderInfo' + type: object + description: OK + tags: + - Inspect /reward_scoring/score: post: parameters: @@ -3662,6 +3737,29 @@ paths: description: OK tags: - RewardScoring + /routes/list: + get: + parameters: + - description: JSON-encoded provider data which will be made available to the + adapter servicing the API + in: header + name: X-LlamaStack-ProviderData + required: false + schema: + type: string + responses: + '200': + content: + application/json: + schema: + additionalProperties: + items: + $ref: '#/components/schemas/RouteInfo' + type: array + type: object + description: OK + tags: + - Inspect /safety/run_shield: post: parameters: @@ -3807,20 +3905,21 @@ security: servers: - url: http://any-hosted-llama-stack.com tags: -- name: Shields +- name: Datasets +- name: Inspect +- name: Memory - name: BatchInference -- name: RewardScoring -- name: SyntheticDataGeneration - name: Agents +- name: Inference +- name: Shields +- name: SyntheticDataGeneration +- name: Models +- name: RewardScoring - name: MemoryBanks - name: Safety -- name: Models -- name: Inference -- name: Memory +- name: Evaluations - name: Telemetry - name: PostTraining -- name: Datasets -- name: Evaluations - description: name: BuiltinTool - description: name: PostTrainingJob +- description: + name: HealthInfo - description: name: InsertDocumentsRequest +- description: + name: ProviderInfo +- description: + name: RouteInfo - description: name: LogSeverity - description: @@ -4236,6 +4341,7 @@ x-tagGroups: - Datasets - Evaluations - Inference + - Inspect - Memory - MemoryBanks - Models @@ -4303,6 +4409,7 @@ x-tagGroups: - FunctionCallToolDefinition - GetAgentsSessionRequest - GetDocumentsRequest + - HealthInfo - ImageMedia - InferenceStep - InsertDocumentsRequest @@ -4326,6 +4433,7 @@ x-tagGroups: - PostTrainingJobStatus - PostTrainingJobStatusResponse - PreferenceOptimizeRequest + - ProviderInfo - QLoraFinetuningConfig - QueryDocumentsRequest - QueryDocumentsResponse @@ -4334,6 +4442,7 @@ x-tagGroups: - RestAPIMethod - RewardScoreRequest - RewardScoringResponse + - RouteInfo - RunShieldRequest - RunShieldResponse - SafetyViolation diff --git a/llama_stack/apis/inspect/__init__.py b/llama_stack/apis/inspect/__init__.py new file mode 100644 index 000000000..88ba8e908 --- /dev/null +++ b/llama_stack/apis/inspect/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .inspect import * # noqa: F401 F403 diff --git a/llama_stack/apis/inspect/client.py b/llama_stack/apis/inspect/client.py new file mode 100644 index 000000000..65d8b83ed --- /dev/null +++ b/llama_stack/apis/inspect/client.py @@ -0,0 +1,82 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio + +from typing import List + +import fire +import httpx +from termcolor import cprint + +from .inspect import * # noqa: F403 + + +class InspectClient(Inspect): + def __init__(self, base_url: str): + self.base_url = base_url + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def list_providers(self) -> Dict[str, ProviderInfo]: + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.base_url}/providers/list", + headers={"Content-Type": "application/json"}, + ) + response.raise_for_status() + print(response.json()) + return { + k: [ProviderInfo(**vi) for vi in v] for k, v in response.json().items() + } + + async def list_routes(self) -> Dict[str, List[RouteInfo]]: + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.base_url}/routes/list", + headers={"Content-Type": "application/json"}, + ) + response.raise_for_status() + return { + k: [RouteInfo(**vi) for vi in v] for k, v in response.json().items() + } + + async def health(self) -> HealthInfo: + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.base_url}/health", + headers={"Content-Type": "application/json"}, + ) + response.raise_for_status() + j = response.json() + if j is None: + return None + return HealthInfo(**j) + + +async def run_main(host: str, port: int): + client = InspectClient(f"http://{host}:{port}") + + response = await client.list_providers() + cprint(f"list_providers response={response}", "green") + + response = await client.list_routes() + cprint(f"list_routes response={response}", "blue") + + response = await client.health() + cprint(f"health response={response}", "yellow") + + +def main(host: str, port: int): + asyncio.run(run_main(host, port)) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/llama_stack/apis/inspect/inspect.py b/llama_stack/apis/inspect/inspect.py new file mode 100644 index 000000000..ca444098c --- /dev/null +++ b/llama_stack/apis/inspect/inspect.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Dict, List, Protocol + +from llama_models.schema_utils import json_schema_type, webmethod +from pydantic import BaseModel + + +@json_schema_type +class ProviderInfo(BaseModel): + provider_type: str + description: str + + +@json_schema_type +class RouteInfo(BaseModel): + route: str + method: str + providers: List[str] + + +@json_schema_type +class HealthInfo(BaseModel): + status: str + # TODO: add a provider level status + + +class Inspect(Protocol): + @webmethod(route="/providers/list", method="GET") + async def list_providers(self) -> Dict[str, ProviderInfo]: ... + + @webmethod(route="/routes/list", method="GET") + async def list_routes(self) -> Dict[str, List[RouteInfo]]: ... + + @webmethod(route="/health", method="GET") + async def health(self) -> HealthInfo: ... diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index c18f715fe..2be6ede26 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -17,6 +17,53 @@ LLAMA_STACK_BUILD_CONFIG_VERSION = "v1" LLAMA_STACK_RUN_CONFIG_VERSION = "v1" +RoutingKey = Union[str, List[str]] + + +class GenericProviderConfig(BaseModel): + provider_type: str + config: Dict[str, Any] + + +class RoutableProviderConfig(GenericProviderConfig): + routing_key: RoutingKey + + +class PlaceholderProviderConfig(BaseModel): + """Placeholder provider config for API whose provider are defined in routing_table""" + + providers: List[str] + + +# Example: /inference, /safety +class AutoRoutedProviderSpec(ProviderSpec): + provider_type: str = "router" + config_class: str = "" + + docker_image: Optional[str] = None + routing_table_api: Api + module: str + provider_data_validator: Optional[str] = Field( + default=None, + ) + + @property + def pip_packages(self) -> List[str]: + raise AssertionError("Should not be called on AutoRoutedProviderSpec") + + +# Example: /models, /shields +@json_schema_type +class RoutingTableProviderSpec(ProviderSpec): + provider_type: str = "routing_table" + config_class: str = "" + docker_image: Optional[str] = None + + inner_specs: List[ProviderSpec] + module: str + pip_packages: List[str] = Field(default_factory=list) + + @json_schema_type class DistributionSpec(BaseModel): description: Optional[str] = Field( diff --git a/llama_stack/distribution/distribution.py b/llama_stack/distribution/distribution.py index 218105f59..eea066d1f 100644 --- a/llama_stack/distribution/distribution.py +++ b/llama_stack/distribution/distribution.py @@ -46,6 +46,8 @@ def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]: for api in stack_apis(): if api in routing_table_apis: continue + if api == Api.inspect: + continue name = api.name.lower() module = importlib.import_module(f"llama_stack.providers.registry.{name}") diff --git a/llama_stack/distribution/inspect.py b/llama_stack/distribution/inspect.py new file mode 100644 index 000000000..acd7ab7f8 --- /dev/null +++ b/llama_stack/distribution/inspect.py @@ -0,0 +1,54 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Dict, List +from llama_stack.apis.inspect import * # noqa: F403 + + +from llama_stack.distribution.distribution import get_provider_registry +from llama_stack.distribution.server.endpoints import get_all_api_endpoints +from llama_stack.providers.datatypes import * # noqa: F403 + + +def is_passthrough(spec: ProviderSpec) -> bool: + return isinstance(spec, RemoteProviderSpec) and spec.adapter is None + + +class DistributionInspectImpl(Inspect): + def __init__(self): + pass + + async def list_providers(self) -> Dict[str, List[ProviderInfo]]: + ret = {} + all_providers = get_provider_registry() + for api, providers in all_providers.items(): + ret[api.value] = [ + ProviderInfo( + provider_type=p.provider_type, + description="Passthrough" if is_passthrough(p) else "", + ) + for p in providers.values() + ] + + return ret + + async def list_routes(self) -> Dict[str, List[RouteInfo]]: + ret = {} + all_endpoints = get_all_api_endpoints() + + for api, endpoints in all_endpoints.items(): + ret[api.value] = [ + RouteInfo( + route=e.route, + method=e.method, + providers=[], + ) + for e in endpoints + ] + return ret + + async def health(self) -> HealthInfo: + return HealthInfo(status="OK") diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 091769d74..ae7d9ab40 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -3,6 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import importlib from typing import Any, Dict, List, Set @@ -11,7 +12,8 @@ from llama_stack.distribution.distribution import ( builtin_automatically_routed_apis, get_provider_registry, ) -from llama_stack.distribution.utils.dynamic import instantiate_provider +from llama_stack.distribution.inspect import DistributionInspectImpl +from llama_stack.distribution.utils.dynamic import instantiate_class_type async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, Any]: @@ -57,7 +59,6 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An if info.router_api.value not in apis_to_serve: continue - print("router_api", info.router_api) if info.router_api.value not in run_config.routing_table: raise ValueError(f"Routing table for `{source_api.value}` is not provided?") @@ -104,6 +105,14 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An impls[api] = impl + impls[Api.inspect] = DistributionInspectImpl() + specs[Api.inspect] = InlineProviderSpec( + api=Api.inspect, + provider_type="__distribution_builtin__", + config_class="", + module="", + ) + return impls, specs @@ -127,3 +136,60 @@ def topological_sort(providers: List[ProviderSpec]) -> List[ProviderSpec]: dfs(a, visited, stack) return [by_id[x] for x in stack] + + +# returns a class implementing the protocol corresponding to the Api +async def instantiate_provider( + provider_spec: ProviderSpec, + deps: Dict[str, Any], + provider_config: Union[GenericProviderConfig, RoutingTable], +): + module = importlib.import_module(provider_spec.module) + + args = [] + if isinstance(provider_spec, RemoteProviderSpec): + if provider_spec.adapter: + method = "get_adapter_impl" + else: + method = "get_client_impl" + + assert isinstance(provider_config, GenericProviderConfig) + config_type = instantiate_class_type(provider_spec.config_class) + config = config_type(**provider_config.config) + args = [config, deps] + elif isinstance(provider_spec, AutoRoutedProviderSpec): + method = "get_auto_router_impl" + + config = None + args = [provider_spec.api, deps[provider_spec.routing_table_api], deps] + elif isinstance(provider_spec, RoutingTableProviderSpec): + method = "get_routing_table_impl" + + assert isinstance(provider_config, List) + routing_table = provider_config + + inner_specs = {x.provider_type: x for x in provider_spec.inner_specs} + inner_impls = [] + for routing_entry in routing_table: + impl = await instantiate_provider( + inner_specs[routing_entry.provider_type], + deps, + routing_entry, + ) + inner_impls.append((routing_entry.routing_key, impl)) + + config = None + args = [provider_spec.api, inner_impls, routing_table, deps] + else: + method = "get_provider_impl" + + assert isinstance(provider_config, GenericProviderConfig) + config_type = instantiate_class_type(provider_spec.config_class) + config = config_type(**provider_config.config) + args = [config, deps] + + fn = getattr(module, method) + impl = await fn(*args) + impl.__provider_spec__ = provider_spec + impl.__provider_config__ = config + return impl diff --git a/llama_stack/distribution/server/endpoints.py b/llama_stack/distribution/server/endpoints.py index 96de31c4b..601e80e5d 100644 --- a/llama_stack/distribution/server/endpoints.py +++ b/llama_stack/distribution/server/endpoints.py @@ -11,12 +11,14 @@ from pydantic import BaseModel from llama_stack.apis.agents import Agents from llama_stack.apis.inference import Inference +from llama_stack.apis.inspect import Inspect from llama_stack.apis.memory import Memory from llama_stack.apis.memory_banks import MemoryBanks from llama_stack.apis.models import Models from llama_stack.apis.safety import Safety from llama_stack.apis.shields import Shields from llama_stack.apis.telemetry import Telemetry + from llama_stack.providers.datatypes import Api @@ -38,6 +40,7 @@ def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]: Api.models: Models, Api.shields: Shields, Api.memory_banks: MemoryBanks, + Api.inspect: Inspect, } for api, protocol in protocols.items(): diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 1ac1a1b16..4013264df 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -15,7 +15,6 @@ from collections.abc import ( AsyncIterator as AsyncIteratorABC, ) from contextlib import asynccontextmanager -from http import HTTPStatus from ssl import SSLError from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional @@ -26,7 +25,6 @@ import yaml from fastapi import Body, FastAPI, HTTPException, Request, Response from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, StreamingResponse -from fastapi.routing import APIRoute from pydantic import BaseModel, ValidationError from termcolor import cprint from typing_extensions import Annotated @@ -287,15 +285,6 @@ def main( app = FastAPI() - # Health check is added to enable deploying the docker container image on Kubernetes which require - # a health check that can return 200 for readiness and liveness check - class HealthCheck(BaseModel): - status: str = "OK" - - @app.get("/healthcheck", status_code=HTTPStatus.OK, response_model=HealthCheck) - async def healthcheck(): - return HealthCheck(status="OK") - impls, specs = asyncio.run(resolve_impls_with_routing(config)) if Api.telemetry in impls: setup_logger(impls[Api.telemetry]) @@ -307,6 +296,7 @@ def main( else: apis_to_serve = set(impls.keys()) + apis_to_serve.add(Api.inspect) for api_str in apis_to_serve: api = Api(api_str) @@ -340,14 +330,11 @@ def main( ) ) - for route in app.routes: - if isinstance(route, APIRoute): - cprint( - f"Serving {next(iter(route.methods))} {route.path}", - "white", - attrs=["bold"], - ) + cprint(f"Serving API {api_str}", "white", attrs=["bold"]) + for endpoint in endpoints: + cprint(f" {endpoint.method.upper()} {endpoint.route}", "white") + print("") app.exception_handler(RequestValidationError)(global_exception_handler) app.exception_handler(Exception)(global_exception_handler) signal.signal(signal.SIGINT, handle_sigint) diff --git a/llama_stack/distribution/utils/dynamic.py b/llama_stack/distribution/utils/dynamic.py index 91aeb4ac7..53b861fe4 100644 --- a/llama_stack/distribution/utils/dynamic.py +++ b/llama_stack/distribution/utils/dynamic.py @@ -5,69 +5,9 @@ # the root directory of this source tree. import importlib -from typing import Any, Dict - -from llama_stack.distribution.datatypes import * # noqa: F403 def instantiate_class_type(fully_qualified_name): module_name, class_name = fully_qualified_name.rsplit(".", 1) module = importlib.import_module(module_name) return getattr(module, class_name) - - -# returns a class implementing the protocol corresponding to the Api -async def instantiate_provider( - provider_spec: ProviderSpec, - deps: Dict[str, Any], - provider_config: Union[GenericProviderConfig, RoutingTable], -): - module = importlib.import_module(provider_spec.module) - - args = [] - if isinstance(provider_spec, RemoteProviderSpec): - if provider_spec.adapter: - method = "get_adapter_impl" - else: - method = "get_client_impl" - - assert isinstance(provider_config, GenericProviderConfig) - config_type = instantiate_class_type(provider_spec.config_class) - config = config_type(**provider_config.config) - args = [config, deps] - elif isinstance(provider_spec, AutoRoutedProviderSpec): - method = "get_auto_router_impl" - - config = None - args = [provider_spec.api, deps[provider_spec.routing_table_api], deps] - elif isinstance(provider_spec, RoutingTableProviderSpec): - method = "get_routing_table_impl" - - assert isinstance(provider_config, List) - routing_table = provider_config - - inner_specs = {x.provider_type: x for x in provider_spec.inner_specs} - inner_impls = [] - for routing_entry in routing_table: - impl = await instantiate_provider( - inner_specs[routing_entry.provider_type], - deps, - routing_entry, - ) - inner_impls.append((routing_entry.routing_key, impl)) - - config = None - args = [provider_spec.api, inner_impls, routing_table, deps] - else: - method = "get_provider_impl" - - assert isinstance(provider_config, GenericProviderConfig) - config_type = instantiate_class_type(provider_spec.config_class) - config = config_type(**provider_config.config) - args = [config, deps] - - fn = getattr(module, method) - impl = await fn(*args) - impl.__provider_spec__ = provider_spec - impl.__provider_config__ = config - return impl diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index a328acd6b..a2e8851a2 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from enum import Enum -from typing import Any, Dict, List, Optional, Protocol, Union +from typing import Any, List, Optional, Protocol from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field @@ -24,6 +24,9 @@ class Api(Enum): shields = "shields" memory_banks = "memory_banks" + # built-in API + inspect = "inspect" + @json_schema_type class ProviderSpec(BaseModel): @@ -55,68 +58,6 @@ class RoutableProvider(Protocol): async def validate_routing_keys(self, keys: List[str]) -> None: ... -class GenericProviderConfig(BaseModel): - provider_type: str - config: Dict[str, Any] - - -class PlaceholderProviderConfig(BaseModel): - """Placeholder provider config for API whose provider are defined in routing_table""" - - providers: List[str] - - -RoutingKey = Union[str, List[str]] - - -class RoutableProviderConfig(GenericProviderConfig): - routing_key: RoutingKey - - -# Example: /inference, /safety -@json_schema_type -class AutoRoutedProviderSpec(ProviderSpec): - provider_type: str = "router" - config_class: str = "" - - docker_image: Optional[str] = None - routing_table_api: Api - module: str = Field( - ..., - description=""" - Fully-qualified name of the module to import. The module is expected to have: - - - `get_router_impl(config, provider_specs, deps)`: returns the router implementation - """, - ) - provider_data_validator: Optional[str] = Field( - default=None, - ) - - @property - def pip_packages(self) -> List[str]: - raise AssertionError("Should not be called on AutoRoutedProviderSpec") - - -# Example: /models, /shields -@json_schema_type -class RoutingTableProviderSpec(ProviderSpec): - provider_type: str = "routing_table" - config_class: str = "" - docker_image: Optional[str] = None - - inner_specs: List[ProviderSpec] - module: str = Field( - ..., - description=""" - Fully-qualified name of the module to import. The module is expected to have: - - - `get_router_impl(config, provider_specs, deps)`: returns the router implementation - """, - ) - pip_packages: List[str] = Field(default_factory=list) - - @json_schema_type class AdapterSpec(BaseModel): adapter_type: str = Field( @@ -179,10 +120,6 @@ class RemoteProviderConfig(BaseModel): return f"http://{self.host}:{self.port}" -def remote_provider_type(adapter_type: str) -> str: - return f"remote::{adapter_type}" - - @json_schema_type class RemoteProviderSpec(ProviderSpec): adapter: Optional[AdapterSpec] = Field( @@ -226,7 +163,7 @@ def remote_provider_spec( if adapter and adapter.config_class else "llama_stack.distribution.datatypes.RemoteProviderConfig" ) - provider_type = remote_provider_type(adapter.adapter_type) if adapter else "remote" + provider_type = f"remote::{adapter.adapter_type}" if adapter else "remote" return RemoteProviderSpec( api=api, provider_type=provider_type, config_class=config_class, adapter=adapter From 703ab9385f9c7bc33474197082a061de6f2d1ae2 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 2 Oct 2024 18:23:02 -0700 Subject: [PATCH 06/11] fix routing table key list --- .pre-commit-config.yaml | 10 +++++----- .../distribution/routers/routing_tables.py | 19 ++++++++++++++----- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 555a475b2..1c85436c4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -52,8 +52,8 @@ repos: # - id: pydoclint # args: [--config=pyproject.toml] -- repo: https://github.com/tcort/markdown-link-check - rev: v3.11.2 - hooks: - - id: markdown-link-check - args: ['--quiet'] +# - repo: https://github.com/tcort/markdown-link-check +# rev: v3.11.2 +# hooks: +# - id: markdown-link-check +# args: ['--quiet'] diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 02dc942e8..e5db17edc 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -94,12 +94,21 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): async def list_shields(self) -> List[ShieldSpec]: specs = [] for entry in self.routing_table_config: - specs.append( - ShieldSpec( - shield_type=entry.routing_key, - provider_config=entry, + if isinstance(entry.routing_key, list): + for k in entry.routing_key: + specs.append( + ShieldSpec( + shield_type=k, + provider_config=entry, + ) + ) + else: + specs.append( + ShieldSpec( + shield_type=entry.routing_key, + provider_config=entry, + ) ) - ) return specs async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]: From 19ce6bf009a80dbc5ae269532b944e3579764fbd Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 2 Oct 2024 20:43:57 -0700 Subject: [PATCH 07/11] Don't validate prompt-guard anymore --- .../impls/meta_reference/safety/config.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/llama_stack/providers/impls/meta_reference/safety/config.py b/llama_stack/providers/impls/meta_reference/safety/config.py index 734103412..36428078d 100644 --- a/llama_stack/providers/impls/meta_reference/safety/config.py +++ b/llama_stack/providers/impls/meta_reference/safety/config.py @@ -50,20 +50,6 @@ class LlamaGuardShieldConfig(BaseModel): class PromptGuardShieldConfig(BaseModel): model: str = "Prompt-Guard-86M" - @validator("model") - @classmethod - def validate_model(cls, model: str) -> str: - permitted_models = [ - m.descriptor() - for m in safety_models() - if m.core_model_id == CoreModelId.prompt_guard_86m - ] - if model not in permitted_models: - raise ValueError( - f"Invalid model: {model}. Must be one of {permitted_models}" - ) - return model - class SafetyConfig(BaseModel): llama_guard_shield: Optional[LlamaGuardShieldConfig] = None From 988a9cada3e7ea296611e20facdd2990f9512b2a Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 2 Oct 2024 21:10:56 -0700 Subject: [PATCH 08/11] Don't ask for Api.inspect in stack build --- llama_stack/cli/stack/build.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/llama_stack/cli/stack/build.py b/llama_stack/cli/stack/build.py index b7c25fa1b..ab6861482 100644 --- a/llama_stack/cli/stack/build.py +++ b/llama_stack/cli/stack/build.py @@ -253,6 +253,8 @@ class StackBuild(Subcommand): for api in Api: if api in routing_table_apis: continue + if api == Api.inspect: + continue providers_for_api = all_providers[api] From e9f615058820ec0a68b4d238b5cdc6d80cde3c36 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 2 Oct 2024 21:31:09 -0700 Subject: [PATCH 09/11] A bit cleanup to avoid breakages --- llama_stack/cli/stack/build.py | 36 ++++++++---------------- llama_stack/distribution/distribution.py | 13 ++++----- 2 files changed, 17 insertions(+), 32 deletions(-) diff --git a/llama_stack/cli/stack/build.py b/llama_stack/cli/stack/build.py index ab6861482..d502e4c84 100644 --- a/llama_stack/cli/stack/build.py +++ b/llama_stack/cli/stack/build.py @@ -114,10 +114,10 @@ class StackBuild(Subcommand): # save build.yaml spec for building same distribution again if build_config.image_type == ImageType.docker.value: # docker needs build file to be in the llama-stack repo dir to be able to copy over to the image - llama_stack_path = Path(os.path.abspath(__file__)).parent.parent.parent.parent - build_dir = ( - llama_stack_path / "tmp/configs/" - ) + llama_stack_path = Path( + os.path.abspath(__file__) + ).parent.parent.parent.parent + build_dir = llama_stack_path / "tmp/configs/" else: build_dir = DISTRIBS_BASE_DIR / f"llamastack-{build_config.name}" @@ -173,12 +173,7 @@ class StackBuild(Subcommand): def _run_stack_build_command(self, args: argparse.Namespace) -> None: import yaml - from llama_stack.distribution.distribution import ( - Api, - get_provider_registry, - builtin_automatically_routed_apis, - ) - from llama_stack.distribution.utils.dynamic import instantiate_class_type + from llama_stack.distribution.distribution import get_provider_registry from prompt_toolkit import prompt from prompt_toolkit.validation import Validator from termcolor import cprint @@ -212,7 +207,10 @@ class StackBuild(Subcommand): if args.name: maybe_build_config = self._get_build_config_from_name(args) if maybe_build_config: - cprint(f"Building from existing build config for {args.name} in {str(maybe_build_config)}...", "green") + cprint( + f"Building from existing build config for {args.name} in {str(maybe_build_config)}...", + "green", + ) with open(maybe_build_config, "r") as f: build_config = BuildConfig(**yaml.safe_load(f)) self._run_stack_build_command_from_build_config(build_config) @@ -240,24 +238,12 @@ class StackBuild(Subcommand): ) cprint( - f"\n Llama Stack is composed of several APIs working together. Let's configure the providers (implementations) you want to use for these APIs.", + "\n Llama Stack is composed of several APIs working together. Let's configure the providers (implementations) you want to use for these APIs.", color="green", ) providers = dict() - all_providers = get_provider_registry() - routing_table_apis = set( - x.routing_table_api for x in builtin_automatically_routed_apis() - ) - - for api in Api: - if api in routing_table_apis: - continue - if api == Api.inspect: - continue - - providers_for_api = all_providers[api] - + for api, providers_for_api in get_provider_registry().items(): api_provider = prompt( "> Enter provider for the {} API: (default=meta-reference): ".format( api.value diff --git a/llama_stack/distribution/distribution.py b/llama_stack/distribution/distribution.py index eea066d1f..999646cc0 100644 --- a/llama_stack/distribution/distribution.py +++ b/llama_stack/distribution/distribution.py @@ -38,17 +38,16 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]: ] -def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]: - ret = {} +def providable_apis() -> List[Api]: routing_table_apis = set( x.routing_table_api for x in builtin_automatically_routed_apis() ) - for api in stack_apis(): - if api in routing_table_apis: - continue - if api == Api.inspect: - continue + return [api for api in Api if api not in routing_table_apis and api != Api.inspect] + +def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]: + ret = {} + for api in providable_apis(): name = api.name.lower() module = importlib.import_module(f"llama_stack.providers.registry.{name}") ret[api] = { From c02a90e4c82d49c51174a53c2060d94a27f27599 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 3 Oct 2024 05:42:47 -0700 Subject: [PATCH 10/11] Bump version to 0.0.38 --- requirements.txt | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 327b2ee82..df3221371 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ blobfile fire httpx huggingface-hub -llama-models>=0.0.37 +llama-models>=0.0.38 prompt-toolkit python-dotenv pydantic>=2 diff --git a/setup.py b/setup.py index 3c26c9a84..804c9ba3d 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ def read_requirements(): setup( name="llama_stack", - version="0.0.37", + version="0.0.38", author="Meta Llama", author_email="llama-oss@meta.com", description="Llama Stack", From d74501f75cdea8d59bde2acc695a50cd634a9d94 Mon Sep 17 00:00:00 2001 From: raghotham Date: Thu, 3 Oct 2024 10:21:16 -0700 Subject: [PATCH 11/11] Update README.md Added pypi package version --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 936876708..a5172ce5c 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,8 @@ # Llama Stack +[![PyPI version](https://img.shields.io/pypi/v/llama_stack.svg)](https://pypi.org/project/llama_stack/) [![PyPI - Downloads](https://img.shields.io/pypi/dm/llama-stack)](https://pypi.org/project/llama-stack/) -[![Discord](https://img.shields.io/discord/1257833999603335178)](https://discord.gg/TZAAYNVtrU) +[![Discord](https://img.shields.io/discord/1257833999603335178)](https://discord.gg/llama-stack) This repository contains the Llama Stack API specifications as well as API Providers and Llama Stack Distributions.