Merge branch 'main' into eval_api_final

This commit is contained in:
Xi Yan 2025-03-17 17:00:30 -07:00
commit 66cd83fb58
37 changed files with 1215 additions and 840 deletions

View file

@ -44,7 +44,9 @@ class PandasDataframeDataset:
elif self.dataset_def.source.type == "rows":
self.df = pandas.DataFrame(self.dataset_def.source.rows)
else:
raise ValueError(f"Unsupported dataset source type: {self.dataset_def.source.type}")
raise ValueError(
f"Unsupported dataset source type: {self.dataset_def.source.type}"
)
if self.df is None:
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
@ -108,7 +110,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
return IterrowsResponse(
data=rows,
next_index=end if end < len(dataset_impl) else None,
next_start_index=end if end < len(dataset_impl) else None,
)
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
@ -117,4 +119,6 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
dataset_impl.load()
new_rows_df = pandas.DataFrame(rows)
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)
dataset_impl.df = pandas.concat(
[dataset_impl.df, new_rows_df], ignore_index=True
)

View file

@ -55,4 +55,13 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig",
),
),
remote_provider_spec(
api=Api.safety,
adapter=AdapterSpec(
adapter_type="nvidia",
pip_packages=["requests"],
module="llama_stack.providers.remote.safety.nvidia",
config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig",
),
),
]

View file

@ -86,7 +86,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
return IterrowsResponse(
data=rows,
next_index=end if end < len(loaded_dataset) else None,
next_start_index=end if end < len(loaded_dataset) else None,
)
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
@ -98,9 +98,13 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
new_dataset = hf_datasets.Dataset.from_list(rows)
# Concatenate the new rows with existing dataset
updated_dataset = hf_datasets.concatenate_datasets([loaded_dataset, new_dataset])
updated_dataset = hf_datasets.concatenate_datasets(
[loaded_dataset, new_dataset]
)
if dataset_def.metadata.get("path", None):
updated_dataset.push_to_hub(dataset_def.metadata["path"])
else:
raise NotImplementedError("Uploading to URL-based datasets is not supported yet")
raise NotImplementedError(
"Uploading to URL-based datasets is not supported yet"
)

View file

@ -12,6 +12,7 @@ from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
CompletionMessage,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
@ -160,12 +161,14 @@ class PassthroughInferenceAdapter(Inference):
client = self._get_client()
response = await client.inference.chat_completion(**json_params)
response = response.to_dict()
# temporary hack to remove the metrics from the response
response["metrics"] = []
return convert_to_pydantic(ChatCompletionResponse, response)
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=response.completion_message.content.text,
stop_reason=response.completion_message.stop_reason,
tool_calls=response.completion_message.tool_calls,
),
logprobs=response.logprobs,
)
async def _stream_chat_completion(self, json_params: Dict[str, Any]) -> AsyncGenerator:
client = self._get_client()

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import NVIDIASafetyConfig
async def get_adapter_impl(config: NVIDIASafetyConfig, _deps) -> Any:
from .nvidia import NVIDIASafetyAdapter
impl = NVIDIASafetyAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,37 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class NVIDIASafetyConfig(BaseModel):
"""
Configuration for the NVIDIA Guardrail microservice endpoint.
Attributes:
guardrails_service_url (str): A base url for accessing the NVIDIA guardrail endpoint, e.g. http://0.0.0.0:7331
config_id (str): The ID of the guardrails configuration to use from the configuration store
(https://developer.nvidia.com/docs/nemo-microservices/guardrails/source/guides/configuration-store-guide.html)
"""
guardrails_service_url: str = Field(
default_factory=lambda: os.getenv("GUARDRAILS_SERVICE_URL", "http://0.0.0.0:7331"),
description="The url for accessing the guardrails service",
)
config_id: Optional[str] = Field(default="self-check", description="Config ID to use from the config store")
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"guardrails_service_url": "${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}",
"config_id": "self-check",
}

View file

@ -0,0 +1,154 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import Any, List, Optional
import requests
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import RunShieldResponse, Safety, SafetyViolation, ViolationLevel
from llama_stack.apis.shields import Shield
from llama_stack.distribution.library_client import convert_pydantic_to_json_value
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from .config import NVIDIASafetyConfig
logger = logging.getLogger(__name__)
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
def __init__(self, config: NVIDIASafetyConfig) -> None:
"""
Initialize the NVIDIASafetyAdapter with a given safety configuration.
Args:
config (NVIDIASafetyConfig): The configuration containing the guardrails service URL and config ID.
"""
print(f"Initializing NVIDIASafetyAdapter({config.guardrails_service_url})...")
self.config = config
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def register_shield(self, shield: Shield) -> None:
if not shield.provider_resource_id:
raise ValueError("Shield model not provided.")
async def run_shield(
self, shield_id: str, messages: List[Message], params: Optional[dict[str, Any]] = None
) -> RunShieldResponse:
"""
Run a safety shield check against the provided messages.
Args:
shield_id (str): The unique identifier for the shield to be used.
messages (List[Message]): A list of Message objects representing the conversation history.
params (Optional[dict[str, Any]]): Additional parameters for the shield check.
Returns:
RunShieldResponse: The response containing safety violation details if any.
Raises:
ValueError: If the shield with the provided shield_id is not found.
"""
shield = await self.shield_store.get_shield(shield_id)
if not shield:
raise ValueError(f"Shield {shield_id} not found")
self.shield = NeMoGuardrails(self.config, shield.shield_id)
return await self.shield.run(messages)
class NeMoGuardrails:
"""
A class that encapsulates NVIDIA's guardrails safety logic.
Sends messages to the guardrails service and interprets the response to determine
if a safety violation has occurred.
"""
def __init__(
self,
config: NVIDIASafetyConfig,
model: str,
threshold: float = 0.9,
temperature: float = 1.0,
):
"""
Initialize a NeMoGuardrails instance with the provided parameters.
Args:
config (NVIDIASafetyConfig): The safety configuration containing the config ID and guardrails URL.
model (str): The identifier or name of the model to be used for safety checks.
threshold (float, optional): The threshold for flagging violations. Defaults to 0.9.
temperature (float, optional): The temperature setting for the underlying model. Must be greater than 0. Defaults to 1.0.
Raises:
ValueError: If temperature is less than or equal to 0.
AssertionError: If config_id is not provided in the configuration.
"""
self.config_id = config.config_id
self.model = model
assert self.config_id is not None, "Must provide config id"
if temperature <= 0:
raise ValueError("Temperature must be greater than 0")
self.temperature = temperature
self.threshold = threshold
self.guardrails_service_url = config.guardrails_service_url
async def run(self, messages: List[Message]) -> RunShieldResponse:
"""
Queries the /v1/guardrails/checks endpoint of the NeMo guardrails deployed API.
Args:
messages (List[Message]): A list of Message objects to be checked for safety violations.
Returns:
RunShieldResponse: If the response indicates a violation ("blocked" status), returns a
RunShieldResponse with a SafetyViolation; otherwise, returns a RunShieldResponse with violation set to None.
Raises:
requests.HTTPError: If the POST request fails.
"""
headers = {
"Accept": "application/json",
}
request_data = {
"model": self.model,
"messages": convert_pydantic_to_json_value(messages),
"temperature": self.temperature,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"max_tokens": 160,
"stream": False,
"guardrails": {
"config_id": self.config_id,
},
}
response = requests.post(
url=f"{self.guardrails_service_url}/v1/guardrail/checks", headers=headers, json=request_data
)
response.raise_for_status()
if "Content-Type" in response.headers and response.headers["Content-Type"].startswith("application/json"):
response_json = response.json()
if response_json["status"] == "blocked":
user_message = "Sorry I cannot do this."
metadata = response_json["rails_status"]
return RunShieldResponse(
violation=SafetyViolation(
user_message=user_message,
violation_level=ViolationLevel.ERROR,
metadata=metadata,
)
)
return RunShieldResponse(violation=None)