Merge branch 'main' into feat/litellm_sambanova_usage

This commit is contained in:
Jorge Piedrahita Ortiz 2025-03-18 12:06:58 -05:00 committed by GitHub
commit 5bd1bd30e2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
76 changed files with 3534 additions and 2843 deletions

View file

@ -4,13 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Optional
from urllib.parse import parse_qs, urlparse
import datasets as hf_datasets
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
from llama_stack.providers.utils.kvstore import kvstore_impl
from .config import HuggingfaceDatasetIOConfig
@ -18,22 +18,14 @@ from .config import HuggingfaceDatasetIOConfig
DATASETS_PREFIX = "datasets:"
def load_hf_dataset(dataset_def: Dataset):
if dataset_def.metadata.get("path", None):
dataset = hf_datasets.load_dataset(**dataset_def.metadata)
else:
df = get_dataframe_from_url(dataset_def.url)
def parse_hf_params(dataset_def: Dataset):
uri = dataset_def.source.uri
parsed_uri = urlparse(uri)
params = parse_qs(parsed_uri.query)
params = {k: v[0] for k, v in params.items()}
path = parsed_uri.path.lstrip("/")
if df is None:
raise ValueError(f"Failed to load dataset from {dataset_def.url}")
dataset = hf_datasets.Dataset.from_pandas(df)
# drop columns not specified by schema
if dataset_def.dataset_schema:
dataset = dataset.select_columns(list(dataset_def.dataset_schema.keys()))
return dataset
return path, params
class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
@ -64,7 +56,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
key = f"{DATASETS_PREFIX}{dataset_def.identifier}"
await self.kvstore.set(
key=key,
value=dataset_def.json(),
value=dataset_def.model_dump_json(),
)
self.dataset_infos[dataset_def.identifier] = dataset_def
@ -73,41 +65,34 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
await self.kvstore.delete(key=key)
del self.dataset_infos[dataset_id]
async def get_rows_paginated(
async def iterrows(
self,
dataset_id: str,
rows_in_page: int,
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult:
start_index: Optional[int] = None,
limit: Optional[int] = None,
) -> IterrowsResponse:
dataset_def = self.dataset_infos[dataset_id]
loaded_dataset = load_hf_dataset(dataset_def)
path, params = parse_hf_params(dataset_def)
loaded_dataset = hf_datasets.load_dataset(path, **params)
if page_token and not page_token.isnumeric():
raise ValueError("Invalid page_token")
start_index = start_index or 0
if page_token is None or len(page_token) == 0:
next_page_token = 0
else:
next_page_token = int(page_token)
start = next_page_token
if rows_in_page == -1:
if limit is None or limit == -1:
end = len(loaded_dataset)
else:
end = min(start + rows_in_page, len(loaded_dataset))
end = min(start_index + limit, len(loaded_dataset))
rows = [loaded_dataset[i] for i in range(start, end)]
rows = [loaded_dataset[i] for i in range(start_index, end)]
return PaginatedRowsResult(
rows=rows,
total_count=len(rows),
next_page_token=str(end),
return IterrowsResponse(
data=rows,
next_start_index=end if end < len(loaded_dataset) else None,
)
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
dataset_def = self.dataset_infos[dataset_id]
loaded_dataset = load_hf_dataset(dataset_def)
path, params = parse_hf_params(dataset_def)
loaded_dataset = hf_datasets.load_dataset(path, **params)
# Convert rows to HF Dataset format
new_dataset = hf_datasets.Dataset.from_list(rows)

View file

@ -12,6 +12,7 @@ from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
CompletionMessage,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
@ -160,12 +161,14 @@ class PassthroughInferenceAdapter(Inference):
client = self._get_client()
response = await client.inference.chat_completion(**json_params)
response = response.to_dict()
# temporary hack to remove the metrics from the response
response["metrics"] = []
return convert_to_pydantic(ChatCompletionResponse, response)
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=response.completion_message.content.text,
stop_reason=response.completion_message.stop_reason,
tool_calls=response.completion_message.tool_calls,
),
logprobs=response.logprobs,
)
async def _stream_chat_completion(self, json_params: Dict[str, Any]) -> AsyncGenerator:
client = self._get_client()

View file

@ -25,6 +25,10 @@ class VLLMInferenceAdapterConfig(BaseModel):
default="fake",
description="The API token",
)
tls_verify: bool = Field(
default=True,
description="Whether to verify TLS certificates",
)
@classmethod
def sample_run_config(
@ -36,4 +40,5 @@ class VLLMInferenceAdapterConfig(BaseModel):
"url": url,
"max_tokens": "${env.VLLM_MAX_TOKENS:4096}",
"api_token": "${env.VLLM_API_TOKEN:fake}",
"tls_verify": "${env.VLLM_TLS_VERIFY:true}",
}

View file

@ -7,6 +7,7 @@ import json
import logging
from typing import AsyncGenerator, List, Optional, Union
import httpx
from openai import AsyncOpenAI
from openai.types.chat.chat_completion_chunk import (
ChatCompletionChunk as OpenAIChatCompletionChunk,
@ -229,7 +230,11 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def initialize(self) -> None:
log.info(f"Initializing VLLM client with base_url={self.config.url}")
self.client = AsyncOpenAI(base_url=self.config.url, api_key=self.config.api_token)
self.client = AsyncOpenAI(
base_url=self.config.url,
api_key=self.config.api_token,
http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False),
)
async def shutdown(self) -> None:
pass

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import NVIDIASafetyConfig
async def get_adapter_impl(config: NVIDIASafetyConfig, _deps) -> Any:
from .nvidia import NVIDIASafetyAdapter
impl = NVIDIASafetyAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,37 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class NVIDIASafetyConfig(BaseModel):
"""
Configuration for the NVIDIA Guardrail microservice endpoint.
Attributes:
guardrails_service_url (str): A base url for accessing the NVIDIA guardrail endpoint, e.g. http://0.0.0.0:7331
config_id (str): The ID of the guardrails configuration to use from the configuration store
(https://developer.nvidia.com/docs/nemo-microservices/guardrails/source/guides/configuration-store-guide.html)
"""
guardrails_service_url: str = Field(
default_factory=lambda: os.getenv("GUARDRAILS_SERVICE_URL", "http://0.0.0.0:7331"),
description="The url for accessing the guardrails service",
)
config_id: Optional[str] = Field(default="self-check", description="Config ID to use from the config store")
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"guardrails_service_url": "${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}",
"config_id": "self-check",
}

View file

@ -0,0 +1,154 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import Any, List, Optional
import requests
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import RunShieldResponse, Safety, SafetyViolation, ViolationLevel
from llama_stack.apis.shields import Shield
from llama_stack.distribution.library_client import convert_pydantic_to_json_value
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from .config import NVIDIASafetyConfig
logger = logging.getLogger(__name__)
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
def __init__(self, config: NVIDIASafetyConfig) -> None:
"""
Initialize the NVIDIASafetyAdapter with a given safety configuration.
Args:
config (NVIDIASafetyConfig): The configuration containing the guardrails service URL and config ID.
"""
print(f"Initializing NVIDIASafetyAdapter({config.guardrails_service_url})...")
self.config = config
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def register_shield(self, shield: Shield) -> None:
if not shield.provider_resource_id:
raise ValueError("Shield model not provided.")
async def run_shield(
self, shield_id: str, messages: List[Message], params: Optional[dict[str, Any]] = None
) -> RunShieldResponse:
"""
Run a safety shield check against the provided messages.
Args:
shield_id (str): The unique identifier for the shield to be used.
messages (List[Message]): A list of Message objects representing the conversation history.
params (Optional[dict[str, Any]]): Additional parameters for the shield check.
Returns:
RunShieldResponse: The response containing safety violation details if any.
Raises:
ValueError: If the shield with the provided shield_id is not found.
"""
shield = await self.shield_store.get_shield(shield_id)
if not shield:
raise ValueError(f"Shield {shield_id} not found")
self.shield = NeMoGuardrails(self.config, shield.shield_id)
return await self.shield.run(messages)
class NeMoGuardrails:
"""
A class that encapsulates NVIDIA's guardrails safety logic.
Sends messages to the guardrails service and interprets the response to determine
if a safety violation has occurred.
"""
def __init__(
self,
config: NVIDIASafetyConfig,
model: str,
threshold: float = 0.9,
temperature: float = 1.0,
):
"""
Initialize a NeMoGuardrails instance with the provided parameters.
Args:
config (NVIDIASafetyConfig): The safety configuration containing the config ID and guardrails URL.
model (str): The identifier or name of the model to be used for safety checks.
threshold (float, optional): The threshold for flagging violations. Defaults to 0.9.
temperature (float, optional): The temperature setting for the underlying model. Must be greater than 0. Defaults to 1.0.
Raises:
ValueError: If temperature is less than or equal to 0.
AssertionError: If config_id is not provided in the configuration.
"""
self.config_id = config.config_id
self.model = model
assert self.config_id is not None, "Must provide config id"
if temperature <= 0:
raise ValueError("Temperature must be greater than 0")
self.temperature = temperature
self.threshold = threshold
self.guardrails_service_url = config.guardrails_service_url
async def run(self, messages: List[Message]) -> RunShieldResponse:
"""
Queries the /v1/guardrails/checks endpoint of the NeMo guardrails deployed API.
Args:
messages (List[Message]): A list of Message objects to be checked for safety violations.
Returns:
RunShieldResponse: If the response indicates a violation ("blocked" status), returns a
RunShieldResponse with a SafetyViolation; otherwise, returns a RunShieldResponse with violation set to None.
Raises:
requests.HTTPError: If the POST request fails.
"""
headers = {
"Accept": "application/json",
}
request_data = {
"model": self.model,
"messages": convert_pydantic_to_json_value(messages),
"temperature": self.temperature,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"max_tokens": 160,
"stream": False,
"guardrails": {
"config_id": self.config_id,
},
}
response = requests.post(
url=f"{self.guardrails_service_url}/v1/guardrail/checks", headers=headers, json=request_data
)
response.raise_for_status()
if "Content-Type" in response.headers and response.headers["Content-Type"].startswith("application/json"):
response_json = response.json()
if response_json["status"] == "blocked":
user_message = "Sorry I cannot do this."
metadata = response_json["rails_status"]
return RunShieldResponse(
violation=SafetyViolation(
user_message=user_message,
violation_level=ViolationLevel.ERROR,
metadata=metadata,
)
)
return RunShieldResponse(violation=None)