issue w/ safety

This commit is contained in:
Xi Yan 2024-09-22 23:15:34 -07:00
parent e0ad4fb99c
commit 4586692dee
8 changed files with 50 additions and 66 deletions

View file

@ -7,11 +7,11 @@
from typing import List, Optional, Protocol from typing import List, Optional, Protocol
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.apis.memory import MemoryBankType from llama_stack.apis.memory import MemoryBankType
from llama_stack.distribution.datatypes import GenericProviderConfig from llama_stack.distribution.datatypes import GenericProviderConfig
from pydantic import BaseModel, Field
@json_schema_type @json_schema_type
@ -23,10 +23,10 @@ class MemoryBankSpec(BaseModel):
class MemoryBanks(Protocol): class MemoryBanks(Protocol):
@webmethod(route="/memory_banks/list", method="GET") @webmethod(route="/memory_banks_router/list", method="GET")
async def list_memory_banks(self) -> List[MemoryBankSpec]: ... async def list_memory_banks(self) -> List[MemoryBankSpec]: ...
@webmethod(route="/memory_banks/get", method="GET") @webmethod(route="/memory_banks_router/get", method="GET")
async def get_memory_bank( async def get_memory_bank(
self, bank_type: MemoryBankType self, bank_type: MemoryBankType
) -> Optional[MemoryBankSpec]: ... ) -> Optional[MemoryBankSpec]: ...

View file

@ -63,13 +63,6 @@ class RoutableProviderConfig(GenericProviderConfig):
routing_key: str routing_key: str
class RoutingTableConfig(BaseModel):
entries: List[RoutableProviderConfig] = Field(...)
keys: Optional[List[str]] = Field(
default=None,
)
# Example: /inference, /safety # Example: /inference, /safety
@json_schema_type @json_schema_type
class AutoRoutedProviderSpec(ProviderSpec): class AutoRoutedProviderSpec(ProviderSpec):
@ -275,7 +268,7 @@ The list of APIs to serve. If not specified, all APIs specified in the provider_
Provider configurations for each of the APIs provided by this package. Provider configurations for each of the APIs provided by this package.
""", """,
) )
routing_tables: Dict[str, RoutingTableConfig] = Field( routing_table: Dict[str, List[RoutableProviderConfig]] = Field(
default_factory=dict, default_factory=dict,
description=""" description="""

View file

@ -12,7 +12,7 @@ from llama_stack.distribution.datatypes import * # noqa: F403
async def get_routing_table_impl( async def get_routing_table_impl(
api: Api, api: Api,
inner_impls: List[Tuple[str, Any]], inner_impls: List[Tuple[str, Any]],
routing_table_config: RoutingTableConfig, routing_table_config: Dict[str, List[RoutableProviderConfig]],
_deps, _deps,
) -> Any: ) -> Any:
from .routing_tables import ( from .routing_tables import (

View file

@ -162,6 +162,7 @@ class SafetyRouter(Safety):
messages: List[Message], messages: List[Message],
params: Dict[str, Any] = None, params: Dict[str, Any] = None,
) -> RunShieldResponse: ) -> RunShieldResponse:
print(f"Running shield {shield_type}")
return await self.routing_table.get_provider_impl(shield_type).run_shield( return await self.routing_table.get_provider_impl(shield_type).run_shield(
shield_type=shield_type, shield_type=shield_type,
messages=messages, messages=messages,

View file

@ -20,7 +20,7 @@ class CommonRoutingTableImpl(RoutingTable):
def __init__( def __init__(
self, self,
inner_impls: List[Tuple[str, Any]], inner_impls: List[Tuple[str, Any]],
routing_table_config: RoutingTableConfig, routing_table_config: Dict[str, List[RoutableProviderConfig]],
) -> None: ) -> None:
self.providers = {k: v for k, v in inner_impls} self.providers = {k: v for k, v in inner_impls}
self.routing_keys = list(self.providers.keys()) self.routing_keys = list(self.providers.keys())
@ -40,7 +40,7 @@ class CommonRoutingTableImpl(RoutingTable):
return self.routing_keys return self.routing_keys
def get_provider_config(self, routing_key: str) -> Optional[GenericProviderConfig]: def get_provider_config(self, routing_key: str) -> Optional[GenericProviderConfig]:
for entry in self.routing_table_config.entries: for entry in self.routing_table_config:
if entry.routing_key == routing_key: if entry.routing_key == routing_key:
return entry return entry
return None return None
@ -50,7 +50,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def list_models(self) -> List[ModelServingSpec]: async def list_models(self) -> List[ModelServingSpec]:
specs = [] specs = []
for entry in self.routing_table_config.entries: for entry in self.routing_table_config:
model_id = entry.routing_key model_id = entry.routing_key
specs.append( specs.append(
ModelServingSpec( ModelServingSpec(
@ -61,7 +61,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
return specs return specs
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]: async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]:
for entry in self.routing_table_config.entries: for entry in self.routing_table_config:
if entry.routing_key == core_model_id: if entry.routing_key == core_model_id:
return ModelServingSpec( return ModelServingSpec(
llama_model=resolve_model(core_model_id), llama_model=resolve_model(core_model_id),
@ -74,7 +74,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> List[ShieldSpec]: async def list_shields(self) -> List[ShieldSpec]:
specs = [] specs = []
for entry in self.routing_table_config.entries: for entry in self.routing_table_config:
specs.append( specs.append(
ShieldSpec( ShieldSpec(
shield_type=entry.routing_key, shield_type=entry.routing_key,
@ -84,7 +84,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
return specs return specs
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]: async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]:
for entry in self.routing_table_config.entries: for entry in self.routing_table_config:
if entry.routing_key == shield_type: if entry.routing_key == shield_type:
return ShieldSpec( return ShieldSpec(
shield_type=entry.routing_key, shield_type=entry.routing_key,
@ -97,7 +97,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
async def list_memory_banks(self) -> List[MemoryBankSpec]: async def list_memory_banks(self) -> List[MemoryBankSpec]:
specs = [] specs = []
for entry in self.routing_table_config.entries: for entry in self.routing_table_config:
specs.append( specs.append(
MemoryBankSpec( MemoryBankSpec(
bank_type=entry.routing_key, bank_type=entry.routing_key,
@ -107,7 +107,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
return specs return specs
async def get_memory_bank(self, bank_type: str) -> Optional[MemoryBankSpec]: async def get_memory_bank(self, bank_type: str) -> Optional[MemoryBankSpec]:
for entry in self.routing_table_config.entries: for entry in self.routing_table_config:
if entry.routing_key == bank_type: if entry.routing_key == bank_type:
return MemoryBankSpec( return MemoryBankSpec(
bank_type=entry.routing_key, bank_type=entry.routing_key,

View file

@ -35,9 +35,6 @@ from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.routing import APIRoute from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.providers.utils.telemetry.tracing import ( from llama_stack.providers.utils.telemetry.tracing import (
end_trace, end_trace,
@ -45,6 +42,9 @@ from llama_stack.providers.utils.telemetry.tracing import (
SpanStatus, SpanStatus,
start_trace, start_trace,
) )
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.distribution import ( from llama_stack.distribution.distribution import (
@ -315,7 +315,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
configs[api] = config configs[api] = config
apis_to_serve = run_config.apis_to_serve or set( apis_to_serve = run_config.apis_to_serve or set(
list(specs.keys()) + list(run_config.routing_tables.keys()) list(specs.keys()) + list(run_config.routing_table.keys())
) )
print("apis_to_serve", apis_to_serve) print("apis_to_serve", apis_to_serve)
for info in builtin_automatically_routed_apis(): for info in builtin_automatically_routed_apis():
@ -331,15 +331,16 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
if info.router_api.value not in apis_to_serve: if info.router_api.value not in apis_to_serve:
continue continue
if source_api.value not in run_config.routing_tables: print("router_api", info.router_api)
if info.router_api.value not in run_config.routing_table:
raise ValueError(f"Routing table for `{source_api.value}` is not provided?") raise ValueError(f"Routing table for `{source_api.value}` is not provided?")
routing_table = run_config.routing_tables[source_api.value] routing_table = run_config.routing_table[info.router_api.value]
providers = all_providers[info.router_api] providers = all_providers[info.router_api]
inner_specs = [] inner_specs = []
for rt_entry in routing_table.entries: for rt_entry in routing_table:
if rt_entry.provider_id not in providers: if rt_entry.provider_id not in providers:
raise ValueError( raise ValueError(
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`" f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"

View file

@ -43,12 +43,12 @@ async def instantiate_provider(
elif isinstance(provider_spec, RoutingTableProviderSpec): elif isinstance(provider_spec, RoutingTableProviderSpec):
method = "get_routing_table_impl" method = "get_routing_table_impl"
assert isinstance(provider_config, RoutingTableConfig) assert isinstance(provider_config, List)
routing_table = provider_config routing_table = provider_config
inner_specs = {x.provider_id: x for x in provider_spec.inner_specs} inner_specs = {x.provider_id: x for x in provider_spec.inner_specs}
inner_impls = [] inner_impls = []
for routing_entry in routing_table.entries: for routing_entry in routing_table:
impl = await instantiate_provider( impl = await instantiate_provider(
inner_specs[routing_entry.provider_id], inner_specs[routing_entry.provider_id],
deps, deps,

View file

@ -20,40 +20,29 @@ api_providers:
namespace: null namespace: null
type: sqlite type: sqlite
db_path: /home/xiyan/.llama/runtime/kvstore.db db_path: /home/xiyan/.llama/runtime/kvstore.db
routing_tables: routing_table:
models: inference:
entries: - routing_key: Meta-Llama3.1-8B-Instruct
- routing_key: Meta-Llama3.1-8B-Instruct provider_id: meta-reference
provider_id: meta-reference config:
config: model: Meta-Llama3.1-8B-Instruct
model: Meta-Llama3.1-8B-Instruct quantization: null
quantization: null torch_seed: null
torch_seed: null max_seq_len: 4096
max_seq_len: 4096 max_batch_size: 1
max_batch_size: 1 memory:
- routing_key: Meta-Llama3.1-8B - routing_key: vector
provider_id: meta-reference provider_id: meta-reference
config: config: {}
model: Meta-Llama3.1-8B safety:
quantization: null - routing_key: llama_guard
torch_seed: null provider_id: meta-reference
max_seq_len: 4096 config:
max_batch_size: 1 model: Llama-Guard-3-8B
memory_banks: excluded_categories: []
entries: disable_input_check: false
- routing_key: vector disable_output_check: false
provider_id: meta-reference - routing_key: prompt_guard
config: {} provider_id: meta-reference
shields: config:
entries: model: Prompt-Guard-86M
- routing_key: llama_guard_shield
provider_id: meta-reference
config:
model: Llama-Guard-3-8B
excluded_categories: []
disable_input_check: false
disable_output_check: false
- routing_key: prompt_guard_shield
provider_id: meta-reference
config:
model: Prompt-Guard-86M