test safety against safety client

This commit is contained in:
Ashwin Bharambe 2024-09-20 14:55:00 -07:00 committed by Xi Yan
parent d6a41d98d2
commit 9252e81a7b
19 changed files with 1076 additions and 10754 deletions

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,17 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.distribution.datatypes import RemoteProviderConfig
from .config import BedrockSafetyRequestProviderData # noqa: F403
async def get_adapter_impl(config: RemoteProviderConfig, _deps):
from .bedrock import BedrockSafetyAdapter
impl = BedrockSafetyAdapter(config.url)
await impl.initialize()
return impl

View file

@ -1,52 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.providers.utils import get_request_provider_data
from .config import BedrockSafetyRequestProviderData
class BedrockSafetyAdapter(Safety):
def __init__(self, url: str) -> None:
self.url = url
pass
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def run_shield(
self,
shield: str,
messages: List[Message],
) -> RunShieldResponse:
# clients will set api_keys by doing something like:
#
# client = llama_stack.LlamaStack()
# await client.safety.run_shield(
# shield_type="aws_guardrail_type",
# messages=[ ... ],
# x_llamastack_provider_data={
# "aws_api_key": "..."
# }
# )
#
# This information will arrive at the LlamaStack server via a HTTP Header.
#
# The server will then provide you a type-checked version of this provider data
# automagically by extracting it from the header and validating it with the
# BedrockSafetyRequestProviderData class you will need to register in the provider
# registry.
#
provider_data: BedrockSafetyRequestProviderData = get_request_provider_data()
# use `aws_api_key` to pass to the AWS servers in whichever form
raise NotImplementedError()

View file

@ -1,12 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class BedrockSafetyRequestProviderData(BaseModel):
aws_api_key: str
# other AWS specific keys you may need