mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-02 20:40:00 +00:00
pre-commit fixes
This commit is contained in:
parent
967dd0aa08
commit
7e211f8553
314 changed files with 5574 additions and 11369 deletions
|
|
@ -42,7 +42,10 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
|||
raise ValueError("Shield model not provided.")
|
||||
|
||||
async def run_shield(
|
||||
self, shield_id: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
"""
|
||||
Run a safety shield check against the provided messages.
|
||||
|
|
@ -50,7 +53,6 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
|||
Args:
|
||||
shield_id (str): The unique identifier for the shield to be used.
|
||||
messages (List[Message]): A list of Message objects representing the conversation history.
|
||||
params (Dict[str, Any], optional): Additional parameters for the safety check.
|
||||
|
||||
Returns:
|
||||
RunShieldResponse: The response containing safety violation details if any.
|
||||
|
|
@ -96,7 +98,7 @@ class NeMoGuardrails:
|
|||
"""
|
||||
self.config_id = config.config_id
|
||||
self.model = model
|
||||
assert self.config_id is not None("Must provide config id")
|
||||
assert self.config_id is not None, "Must provide config id"
|
||||
if temperature <= 0:
|
||||
raise ValueError("Temperature must be greater than 0")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,17 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: SampleConfig, _deps) -> Any:
|
||||
from .sample import SampleSafetyImpl
|
||||
|
||||
impl = SampleSafetyImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -1,12 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SampleConfig(BaseModel):
|
||||
host: str = "localhost"
|
||||
port: int = 9999
|
||||
|
|
@ -1,23 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.shields import Shield
|
||||
|
||||
from .config import SampleConfig
|
||||
|
||||
|
||||
class SampleSafetyImpl(Safety):
|
||||
def __init__(self, config: SampleConfig):
|
||||
self.config = config
|
||||
|
||||
async def register_shield(self, shield: Shield) -> None:
|
||||
# these are the safety shields the Llama Stack will use to route requests to this provider
|
||||
# perform validation here if necessary
|
||||
pass
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
Loading…
Add table
Add a link
Reference in a new issue