mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
issue w/ safety
This commit is contained in:
parent
e0ad4fb99c
commit
4586692dee
8 changed files with 50 additions and 66 deletions
|
@ -7,11 +7,11 @@
|
||||||
from typing import List, Optional, Protocol
|
from typing import List, Optional, Protocol
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from llama_stack.apis.memory import MemoryBankType
|
from llama_stack.apis.memory import MemoryBankType
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import GenericProviderConfig
|
from llama_stack.distribution.datatypes import GenericProviderConfig
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -23,10 +23,10 @@ class MemoryBankSpec(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class MemoryBanks(Protocol):
|
class MemoryBanks(Protocol):
|
||||||
@webmethod(route="/memory_banks/list", method="GET")
|
@webmethod(route="/memory_banks_router/list", method="GET")
|
||||||
async def list_memory_banks(self) -> List[MemoryBankSpec]: ...
|
async def list_memory_banks(self) -> List[MemoryBankSpec]: ...
|
||||||
|
|
||||||
@webmethod(route="/memory_banks/get", method="GET")
|
@webmethod(route="/memory_banks_router/get", method="GET")
|
||||||
async def get_memory_bank(
|
async def get_memory_bank(
|
||||||
self, bank_type: MemoryBankType
|
self, bank_type: MemoryBankType
|
||||||
) -> Optional[MemoryBankSpec]: ...
|
) -> Optional[MemoryBankSpec]: ...
|
||||||
|
|
|
@ -63,13 +63,6 @@ class RoutableProviderConfig(GenericProviderConfig):
|
||||||
routing_key: str
|
routing_key: str
|
||||||
|
|
||||||
|
|
||||||
class RoutingTableConfig(BaseModel):
|
|
||||||
entries: List[RoutableProviderConfig] = Field(...)
|
|
||||||
keys: Optional[List[str]] = Field(
|
|
||||||
default=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Example: /inference, /safety
|
# Example: /inference, /safety
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AutoRoutedProviderSpec(ProviderSpec):
|
class AutoRoutedProviderSpec(ProviderSpec):
|
||||||
|
@ -275,7 +268,7 @@ The list of APIs to serve. If not specified, all APIs specified in the provider_
|
||||||
Provider configurations for each of the APIs provided by this package.
|
Provider configurations for each of the APIs provided by this package.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
routing_tables: Dict[str, RoutingTableConfig] = Field(
|
routing_table: Dict[str, List[RoutableProviderConfig]] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="""
|
description="""
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
async def get_routing_table_impl(
|
async def get_routing_table_impl(
|
||||||
api: Api,
|
api: Api,
|
||||||
inner_impls: List[Tuple[str, Any]],
|
inner_impls: List[Tuple[str, Any]],
|
||||||
routing_table_config: RoutingTableConfig,
|
routing_table_config: Dict[str, List[RoutableProviderConfig]],
|
||||||
_deps,
|
_deps,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
from .routing_tables import (
|
from .routing_tables import (
|
||||||
|
|
|
@ -162,6 +162,7 @@ class SafetyRouter(Safety):
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
params: Dict[str, Any] = None,
|
params: Dict[str, Any] = None,
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
|
print(f"Running shield {shield_type}")
|
||||||
return await self.routing_table.get_provider_impl(shield_type).run_shield(
|
return await self.routing_table.get_provider_impl(shield_type).run_shield(
|
||||||
shield_type=shield_type,
|
shield_type=shield_type,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
|
|
@ -20,7 +20,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
inner_impls: List[Tuple[str, Any]],
|
inner_impls: List[Tuple[str, Any]],
|
||||||
routing_table_config: RoutingTableConfig,
|
routing_table_config: Dict[str, List[RoutableProviderConfig]],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.providers = {k: v for k, v in inner_impls}
|
self.providers = {k: v for k, v in inner_impls}
|
||||||
self.routing_keys = list(self.providers.keys())
|
self.routing_keys = list(self.providers.keys())
|
||||||
|
@ -40,7 +40,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
return self.routing_keys
|
return self.routing_keys
|
||||||
|
|
||||||
def get_provider_config(self, routing_key: str) -> Optional[GenericProviderConfig]:
|
def get_provider_config(self, routing_key: str) -> Optional[GenericProviderConfig]:
|
||||||
for entry in self.routing_table_config.entries:
|
for entry in self.routing_table_config:
|
||||||
if entry.routing_key == routing_key:
|
if entry.routing_key == routing_key:
|
||||||
return entry
|
return entry
|
||||||
return None
|
return None
|
||||||
|
@ -50,7 +50,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
|
|
||||||
async def list_models(self) -> List[ModelServingSpec]:
|
async def list_models(self) -> List[ModelServingSpec]:
|
||||||
specs = []
|
specs = []
|
||||||
for entry in self.routing_table_config.entries:
|
for entry in self.routing_table_config:
|
||||||
model_id = entry.routing_key
|
model_id = entry.routing_key
|
||||||
specs.append(
|
specs.append(
|
||||||
ModelServingSpec(
|
ModelServingSpec(
|
||||||
|
@ -61,7 +61,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
return specs
|
return specs
|
||||||
|
|
||||||
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]:
|
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]:
|
||||||
for entry in self.routing_table_config.entries:
|
for entry in self.routing_table_config:
|
||||||
if entry.routing_key == core_model_id:
|
if entry.routing_key == core_model_id:
|
||||||
return ModelServingSpec(
|
return ModelServingSpec(
|
||||||
llama_model=resolve_model(core_model_id),
|
llama_model=resolve_model(core_model_id),
|
||||||
|
@ -74,7 +74,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
|
|
||||||
async def list_shields(self) -> List[ShieldSpec]:
|
async def list_shields(self) -> List[ShieldSpec]:
|
||||||
specs = []
|
specs = []
|
||||||
for entry in self.routing_table_config.entries:
|
for entry in self.routing_table_config:
|
||||||
specs.append(
|
specs.append(
|
||||||
ShieldSpec(
|
ShieldSpec(
|
||||||
shield_type=entry.routing_key,
|
shield_type=entry.routing_key,
|
||||||
|
@ -84,7 +84,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
return specs
|
return specs
|
||||||
|
|
||||||
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]:
|
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]:
|
||||||
for entry in self.routing_table_config.entries:
|
for entry in self.routing_table_config:
|
||||||
if entry.routing_key == shield_type:
|
if entry.routing_key == shield_type:
|
||||||
return ShieldSpec(
|
return ShieldSpec(
|
||||||
shield_type=entry.routing_key,
|
shield_type=entry.routing_key,
|
||||||
|
@ -97,7 +97,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||||
|
|
||||||
async def list_memory_banks(self) -> List[MemoryBankSpec]:
|
async def list_memory_banks(self) -> List[MemoryBankSpec]:
|
||||||
specs = []
|
specs = []
|
||||||
for entry in self.routing_table_config.entries:
|
for entry in self.routing_table_config:
|
||||||
specs.append(
|
specs.append(
|
||||||
MemoryBankSpec(
|
MemoryBankSpec(
|
||||||
bank_type=entry.routing_key,
|
bank_type=entry.routing_key,
|
||||||
|
@ -107,7 +107,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||||
return specs
|
return specs
|
||||||
|
|
||||||
async def get_memory_bank(self, bank_type: str) -> Optional[MemoryBankSpec]:
|
async def get_memory_bank(self, bank_type: str) -> Optional[MemoryBankSpec]:
|
||||||
for entry in self.routing_table_config.entries:
|
for entry in self.routing_table_config:
|
||||||
if entry.routing_key == bank_type:
|
if entry.routing_key == bank_type:
|
||||||
return MemoryBankSpec(
|
return MemoryBankSpec(
|
||||||
bank_type=entry.routing_key,
|
bank_type=entry.routing_key,
|
||||||
|
|
|
@ -35,9 +35,6 @@ from fastapi import Body, FastAPI, HTTPException, Request, Response
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
from fastapi.routing import APIRoute
|
from fastapi.routing import APIRoute
|
||||||
from pydantic import BaseModel, ValidationError
|
|
||||||
from termcolor import cprint
|
|
||||||
from typing_extensions import Annotated
|
|
||||||
|
|
||||||
from llama_stack.providers.utils.telemetry.tracing import (
|
from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
end_trace,
|
end_trace,
|
||||||
|
@ -45,6 +42,9 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
SpanStatus,
|
SpanStatus,
|
||||||
start_trace,
|
start_trace,
|
||||||
)
|
)
|
||||||
|
from pydantic import BaseModel, ValidationError
|
||||||
|
from termcolor import cprint
|
||||||
|
from typing_extensions import Annotated
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.distribution.distribution import (
|
from llama_stack.distribution.distribution import (
|
||||||
|
@ -315,7 +315,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
||||||
configs[api] = config
|
configs[api] = config
|
||||||
|
|
||||||
apis_to_serve = run_config.apis_to_serve or set(
|
apis_to_serve = run_config.apis_to_serve or set(
|
||||||
list(specs.keys()) + list(run_config.routing_tables.keys())
|
list(specs.keys()) + list(run_config.routing_table.keys())
|
||||||
)
|
)
|
||||||
print("apis_to_serve", apis_to_serve)
|
print("apis_to_serve", apis_to_serve)
|
||||||
for info in builtin_automatically_routed_apis():
|
for info in builtin_automatically_routed_apis():
|
||||||
|
@ -331,15 +331,16 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
||||||
if info.router_api.value not in apis_to_serve:
|
if info.router_api.value not in apis_to_serve:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if source_api.value not in run_config.routing_tables:
|
print("router_api", info.router_api)
|
||||||
|
if info.router_api.value not in run_config.routing_table:
|
||||||
raise ValueError(f"Routing table for `{source_api.value}` is not provided?")
|
raise ValueError(f"Routing table for `{source_api.value}` is not provided?")
|
||||||
|
|
||||||
routing_table = run_config.routing_tables[source_api.value]
|
routing_table = run_config.routing_table[info.router_api.value]
|
||||||
|
|
||||||
providers = all_providers[info.router_api]
|
providers = all_providers[info.router_api]
|
||||||
|
|
||||||
inner_specs = []
|
inner_specs = []
|
||||||
for rt_entry in routing_table.entries:
|
for rt_entry in routing_table:
|
||||||
if rt_entry.provider_id not in providers:
|
if rt_entry.provider_id not in providers:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
|
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
|
||||||
|
|
|
@ -43,12 +43,12 @@ async def instantiate_provider(
|
||||||
elif isinstance(provider_spec, RoutingTableProviderSpec):
|
elif isinstance(provider_spec, RoutingTableProviderSpec):
|
||||||
method = "get_routing_table_impl"
|
method = "get_routing_table_impl"
|
||||||
|
|
||||||
assert isinstance(provider_config, RoutingTableConfig)
|
assert isinstance(provider_config, List)
|
||||||
routing_table = provider_config
|
routing_table = provider_config
|
||||||
|
|
||||||
inner_specs = {x.provider_id: x for x in provider_spec.inner_specs}
|
inner_specs = {x.provider_id: x for x in provider_spec.inner_specs}
|
||||||
inner_impls = []
|
inner_impls = []
|
||||||
for routing_entry in routing_table.entries:
|
for routing_entry in routing_table:
|
||||||
impl = await instantiate_provider(
|
impl = await instantiate_provider(
|
||||||
inner_specs[routing_entry.provider_id],
|
inner_specs[routing_entry.provider_id],
|
||||||
deps,
|
deps,
|
||||||
|
|
|
@ -20,40 +20,29 @@ api_providers:
|
||||||
namespace: null
|
namespace: null
|
||||||
type: sqlite
|
type: sqlite
|
||||||
db_path: /home/xiyan/.llama/runtime/kvstore.db
|
db_path: /home/xiyan/.llama/runtime/kvstore.db
|
||||||
routing_tables:
|
routing_table:
|
||||||
models:
|
inference:
|
||||||
entries:
|
- routing_key: Meta-Llama3.1-8B-Instruct
|
||||||
- routing_key: Meta-Llama3.1-8B-Instruct
|
provider_id: meta-reference
|
||||||
provider_id: meta-reference
|
config:
|
||||||
config:
|
model: Meta-Llama3.1-8B-Instruct
|
||||||
model: Meta-Llama3.1-8B-Instruct
|
quantization: null
|
||||||
quantization: null
|
torch_seed: null
|
||||||
torch_seed: null
|
max_seq_len: 4096
|
||||||
max_seq_len: 4096
|
max_batch_size: 1
|
||||||
max_batch_size: 1
|
memory:
|
||||||
- routing_key: Meta-Llama3.1-8B
|
- routing_key: vector
|
||||||
provider_id: meta-reference
|
provider_id: meta-reference
|
||||||
config:
|
config: {}
|
||||||
model: Meta-Llama3.1-8B
|
safety:
|
||||||
quantization: null
|
- routing_key: llama_guard
|
||||||
torch_seed: null
|
provider_id: meta-reference
|
||||||
max_seq_len: 4096
|
config:
|
||||||
max_batch_size: 1
|
model: Llama-Guard-3-8B
|
||||||
memory_banks:
|
excluded_categories: []
|
||||||
entries:
|
disable_input_check: false
|
||||||
- routing_key: vector
|
disable_output_check: false
|
||||||
provider_id: meta-reference
|
- routing_key: prompt_guard
|
||||||
config: {}
|
provider_id: meta-reference
|
||||||
shields:
|
config:
|
||||||
entries:
|
model: Prompt-Guard-86M
|
||||||
- routing_key: llama_guard_shield
|
|
||||||
provider_id: meta-reference
|
|
||||||
config:
|
|
||||||
model: Llama-Guard-3-8B
|
|
||||||
excluded_categories: []
|
|
||||||
disable_input_check: false
|
|
||||||
disable_output_check: false
|
|
||||||
- routing_key: prompt_guard_shield
|
|
||||||
provider_id: meta-reference
|
|
||||||
config:
|
|
||||||
model: Prompt-Guard-86M
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue