pre-commit fixes

This commit is contained in:
Chantal D Gama Rose 2025-03-14 13:56:05 -07:00
parent 967dd0aa08
commit 7e211f8553
314 changed files with 5574 additions and 11369 deletions

View file

@ -42,7 +42,10 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
raise ValueError("Shield model not provided.")
async def run_shield(
self, shield_id: str, messages: List[Message], params: Dict[str, Any] = None
self,
shield_id: str,
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
"""
Run a safety shield check against the provided messages.
@ -50,7 +53,6 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
Args:
shield_id (str): The unique identifier for the shield to be used.
messages (List[Message]): A list of Message objects representing the conversation history.
params (Dict[str, Any], optional): Additional parameters for the safety check.
Returns:
RunShieldResponse: The response containing safety violation details if any.
@ -96,7 +98,7 @@ class NeMoGuardrails:
"""
self.config_id = config.config_id
self.model = model
assert self.config_id is not None("Must provide config id")
assert self.config_id is not None, "Must provide config id"
if temperature <= 0:
raise ValueError("Temperature must be greater than 0")

View file

@ -1,17 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import SampleConfig
async def get_adapter_impl(config: SampleConfig, _deps) -> Any:
from .sample import SampleSafetyImpl
impl = SampleSafetyImpl(config)
await impl.initialize()
return impl

View file

@ -1,12 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class SampleConfig(BaseModel):
host: str = "localhost"
port: int = 9999

View file

@ -1,23 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.safety import Safety
from llama_stack.apis.shields import Shield
from .config import SampleConfig
class SampleSafetyImpl(Safety):
def __init__(self, config: SampleConfig):
self.config = config
async def register_shield(self, shield: Shield) -> None:
# these are the safety shields the Llama Stack will use to route requests to this provider
# perform validation here if necessary
pass
async def initialize(self):
pass