forked from phoenix-oss/llama-stack-mirror
Test Plan: First, start a TGI container with `meta-llama/Llama-Guard-3-8B` model serving on port 5099. See https://github.com/meta-llama/llama-stack/pull/53 and its description for how. Then run llama-stack with the following run config: ``` image_name: safety docker_image: null conda_env: safety apis_to_serve: - models - inference - shields - safety api_providers: inference: providers: - remote::tgi safety: providers: - meta-reference telemetry: provider_id: meta-reference config: {} routing_table: inference: - provider_id: remote::tgi config: url: http://localhost:5099 api_token: null hf_endpoint_name: null routing_key: Llama-Guard-3-8B safety: - provider_id: meta-reference config: llama_guard_shield: model: Llama-Guard-3-8B excluded_categories: [] disable_input_check: false disable_output_check: false prompt_guard_shield: null routing_key: llama_guard ``` Now simply run `python -m llama_stack.apis.safety.client localhost <port>` and check that the llama_guard shield calls run correctly. (The injection_shield calls fail as expected since we have not set up a router for them.)
118 lines
3.9 KiB
Python
118 lines
3.9 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from typing import Any, List, Optional, Tuple
|
|
|
|
from llama_models.sku_list import resolve_model
|
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
|
|
|
from llama_stack.apis.models import * # noqa: F403
|
|
from llama_stack.apis.shields import * # noqa: F403
|
|
from llama_stack.apis.memory_banks import * # noqa: F403
|
|
|
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
|
|
|
|
|
class CommonRoutingTableImpl(RoutingTable):
|
|
def __init__(
|
|
self,
|
|
inner_impls: List[Tuple[str, Any]],
|
|
routing_table_config: Dict[str, List[RoutableProviderConfig]],
|
|
) -> None:
|
|
self.providers = {k: v for k, v in inner_impls}
|
|
self.routing_keys = list(self.providers.keys())
|
|
self.routing_table_config = routing_table_config
|
|
|
|
async def initialize(self) -> None:
|
|
pass
|
|
|
|
async def shutdown(self) -> None:
|
|
for p in self.providers.values():
|
|
await p.shutdown()
|
|
|
|
def get_provider_impl(self, routing_key: str) -> Any:
|
|
if routing_key not in self.providers:
|
|
raise ValueError(f"Could not find provider for {routing_key}")
|
|
return self.providers[routing_key]
|
|
|
|
def get_routing_keys(self) -> List[str]:
|
|
return self.routing_keys
|
|
|
|
def get_provider_config(self, routing_key: str) -> Optional[GenericProviderConfig]:
|
|
for entry in self.routing_table_config:
|
|
if entry.routing_key == routing_key:
|
|
return entry
|
|
return None
|
|
|
|
|
|
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|
|
|
async def list_models(self) -> List[ModelServingSpec]:
|
|
specs = []
|
|
for entry in self.routing_table_config:
|
|
model_id = entry.routing_key
|
|
specs.append(
|
|
ModelServingSpec(
|
|
llama_model=resolve_model(model_id),
|
|
provider_config=entry,
|
|
)
|
|
)
|
|
return specs
|
|
|
|
async def get_model(self, core_model_id: str) -> Optional[ModelServingSpec]:
|
|
for entry in self.routing_table_config:
|
|
if entry.routing_key == core_model_id:
|
|
return ModelServingSpec(
|
|
llama_model=resolve_model(core_model_id),
|
|
provider_config=entry,
|
|
)
|
|
return None
|
|
|
|
|
|
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
|
|
|
async def list_shields(self) -> List[ShieldSpec]:
|
|
specs = []
|
|
for entry in self.routing_table_config:
|
|
specs.append(
|
|
ShieldSpec(
|
|
shield_type=entry.routing_key,
|
|
provider_config=entry,
|
|
)
|
|
)
|
|
return specs
|
|
|
|
async def get_shield(self, shield_type: str) -> Optional[ShieldSpec]:
|
|
for entry in self.routing_table_config:
|
|
if entry.routing_key == shield_type:
|
|
return ShieldSpec(
|
|
shield_type=entry.routing_key,
|
|
provider_config=entry,
|
|
)
|
|
return None
|
|
|
|
|
|
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
|
|
|
async def list_available_memory_banks(self) -> List[MemoryBankSpec]:
|
|
specs = []
|
|
for entry in self.routing_table_config:
|
|
specs.append(
|
|
MemoryBankSpec(
|
|
bank_type=entry.routing_key,
|
|
provider_config=entry,
|
|
)
|
|
)
|
|
return specs
|
|
|
|
async def get_serving_memory_bank(self, bank_type: str) -> Optional[MemoryBankSpec]:
|
|
for entry in self.routing_table_config:
|
|
if entry.routing_key == bank_type:
|
|
return MemoryBankSpec(
|
|
bank_type=entry.routing_key,
|
|
provider_config=entry,
|
|
)
|
|
return None
|