From 4586692dee76968c7be3249b61eede845eab6821 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Sun, 22 Sep 2024 23:15:34 -0700 Subject: [PATCH] issue w/ safety --- llama_stack/apis/memory_banks/memory_banks.py | 6 +- llama_stack/distribution/datatypes.py | 9 +-- llama_stack/distribution/routers/__init__.py | 2 +- llama_stack/distribution/routers/routers.py | 1 + .../distribution/routers/routing_tables.py | 16 ++--- llama_stack/distribution/server/server.py | 15 ++--- llama_stack/distribution/utils/dynamic.py | 4 +- tests/examples/router-local-run.yaml | 63 ++++++++----------- 8 files changed, 50 insertions(+), 66 deletions(-) diff --git a/llama_stack/apis/memory_banks/memory_banks.py b/llama_stack/apis/memory_banks/memory_banks.py index 23bfb69e1..7c0e981a3 100644 --- a/llama_stack/apis/memory_banks/memory_banks.py +++ b/llama_stack/apis/memory_banks/memory_banks.py @@ -7,11 +7,11 @@ from typing import List, Optional, Protocol from llama_models.schema_utils import json_schema_type, webmethod -from pydantic import BaseModel, Field from llama_stack.apis.memory import MemoryBankType from llama_stack.distribution.datatypes import GenericProviderConfig +from pydantic import BaseModel, Field @json_schema_type @@ -23,10 +23,10 @@ class MemoryBankSpec(BaseModel): class MemoryBanks(Protocol): - @webmethod(route="/memory_banks/list", method="GET") + @webmethod(route="/memory_banks_router/list", method="GET") async def list_memory_banks(self) -> List[MemoryBankSpec]: ... - @webmethod(route="/memory_banks/get", method="GET") + @webmethod(route="/memory_banks_router/get", method="GET") async def get_memory_bank( self, bank_type: MemoryBankType ) -> Optional[MemoryBankSpec]: ... diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index a3ff86cdf..3a60a057d 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -63,13 +63,6 @@ class RoutableProviderConfig(GenericProviderConfig): routing_key: str -class RoutingTableConfig(BaseModel): - entries: List[RoutableProviderConfig] = Field(...) - keys: Optional[List[str]] = Field( - default=None, - ) - - # Example: /inference, /safety @json_schema_type class AutoRoutedProviderSpec(ProviderSpec): @@ -275,7 +268,7 @@ The list of APIs to serve. If not specified, all APIs specified in the provider_ Provider configurations for each of the APIs provided by this package. """, ) - routing_tables: Dict[str, RoutingTableConfig] = Field( + routing_table: Dict[str, List[RoutableProviderConfig]] = Field( default_factory=dict, description=""" diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index e8b8938b0..363c863aa 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -12,7 +12,7 @@ from llama_stack.distribution.datatypes import * # noqa: F403 async def get_routing_table_impl( api: Api, inner_impls: List[Tuple[str, Any]], - routing_table_config: RoutingTableConfig, + routing_table_config: Dict[str, List[RoutableProviderConfig]], _deps, ) -> Any: from .routing_tables import ( diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index c9a536aa0..ba32e5986 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -162,6 +162,7 @@ class SafetyRouter(Safety): messages: List[Message], params: Dict[str, Any] = None, ) -> RunShieldResponse: + print(f"Running shield {shield_type}") return await self.routing_table.get_provider_impl(shield_type).run_shield( shield_type=shield_type, messages=messages, diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index cd014d28d..fcd4d2b2b 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -20,7 +20,7 @@ class CommonRoutingTableImpl(RoutingTable): def __init__( self, inner_impls: List[Tuple[str, Any]], - routing_table_config: RoutingTableConfig, + routing_table_config: Dict[str, List[RoutableProviderConfig]], ) -> None: self.providers = {k: v for k, v in inner_impls} self.routing_keys = list(self.providers.keys()) @@ -40,7 +40,7 @@ class CommonRoutingTableImpl(RoutingTable): return self.routing_keys def get_provider_config(self, routing_key: str) -> Optional[GenericProviderConfig]: - for entry in self.routing_table_config.entries: + for entry in self.routing_table_config: if entry.routing_key == routing_key: return entry return None @@ -50,7 +50,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): async def list_models(self) -> List[ModelServingSpec]: specs = [] - for entry in self.routing_table_config.entries: + for entry in self.routing_table_config: model_id = entry.routing_key specs.append( ModelServingSpec( @@ -61,7 +61,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): return specs async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]: - for entry in self.routing_table_config.entries: + for entry in self.routing_table_config: if entry.routing_key == core_model_id: return ModelServingSpec( llama_model=resolve_model(core_model_id), @@ -74,7 +74,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): async def list_shields(self) -> List[ShieldSpec]: specs = [] - for entry in self.routing_table_config.entries: + for entry in self.routing_table_config: specs.append( ShieldSpec( shield_type=entry.routing_key, @@ -84,7 +84,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): return specs async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]: - for entry in self.routing_table_config.entries: + for entry in self.routing_table_config: if entry.routing_key == shield_type: return ShieldSpec( shield_type=entry.routing_key, @@ -97,7 +97,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): async def list_memory_banks(self) -> List[MemoryBankSpec]: specs = [] - for entry in self.routing_table_config.entries: + for entry in self.routing_table_config: specs.append( MemoryBankSpec( bank_type=entry.routing_key, @@ -107,7 +107,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): return specs async def get_memory_bank(self, bank_type: str) -> Optional[MemoryBankSpec]: - for entry in self.routing_table_config.entries: + for entry in self.routing_table_config: if entry.routing_key == bank_type: return MemoryBankSpec( bank_type=entry.routing_key, diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 18433596f..4f9ef7ea6 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -35,9 +35,6 @@ from fastapi import Body, FastAPI, HTTPException, Request, Response from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, StreamingResponse from fastapi.routing import APIRoute -from pydantic import BaseModel, ValidationError -from termcolor import cprint -from typing_extensions import Annotated from llama_stack.providers.utils.telemetry.tracing import ( end_trace, @@ -45,6 +42,9 @@ from llama_stack.providers.utils.telemetry.tracing import ( SpanStatus, start_trace, ) +from pydantic import BaseModel, ValidationError +from termcolor import cprint +from typing_extensions import Annotated from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.distribution import ( @@ -315,7 +315,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An configs[api] = config apis_to_serve = run_config.apis_to_serve or set( - list(specs.keys()) + list(run_config.routing_tables.keys()) + list(specs.keys()) + list(run_config.routing_table.keys()) ) print("apis_to_serve", apis_to_serve) for info in builtin_automatically_routed_apis(): @@ -331,15 +331,16 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An if info.router_api.value not in apis_to_serve: continue - if source_api.value not in run_config.routing_tables: + print("router_api", info.router_api) + if info.router_api.value not in run_config.routing_table: raise ValueError(f"Routing table for `{source_api.value}` is not provided?") - routing_table = run_config.routing_tables[source_api.value] + routing_table = run_config.routing_table[info.router_api.value] providers = all_providers[info.router_api] inner_specs = [] - for rt_entry in routing_table.entries: + for rt_entry in routing_table: if rt_entry.provider_id not in providers: raise ValueError( f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`" diff --git a/llama_stack/distribution/utils/dynamic.py b/llama_stack/distribution/utils/dynamic.py index 6d9c57dfd..7c2ac2e6a 100644 --- a/llama_stack/distribution/utils/dynamic.py +++ b/llama_stack/distribution/utils/dynamic.py @@ -43,12 +43,12 @@ async def instantiate_provider( elif isinstance(provider_spec, RoutingTableProviderSpec): method = "get_routing_table_impl" - assert isinstance(provider_config, RoutingTableConfig) + assert isinstance(provider_config, List) routing_table = provider_config inner_specs = {x.provider_id: x for x in provider_spec.inner_specs} inner_impls = [] - for routing_entry in routing_table.entries: + for routing_entry in routing_table: impl = await instantiate_provider( inner_specs[routing_entry.provider_id], deps, diff --git a/tests/examples/router-local-run.yaml b/tests/examples/router-local-run.yaml index 807dcafec..df4c453b2 100644 --- a/tests/examples/router-local-run.yaml +++ b/tests/examples/router-local-run.yaml @@ -20,40 +20,29 @@ api_providers: namespace: null type: sqlite db_path: /home/xiyan/.llama/runtime/kvstore.db -routing_tables: - models: - entries: - - routing_key: Meta-Llama3.1-8B-Instruct - provider_id: meta-reference - config: - model: Meta-Llama3.1-8B-Instruct - quantization: null - torch_seed: null - max_seq_len: 4096 - max_batch_size: 1 - - routing_key: Meta-Llama3.1-8B - provider_id: meta-reference - config: - model: Meta-Llama3.1-8B - quantization: null - torch_seed: null - max_seq_len: 4096 - max_batch_size: 1 - memory_banks: - entries: - - routing_key: vector - provider_id: meta-reference - config: {} - shields: - entries: - - routing_key: llama_guard_shield - provider_id: meta-reference - config: - model: Llama-Guard-3-8B - excluded_categories: [] - disable_input_check: false - disable_output_check: false - - routing_key: prompt_guard_shield - provider_id: meta-reference - config: - model: Prompt-Guard-86M +routing_table: + inference: + - routing_key: Meta-Llama3.1-8B-Instruct + provider_id: meta-reference + config: + model: Meta-Llama3.1-8B-Instruct + quantization: null + torch_seed: null + max_seq_len: 4096 + max_batch_size: 1 + memory: + - routing_key: vector + provider_id: meta-reference + config: {} + safety: + - routing_key: llama_guard + provider_id: meta-reference + config: + model: Llama-Guard-3-8B + excluded_categories: [] + disable_input_check: false + disable_output_check: false + - routing_key: prompt_guard + provider_id: meta-reference + config: + model: Prompt-Guard-86M