Use inference APIs for running llama guard

Test Plan:

First, start a TGI container with `meta-llama/Llama-Guard-3-8B` model
serving on port 5099. See https://github.com/meta-llama/llama-stack/pull/53 and its
description for how.

Then run llama-stack with the following run config:

```
image_name: safety
docker_image: null
conda_env: safety
apis_to_serve:
- models
- inference
- shields
- safety
api_providers:
  inference:
    providers:
    - remote::tgi
  safety:
    providers:
    - meta-reference
  telemetry:
    provider_id: meta-reference
    config: {}
routing_table:
  inference:
  - provider_id: remote::tgi
    config:
      url: http://localhost:5099
      api_token: null
      hf_endpoint_name: null
    routing_key: Llama-Guard-3-8B
  safety:
  - provider_id: meta-reference
    config:
      llama_guard_shield:
        model: Llama-Guard-3-8B
        excluded_categories: []
        disable_input_check: false
        disable_output_check: false
      prompt_guard_shield: null
    routing_key: llama_guard
```

Now simply run `python -m llama_stack.apis.safety.client localhost
<port>` and check that the llama_guard shield calls run correctly. (The
injection_shield calls fail as expected since we have not set up a
router for them.)
This commit is contained in:
Ashwin Bharambe 2024-09-24 17:02:57 -07:00
parent c4534217c8
commit 0d2eb3bd25
9 changed files with 56 additions and 81 deletions

View file

@ -368,17 +368,19 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
providers = all_providers[info.router_api]
inner_specs = []
inner_deps = []
for rt_entry in routing_table:
if rt_entry.provider_id not in providers:
raise ValueError(
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
)
inner_specs.append(providers[rt_entry.provider_id])
inner_deps.extend(providers[rt_entry.provider_id].api_dependencies)
specs[source_api] = RoutingTableProviderSpec(
api=source_api,
module="llama_stack.distribution.routers",
api_dependencies=[],
api_dependencies=inner_deps,
inner_specs=inner_specs,
)
configs[source_api] = routing_table