forked from phoenix-oss/llama-stack-mirror
Merge branch 'main' into eval_api_final
This commit is contained in:
commit
66cd83fb58
37 changed files with 1215 additions and 840 deletions
|
@ -44,7 +44,9 @@ class PandasDataframeDataset:
|
|||
elif self.dataset_def.source.type == "rows":
|
||||
self.df = pandas.DataFrame(self.dataset_def.source.rows)
|
||||
else:
|
||||
raise ValueError(f"Unsupported dataset source type: {self.dataset_def.source.type}")
|
||||
raise ValueError(
|
||||
f"Unsupported dataset source type: {self.dataset_def.source.type}"
|
||||
)
|
||||
|
||||
if self.df is None:
|
||||
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
|
||||
|
@ -108,7 +110,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
|
||||
return IterrowsResponse(
|
||||
data=rows,
|
||||
next_index=end if end < len(dataset_impl) else None,
|
||||
next_start_index=end if end < len(dataset_impl) else None,
|
||||
)
|
||||
|
||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||
|
@ -117,4 +119,6 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
dataset_impl.load()
|
||||
|
||||
new_rows_df = pandas.DataFrame(rows)
|
||||
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)
|
||||
dataset_impl.df = pandas.concat(
|
||||
[dataset_impl.df, new_rows_df], ignore_index=True
|
||||
)
|
||||
|
|
|
@ -55,4 +55,13 @@ def available_providers() -> List[ProviderSpec]:
|
|||
config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.safety,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="nvidia",
|
||||
pip_packages=["requests"],
|
||||
module="llama_stack.providers.remote.safety.nvidia",
|
||||
config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
|
@ -86,7 +86,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
|
||||
return IterrowsResponse(
|
||||
data=rows,
|
||||
next_index=end if end < len(loaded_dataset) else None,
|
||||
next_start_index=end if end < len(loaded_dataset) else None,
|
||||
)
|
||||
|
||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||
|
@ -98,9 +98,13 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
new_dataset = hf_datasets.Dataset.from_list(rows)
|
||||
|
||||
# Concatenate the new rows with existing dataset
|
||||
updated_dataset = hf_datasets.concatenate_datasets([loaded_dataset, new_dataset])
|
||||
updated_dataset = hf_datasets.concatenate_datasets(
|
||||
[loaded_dataset, new_dataset]
|
||||
)
|
||||
|
||||
if dataset_def.metadata.get("path", None):
|
||||
updated_dataset.push_to_hub(dataset_def.metadata["path"])
|
||||
else:
|
||||
raise NotImplementedError("Uploading to URL-based datasets is not supported yet")
|
||||
raise NotImplementedError(
|
||||
"Uploading to URL-based datasets is not supported yet"
|
||||
)
|
||||
|
|
|
@ -12,6 +12,7 @@ from llama_stack.apis.common.content_types import InterleavedContent
|
|||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
|
@ -160,12 +161,14 @@ class PassthroughInferenceAdapter(Inference):
|
|||
client = self._get_client()
|
||||
response = await client.inference.chat_completion(**json_params)
|
||||
|
||||
response = response.to_dict()
|
||||
|
||||
# temporary hack to remove the metrics from the response
|
||||
response["metrics"] = []
|
||||
|
||||
return convert_to_pydantic(ChatCompletionResponse, response)
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
content=response.completion_message.content.text,
|
||||
stop_reason=response.completion_message.stop_reason,
|
||||
tool_calls=response.completion_message.tool_calls,
|
||||
),
|
||||
logprobs=response.logprobs,
|
||||
)
|
||||
|
||||
async def _stream_chat_completion(self, json_params: Dict[str, Any]) -> AsyncGenerator:
|
||||
client = self._get_client()
|
||||
|
|
18
llama_stack/providers/remote/safety/nvidia/__init__.py
Normal file
18
llama_stack/providers/remote/safety/nvidia/__init__.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .config import NVIDIASafetyConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: NVIDIASafetyConfig, _deps) -> Any:
|
||||
from .nvidia import NVIDIASafetyAdapter
|
||||
|
||||
impl = NVIDIASafetyAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
37
llama_stack/providers/remote/safety/nvidia/config.py
Normal file
37
llama_stack/providers/remote/safety/nvidia/config.py
Normal file
|
@ -0,0 +1,37 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class NVIDIASafetyConfig(BaseModel):
|
||||
"""
|
||||
Configuration for the NVIDIA Guardrail microservice endpoint.
|
||||
|
||||
Attributes:
|
||||
guardrails_service_url (str): A base url for accessing the NVIDIA guardrail endpoint, e.g. http://0.0.0.0:7331
|
||||
config_id (str): The ID of the guardrails configuration to use from the configuration store
|
||||
(https://developer.nvidia.com/docs/nemo-microservices/guardrails/source/guides/configuration-store-guide.html)
|
||||
|
||||
"""
|
||||
|
||||
guardrails_service_url: str = Field(
|
||||
default_factory=lambda: os.getenv("GUARDRAILS_SERVICE_URL", "http://0.0.0.0:7331"),
|
||||
description="The url for accessing the guardrails service",
|
||||
)
|
||||
config_id: Optional[str] = Field(default="self-check", description="Config ID to use from the config store")
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||
return {
|
||||
"guardrails_service_url": "${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}",
|
||||
"config_id": "self-check",
|
||||
}
|
154
llama_stack/providers/remote/safety/nvidia/nvidia.py
Normal file
154
llama_stack/providers/remote/safety/nvidia/nvidia.py
Normal file
|
@ -0,0 +1,154 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import RunShieldResponse, Safety, SafetyViolation, ViolationLevel
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.distribution.library_client import convert_pydantic_to_json_value
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
|
||||
from .config import NVIDIASafetyConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||
def __init__(self, config: NVIDIASafetyConfig) -> None:
|
||||
"""
|
||||
Initialize the NVIDIASafetyAdapter with a given safety configuration.
|
||||
|
||||
Args:
|
||||
config (NVIDIASafetyConfig): The configuration containing the guardrails service URL and config ID.
|
||||
"""
|
||||
print(f"Initializing NVIDIASafetyAdapter({config.guardrails_service_url})...")
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_shield(self, shield: Shield) -> None:
|
||||
if not shield.provider_resource_id:
|
||||
raise ValueError("Shield model not provided.")
|
||||
|
||||
async def run_shield(
|
||||
self, shield_id: str, messages: List[Message], params: Optional[dict[str, Any]] = None
|
||||
) -> RunShieldResponse:
|
||||
"""
|
||||
Run a safety shield check against the provided messages.
|
||||
|
||||
Args:
|
||||
shield_id (str): The unique identifier for the shield to be used.
|
||||
messages (List[Message]): A list of Message objects representing the conversation history.
|
||||
params (Optional[dict[str, Any]]): Additional parameters for the shield check.
|
||||
|
||||
Returns:
|
||||
RunShieldResponse: The response containing safety violation details if any.
|
||||
|
||||
Raises:
|
||||
ValueError: If the shield with the provided shield_id is not found.
|
||||
"""
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
if not shield:
|
||||
raise ValueError(f"Shield {shield_id} not found")
|
||||
|
||||
self.shield = NeMoGuardrails(self.config, shield.shield_id)
|
||||
return await self.shield.run(messages)
|
||||
|
||||
|
||||
class NeMoGuardrails:
|
||||
"""
|
||||
A class that encapsulates NVIDIA's guardrails safety logic.
|
||||
|
||||
Sends messages to the guardrails service and interprets the response to determine
|
||||
if a safety violation has occurred.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: NVIDIASafetyConfig,
|
||||
model: str,
|
||||
threshold: float = 0.9,
|
||||
temperature: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Initialize a NeMoGuardrails instance with the provided parameters.
|
||||
|
||||
Args:
|
||||
config (NVIDIASafetyConfig): The safety configuration containing the config ID and guardrails URL.
|
||||
model (str): The identifier or name of the model to be used for safety checks.
|
||||
threshold (float, optional): The threshold for flagging violations. Defaults to 0.9.
|
||||
temperature (float, optional): The temperature setting for the underlying model. Must be greater than 0. Defaults to 1.0.
|
||||
|
||||
Raises:
|
||||
ValueError: If temperature is less than or equal to 0.
|
||||
AssertionError: If config_id is not provided in the configuration.
|
||||
"""
|
||||
self.config_id = config.config_id
|
||||
self.model = model
|
||||
assert self.config_id is not None, "Must provide config id"
|
||||
if temperature <= 0:
|
||||
raise ValueError("Temperature must be greater than 0")
|
||||
|
||||
self.temperature = temperature
|
||||
self.threshold = threshold
|
||||
self.guardrails_service_url = config.guardrails_service_url
|
||||
|
||||
async def run(self, messages: List[Message]) -> RunShieldResponse:
|
||||
"""
|
||||
Queries the /v1/guardrails/checks endpoint of the NeMo guardrails deployed API.
|
||||
|
||||
Args:
|
||||
messages (List[Message]): A list of Message objects to be checked for safety violations.
|
||||
|
||||
Returns:
|
||||
RunShieldResponse: If the response indicates a violation ("blocked" status), returns a
|
||||
RunShieldResponse with a SafetyViolation; otherwise, returns a RunShieldResponse with violation set to None.
|
||||
|
||||
Raises:
|
||||
requests.HTTPError: If the POST request fails.
|
||||
"""
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
}
|
||||
request_data = {
|
||||
"model": self.model,
|
||||
"messages": convert_pydantic_to_json_value(messages),
|
||||
"temperature": self.temperature,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"max_tokens": 160,
|
||||
"stream": False,
|
||||
"guardrails": {
|
||||
"config_id": self.config_id,
|
||||
},
|
||||
}
|
||||
response = requests.post(
|
||||
url=f"{self.guardrails_service_url}/v1/guardrail/checks", headers=headers, json=request_data
|
||||
)
|
||||
response.raise_for_status()
|
||||
if "Content-Type" in response.headers and response.headers["Content-Type"].startswith("application/json"):
|
||||
response_json = response.json()
|
||||
if response_json["status"] == "blocked":
|
||||
user_message = "Sorry I cannot do this."
|
||||
metadata = response_json["rails_status"]
|
||||
|
||||
return RunShieldResponse(
|
||||
violation=SafetyViolation(
|
||||
user_message=user_message,
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
return RunShieldResponse(violation=None)
|
Loading…
Add table
Add a link
Reference in a new issue