forked from phoenix-oss/llama-stack-mirror
Use inference APIs for running llama guard
Test Plan: First, start a TGI container with `meta-llama/Llama-Guard-3-8B` model serving on port 5099. See https://github.com/meta-llama/llama-stack/pull/53 and its description for how. Then run llama-stack with the following run config: ``` image_name: safety docker_image: null conda_env: safety apis_to_serve: - models - inference - shields - safety api_providers: inference: providers: - remote::tgi safety: providers: - meta-reference telemetry: provider_id: meta-reference config: {} routing_table: inference: - provider_id: remote::tgi config: url: http://localhost:5099 api_token: null hf_endpoint_name: null routing_key: Llama-Guard-3-8B safety: - provider_id: meta-reference config: llama_guard_shield: model: Llama-Guard-3-8B excluded_categories: [] disable_input_check: false disable_output_check: false prompt_guard_shield: null routing_key: llama_guard ``` Now simply run `python -m llama_stack.apis.safety.client localhost <port>` and check that the llama_guard shield calls run correctly. (The injection_shield calls fail as expected since we have not set up a router for them.)
This commit is contained in:
parent
c4534217c8
commit
0d2eb3bd25
9 changed files with 56 additions and 81 deletions
|
@ -33,8 +33,10 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
for p in self.providers.values():
|
||||
await p.shutdown()
|
||||
|
||||
def get_provider_impl(self, routing_key: str) -> Optional[Any]:
|
||||
return self.providers.get(routing_key)
|
||||
def get_provider_impl(self, routing_key: str) -> Any:
|
||||
if routing_key not in self.providers:
|
||||
raise ValueError(f"Could not find provider for {routing_key}")
|
||||
return self.providers[routing_key]
|
||||
|
||||
def get_routing_keys(self) -> List[str]:
|
||||
return self.routing_keys
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue