Use inference APIs for running llama guard

Test Plan:

First, start a TGI container with `meta-llama/Llama-Guard-3-8B` model
serving on port 5099. See https://github.com/meta-llama/llama-stack/pull/53 and its
description for how.

Then run llama-stack with the following run config:

```
image_name: safety
docker_image: null
conda_env: safety
apis_to_serve:
- models
- inference
- shields
- safety
api_providers:
  inference:
    providers:
    - remote::tgi
  safety:
    providers:
    - meta-reference
  telemetry:
    provider_id: meta-reference
    config: {}
routing_table:
  inference:
  - provider_id: remote::tgi
    config:
      url: http://localhost:5099
      api_token: null
      hf_endpoint_name: null
    routing_key: Llama-Guard-3-8B
  safety:
  - provider_id: meta-reference
    config:
      llama_guard_shield:
        model: Llama-Guard-3-8B
        excluded_categories: []
        disable_input_check: false
        disable_output_check: false
      prompt_guard_shield: null
    routing_key: llama_guard
```

Now simply run `python -m llama_stack.apis.safety.client localhost
<port>` and check that the llama_guard shield calls run correctly. (The
injection_shield calls fail as expected since we have not set up a
router for them.)
This commit is contained in:
Ashwin Bharambe 2024-09-24 17:02:57 -07:00
parent c4534217c8
commit 0d2eb3bd25
9 changed files with 56 additions and 81 deletions

View file

@ -33,8 +33,10 @@ class CommonRoutingTableImpl(RoutingTable):
for p in self.providers.values():
await p.shutdown()
def get_provider_impl(self, routing_key: str) -> Optional[Any]:
return self.providers.get(routing_key)
def get_provider_impl(self, routing_key: str) -> Any:
if routing_key not in self.providers:
raise ValueError(f"Could not find provider for {routing_key}")
return self.providers[routing_key]
def get_routing_keys(self) -> List[str]:
return self.routing_keys