llama-stack-mirror/llama_stack/apis/shields/shields.py
Ihar Hrachyshka 1deb95f922 chore: enable pyupgrade fixes
Schema reflection code needed a minor adjustment to handle UnionTypes
and collections.abc.AsyncIterator. (Both are preferred for latest Python
releases.)

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
2025-05-01 17:02:13 -04:00

61 lines
1.7 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
class CommonShieldFields(BaseModel):
params: dict[str, Any] | None = None
@json_schema_type
class Shield(CommonShieldFields, Resource):
"""A safety shield resource that can be used to check content"""
type: Literal[ResourceType.shield.value] = ResourceType.shield.value
@property
def shield_id(self) -> str:
return self.identifier
@property
def provider_shield_id(self) -> str:
return self.provider_resource_id
class ShieldInput(CommonShieldFields):
shield_id: str
provider_id: str | None = None
provider_shield_id: str | None = None
class ListShieldsResponse(BaseModel):
data: list[Shield]
@runtime_checkable
@trace_protocol
class Shields(Protocol):
@webmethod(route="/shields", method="GET")
async def list_shields(self) -> ListShieldsResponse: ...
@webmethod(route="/shields/{identifier:path}", method="GET")
async def get_shield(self, identifier: str) -> Shield: ...
@webmethod(route="/shields", method="POST")
async def register_shield(
self,
shield_id: str,
provider_shield_id: str | None = None,
provider_id: str | None = None,
params: dict[str, Any] | None = None,
) -> Shield: ...