right naming

This commit is contained in:
Dinesh Yeduguru 2024-11-07 22:24:45 -08:00
parent 19d57b4d82
commit 04a2965967
10 changed files with 50 additions and 17 deletions

View file

@ -22,10 +22,22 @@ class ResourceType(Enum):
class Resource(BaseModel):
"""Base class for all Llama Stack resources"""
identifier: str = Field(description="Unique identifier for this resource")
identifier: str = Field(
description="Unique identifier for this resource in llama stack"
)
provider_resource_identifier: str = Field(
description="Unique identifier for this resource in the provider",
default=None,
)
provider_id: str = Field(description="ID of the provider that owns this resource")
type: ResourceType = Field(
description="Type of resource (e.g. 'model', 'shield', 'memory_bank', etc.)"
)
# If the provider_resource_identifier is not set, set it to the identifier
def model_post_init(self, __context) -> None:
if self.provider_resource_identifier is None:
self.provider_resource_identifier = self.identifier

View file

@ -41,13 +41,13 @@ class SafetyClient(Safety):
pass
async def run_shield(
self, shield_type: str, messages: List[Message]
self, shield_id: str, messages: List[Message]
) -> RunShieldResponse:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/safety/run_shield",
json=dict(
shield_type=shield_type,
shield_id=shield_id,
messages=[encodable_dict(m) for m in messages],
),
headers={
@ -80,7 +80,7 @@ async def run_main(host: str, port: int, image_path: str = None):
)
cprint(f"User>{message.content}", "green")
response = await client.run_shield(
shield_type="llama_guard",
shield_id="llama_guard",
messages=[message],
)
print(response)
@ -91,7 +91,7 @@ async def run_main(host: str, port: int, image_path: str = None):
]:
cprint(f"User>{message.content}", "green")
response = await client.run_shield(
shield_type="llama_guard",
shield_id="llama_guard",
messages=[message],
)
print(response)

View file

@ -38,10 +38,18 @@ class RunShieldResponse(BaseModel):
violation: Optional[SafetyViolation] = None
class ShieldStore(Protocol):
async def get_shield(self, identifier: str) -> Shield: ...
@runtime_checkable
class Safety(Protocol):
shield_store: ShieldStore
@webmethod(route="/safety/run_shield")
async def run_shield(
self, shield: Shield, messages: List[Message], params: Dict[str, Any] = None
self,
shield_id: str,
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse: ...