Use inference APIs for running llama guard

Test Plan:

First, start a TGI container with `meta-llama/Llama-Guard-3-8B` model
serving on port 5099. See https://github.com/meta-llama/llama-stack/pull/53 and its
description for how.

Then run llama-stack with the following run config:

```
image_name: safety
docker_image: null
conda_env: safety
apis_to_serve:
- models
- inference
- shields
- safety
api_providers:
  inference:
    providers:
    - remote::tgi
  safety:
    providers:
    - meta-reference
  telemetry:
    provider_id: meta-reference
    config: {}
routing_table:
  inference:
  - provider_id: remote::tgi
    config:
      url: http://localhost:5099
      api_token: null
      hf_endpoint_name: null
    routing_key: Llama-Guard-3-8B
  safety:
  - provider_id: meta-reference
    config:
      llama_guard_shield:
        model: Llama-Guard-3-8B
        excluded_categories: []
        disable_input_check: false
        disable_output_check: false
      prompt_guard_shield: null
    routing_key: llama_guard
```

Now simply run `python -m llama_stack.apis.safety.client localhost
<port>` and check that the llama_guard shield calls run correctly. (The
injection_shield calls fail as expected since we have not set up a
router for them.)
This commit is contained in:
Ashwin Bharambe 2024-09-24 17:02:57 -07:00
parent c4534217c8
commit 0d2eb3bd25
9 changed files with 56 additions and 81 deletions

View file

@ -103,8 +103,7 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
# TODO: we need to fix streaming response to align provider implementations with Protocol.
async for chunk in self.routing_table.get_provider_impl(model).chat_completion(
params = dict(
model=model,
messages=messages,
sampling_params=sampling_params,
@ -113,6 +112,10 @@ class InferenceRouter(Inference):
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
# TODO: we need to fix streaming response to align provider implementations with Protocol.
async for chunk in self.routing_table.get_provider_impl(model).chat_completion(
**params
):
yield chunk

View file

@ -33,8 +33,10 @@ class CommonRoutingTableImpl(RoutingTable):
for p in self.providers.values():
await p.shutdown()
def get_provider_impl(self, routing_key: str) -> Optional[Any]:
return self.providers.get(routing_key)
def get_provider_impl(self, routing_key: str) -> Any:
if routing_key not in self.providers:
raise ValueError(f"Could not find provider for {routing_key}")
return self.providers[routing_key]
def get_routing_keys(self) -> List[str]:
return self.routing_keys

View file

@ -368,17 +368,19 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
providers = all_providers[info.router_api]
inner_specs = []
inner_deps = []
for rt_entry in routing_table:
if rt_entry.provider_id not in providers:
raise ValueError(
f"Unknown provider `{rt_entry.provider_id}` is not available for API `{api}`"
)
inner_specs.append(providers[rt_entry.provider_id])
inner_deps.extend(providers[rt_entry.provider_id].api_dependencies)
specs[source_api] = RoutingTableProviderSpec(
api=source_api,
module="llama_stack.distribution.routers",
api_dependencies=[],
api_dependencies=inner_deps,
inner_specs=inner_specs,
)
configs[source_api] = routing_table