issue w/ safety

This commit is contained in:
Xi Yan 2024-09-22 23:15:34 -07:00
parent e0ad4fb99c
commit 4586692dee
8 changed files with 50 additions and 66 deletions

View file

@ -7,11 +7,11 @@
from typing import List, Optional, Protocol
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.apis.memory import MemoryBankType
from llama_stack.distribution.datatypes import GenericProviderConfig
from pydantic import BaseModel, Field
@json_schema_type
@ -23,10 +23,10 @@ class MemoryBankSpec(BaseModel):
class MemoryBanks(Protocol):
@webmethod(route="/memory_banks/list", method="GET")
@webmethod(route="/memory_banks_router/list", method="GET")
async def list_memory_banks(self) -> List[MemoryBankSpec]: ...
@webmethod(route="/memory_banks/get", method="GET")
@webmethod(route="/memory_banks_router/get", method="GET")
async def get_memory_bank(
self, bank_type: MemoryBankType
) -> Optional[MemoryBankSpec]: ...

View file

@ -63,13 +63,6 @@ class RoutableProviderConfig(GenericProviderConfig):
routing_key: str
class RoutingTableConfig(BaseModel):
entries: List[RoutableProviderConfig] = Field(...)
keys: Optional[List[str]] = Field(
default=None,
)
# Example: /inference, /safety
@json_schema_type
class AutoRoutedProviderSpec(ProviderSpec):
@ -275,7 +268,7 @@ The list of APIs to serve. If not specified, all APIs specified in the provider_
Provider configurations for each of the APIs provided by this package.
""",
)
routing_tables: Dict[str, RoutingTableConfig] = Field(
routing_table: Dict[str, List[RoutableProviderConfig]] = Field(
default_factory=dict,
description="""

View file

@ -12,7 +12,7 @@ from llama_stack.distribution.datatypes import * # noqa: F403
async def get_routing_table_impl(
api: Api,
inner_impls: List[Tuple[str, Any]],
routing_table_config: RoutingTableConfig,
routing_table_config: Dict[str, List[RoutableProviderConfig]],
_deps,
) -> Any:
from .routing_tables import (

View file

@ -162,6 +162,7 @@ class SafetyRouter(Safety):
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
print(f"Running shield {shield_type}")
return await self.routing_table.get_provider_impl(shield_type).run_shield(
shield_type=shield_type,
messages=messages,

View file

@ -20,7 +20,7 @@ class CommonRoutingTableImpl(RoutingTable):
def __init__(
self,
inner_impls: List[Tuple[str, Any]],
routing_table_config: RoutingTableConfig,
routing_table_config: Dict[str, List[RoutableProviderConfig]],
) -> None:
self.providers = {k: v for k, v in inner_impls}
self.routing_keys = list(self.providers.keys())
@ -40,7 +40,7 @@ class CommonRoutingTableImpl(RoutingTable):
return self.routing_keys
def get_provider_config(self, routing_key: str) -> Optional[GenericProviderConfig]:
for entry in self.routing_table_config.entries:
for entry in self.routing_table_config:
if entry.routing_key == routing_key:
return entry
return None
@ -50,7 +50,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def list_models(self) -> List[ModelServingSpec]:
specs = []
for entry in self.routing_table_config.entries:
for entry in self.routing_table_config:
model_id = entry.routing_key
specs.append(
ModelServingSpec(
@ -61,7 +61,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
return specs
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]:
for entry in self.routing_table_config.entries:
for entry in self.routing_table_config:
if entry.routing_key == core_model_id:
return ModelServingSpec(
llama_model=resolve_model(core_model_id),
@ -74,7 +74,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> List[ShieldSpec]:
specs = []
for entry in self.routing_table_config.entries:
for entry in self.routing_table_config:
specs.append(
ShieldSpec(
shield_type=entry.routing_key,
@ -84,7 +84,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
return specs
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]:
for entry in self.routing_table_config.entries:
for entry in self.routing_table_config:
if entry.routing_key == shield_type:
return ShieldSpec(
shield_type=entry.routing_key,
@ -97,7 +97,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
async def list_memory_banks(self) -> List[MemoryBankSpec]:
specs = []
for entry in self.routing_table_config.entries:
for entry in self.routing_table_config:
specs.append(
MemoryBankSpec(
bank_type=entry.routing_key,
@ -107,7 +107,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
return specs
async def get_memory_bank(self, bank_type: str) -> Optional[MemoryBankSpec]:
for entry in self.routing_table_config.entries:
for entry in self.routing_table_config:
if entry.routing_key == bank_type:
return MemoryBankSpec(
bank_type=entry.routing_key,

View file

@ -35,9 +35,6 @@ from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.routing import APIRoute
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
@ -45,6 +42,9 @@ from llama_stack.providers.utils.telemetry.tracing import (
SpanStatus,
start_trace,
)
from pydantic import BaseModel, ValidationError
from termcolor import cprint
from typing_extensions import Annotated
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.distribution import (
@ -315,7 +315,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
configs[api] = config
apis_to_serve = run_config.apis_to_serve or set(
list(specs.keys()) + list(run_config.routing_tables.keys())
list(specs.keys()) + list(run_config.routing_table.keys())
)
print("apis_to_serve", apis_to_serve)
for info in builtin_automatically_routed_apis():
@ -331,15 +331,16 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
if info.router_api.value not in apis_to_serve:
continue
if source_api.value not in run_config.routing_tables:
print("router_api", info.router_api)
if info.router_api.value not in run_config.routing_table:
raise ValueError(f"Routing table for `{source_api.value}` is not provided?")
routing_table = run_config.routing_tables[source_api.value]
routing_table = run_config.routing_table[info.router_api.value]
providers = all_providers[info.router_api]
inner_specs = []
for rt_entry in routing_table.entries:
for rt_entry in routing_table:
if rt_entry.provider_id not in providers:
raise ValueError(
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"

View file

@ -43,12 +43,12 @@ async def instantiate_provider(
elif isinstance(provider_spec, RoutingTableProviderSpec):
method = "get_routing_table_impl"
assert isinstance(provider_config, RoutingTableConfig)
assert isinstance(provider_config, List)
routing_table = provider_config
inner_specs = {x.provider_id: x for x in provider_spec.inner_specs}
inner_impls = []
for routing_entry in routing_table.entries:
for routing_entry in routing_table:
impl = await instantiate_provider(
inner_specs[routing_entry.provider_id],
deps,

View file

@ -20,40 +20,29 @@ api_providers:
namespace: null
type: sqlite
db_path: /home/xiyan/.llama/runtime/kvstore.db
routing_tables:
models:
entries:
- routing_key: Meta-Llama3.1-8B-Instruct
provider_id: meta-reference
config:
model: Meta-Llama3.1-8B-Instruct
quantization: null
torch_seed: null
max_seq_len: 4096
max_batch_size: 1
- routing_key: Meta-Llama3.1-8B
provider_id: meta-reference
config:
model: Meta-Llama3.1-8B
quantization: null
torch_seed: null
max_seq_len: 4096
max_batch_size: 1
memory_banks:
entries:
- routing_key: vector
provider_id: meta-reference
config: {}
shields:
entries:
- routing_key: llama_guard_shield
provider_id: meta-reference
config:
model: Llama-Guard-3-8B
excluded_categories: []
disable_input_check: false
disable_output_check: false
- routing_key: prompt_guard_shield
provider_id: meta-reference
config:
model: Prompt-Guard-86M
routing_table:
inference:
- routing_key: Meta-Llama3.1-8B-Instruct
provider_id: meta-reference
config:
model: Meta-Llama3.1-8B-Instruct
quantization: null
torch_seed: null
max_seq_len: 4096
max_batch_size: 1
memory:
- routing_key: vector
provider_id: meta-reference
config: {}
safety:
- routing_key: llama_guard
provider_id: meta-reference
config:
model: Llama-Guard-3-8B
excluded_categories: []
disable_input_check: false
disable_output_check: false
- routing_key: prompt_guard
provider_id: meta-reference
config:
model: Prompt-Guard-86M