Merge branch 'main' into nvidia-e2e-notebook

This commit is contained in:
Jash Gulabrai 2025-05-28 17:48:15 -04:00
commit f5cb965f0f
226 changed files with 16519 additions and 8666 deletions

View file

@ -92,8 +92,11 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
if prompt_logprobs is not None:
logging.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.")
model_id = (await self.model_store.get_model(model)).provider_resource_id
if model_id.startswith("openai/"):
model_id = model_id[len("openai/") :]
params = await prepare_openai_completion_params(
model=(await self.model_store.get_model(model)).provider_resource_id,
model=model_id,
prompt=prompt,
best_of=best_of,
echo=echo,
@ -139,8 +142,11 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
model_id = (await self.model_store.get_model(model)).provider_resource_id
if model_id.startswith("openai/"):
model_id = model_id[len("openai/") :]
params = await prepare_openai_completion_params(
model=(await self.model_store.get_model(model)).provider_resource_id,
model=model_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,

View file

@ -4,8 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from llama_stack.schema_utils import json_schema_type
@ -24,11 +25,27 @@ class VLLMInferenceAdapterConfig(BaseModel):
default="fake",
description="The API token",
)
tls_verify: bool = Field(
tls_verify: bool | str = Field(
default=True,
description="Whether to verify TLS certificates",
description="Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file.",
)
@field_validator("tls_verify")
@classmethod
def validate_tls_verify(cls, v):
if isinstance(v, str):
# Check if it's a boolean string
if v.lower() in ("true", "false"):
return v.lower() == "true"
# Otherwise, treat it as a cert path
cert_path = Path(v).expanduser().resolve()
if not cert_path.exists():
raise ValueError(f"TLS certificate file does not exist: {v}")
if not cert_path.is_file():
raise ValueError(f"TLS certificate path is not a file: {v}")
return v
return v
@classmethod
def sample_run_config(
cls,

View file

@ -313,7 +313,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
return AsyncOpenAI(
base_url=self.config.url,
api_key=self.config.api_token,
http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False),
http_client=httpx.AsyncClient(verify=self.config.tls_verify),
)
async def completion(

View file

@ -224,7 +224,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
Parameters:
training_config: TrainingConfig - Configuration for training
model: str - Model identifier
model: str - NeMo Customizer configuration name
algorithm_config: Optional[AlgorithmConfig] - Algorithm-specific configuration
checkpoint_dir: Optional[str] - Directory containing model checkpoints, ignored atm
job_uuid: str - Unique identifier for the job, ignored atm
@ -299,9 +299,6 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
User is informed about unsupported parameters via warnings.
"""
# Map model to nvidia model name
# See `_MODEL_ENTRIES` for supported models
nvidia_model = self.get_provider_model_id(model)
# Check for unsupported method parameters
unsupported_method_params = []
@ -347,7 +344,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
# Prepare base job configuration
job_config = {
"config": nvidia_model,
"config": model,
"dataset": {
"name": training_config["data_config"]["dataset_id"],
"namespace": self.config.dataset_namespace,

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import SambaNovaSafetyConfig
async def get_adapter_impl(config: SambaNovaSafetyConfig, _deps) -> Any:
from .sambanova import SambaNovaSafetyAdapter
impl = SambaNovaSafetyAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,37 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from llama_stack.schema_utils import json_schema_type
class SambaNovaProviderDataValidator(BaseModel):
sambanova_api_key: str | None = Field(
default=None,
description="Sambanova Cloud API key",
)
@json_schema_type
class SambaNovaSafetyConfig(BaseModel):
url: str = Field(
default="https://api.sambanova.ai/v1",
description="The URL for the SambaNova AI server",
)
api_key: SecretStr | None = Field(
default=None,
description="The SambaNova cloud API Key",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> dict[str, Any]:
return {
"url": "https://api.sambanova.ai/v1",
"api_key": api_key,
}

View file

@ -0,0 +1,100 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import logging
from typing import Any
import litellm
import requests
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
SafetyViolation,
ViolationLevel,
)
from llama_stack.apis.shields import Shield
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
from .config import SambaNovaSafetyConfig
logger = logging.getLogger(__name__)
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProviderData):
def __init__(self, config: SambaNovaSafetyConfig) -> None:
self.config = config
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
def _get_api_key(self) -> str:
config_api_key = self.config.api_key if self.config.api_key else None
if config_api_key:
return config_api_key.get_secret_value()
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.sambanova_api_key:
raise ValueError(
'Pass Sambanova API Key in the header X-LlamaStack-Provider-Data as { "sambanova_api_key": <your api key> }'
)
return provider_data.sambanova_api_key
async def register_shield(self, shield: Shield) -> None:
list_models_url = self.config.url + "/models"
try:
response = requests.get(list_models_url)
response.raise_for_status()
except requests.exceptions.RequestException as e:
raise RuntimeError(f"Request to {list_models_url} failed") from e
available_models = [model.get("id") for model in response.json().get("data", {})]
if (
len(available_models) == 0
or "guard" not in shield.provider_resource_id.lower()
or shield.provider_resource_id.split("sambanova/")[-1] not in available_models
):
raise ValueError(f"Shield {shield.provider_resource_id} not found in SambaNova")
async def run_shield(
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id)
if not shield:
raise ValueError(f"Shield {shield_id} not found")
shield_params = shield.params
logger.debug(f"run_shield::{shield_params}::messages={messages}")
content_messages = [await convert_message_to_openai_dict_new(m) for m in messages]
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
response = litellm.completion(
model=shield.provider_resource_id, messages=content_messages, api_key=self._get_api_key()
)
shield_message = response.choices[0].message.content
if "unsafe" in shield_message.lower():
user_message = CANNED_RESPONSE_TEXT
violation_type = shield_message.split("\n")[-1]
metadata = {"violation_type": violation_type}
return RunShieldResponse(
violation=SafetyViolation(
user_message=user_message,
violation_level=ViolationLevel.ERROR,
metadata=metadata,
)
)
return RunShieldResponse()

View file

@ -12,19 +12,19 @@ import httpx
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
ListToolDefsResponse,
Tool,
ToolDef,
ToolGroup,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from .config import BingSearchToolConfig
class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
class BingSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: BingSearchToolConfig):
self.config = config
self.url = "https://api.bing.microsoft.com/v7.0/search"
@ -32,10 +32,10 @@ class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestP
async def initialize(self):
pass
async def register_tool(self, tool: Tool) -> None:
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass
async def unregister_tool(self, tool_id: str) -> None:
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return
def _get_api_key(self) -> str:

View file

@ -11,30 +11,30 @@ import httpx
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
ListToolDefsResponse,
Tool,
ToolDef,
ToolGroup,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.models.llama.datatypes import BuiltinTool
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from .config import BraveSearchToolConfig
class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
class BraveSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: BraveSearchToolConfig):
self.config = config
async def initialize(self):
pass
async def register_tool(self, tool: Tool) -> None:
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass
async def unregister_tool(self, tool_id: str) -> None:
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return
def _get_api_key(self) -> str:

View file

@ -4,18 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from .config import ModelContextProtocolConfig
from .config import MCPProviderConfig
class ModelContextProtocolToolProviderDataValidator(BaseModel):
api_key: str
async def get_adapter_impl(config: ModelContextProtocolConfig, _deps):
async def get_adapter_impl(config: MCPProviderConfig, _deps):
from .model_context_protocol import ModelContextProtocolToolRuntimeImpl
impl = ModelContextProtocolToolRuntimeImpl(config)
impl = ModelContextProtocolToolRuntimeImpl(config, _deps)
await impl.initialize()
return impl

View file

@ -9,7 +9,12 @@ from typing import Any
from pydantic import BaseModel
class ModelContextProtocolConfig(BaseModel):
class MCPProviderDataValidator(BaseModel):
# mcp_endpoint => dict of headers to send
mcp_headers: dict[str, dict[str, str]] | None = None
class MCPProviderConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {}

View file

@ -7,61 +7,45 @@
from typing import Any
from urllib.parse import urlparse
from mcp import ClientSession
from mcp.client.sse import sse_client
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.datatypes import Api
from llama_stack.apis.tools import (
ListToolDefsResponse,
ToolDef,
ToolGroup,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool, list_mcp_tools
from .config import ModelContextProtocolConfig
from .config import MCPProviderConfig
logger = get_logger(__name__, category="tools")
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
def __init__(self, config: ModelContextProtocolConfig):
class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
self.config = config
async def initialize(self):
pass
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolDefsResponse:
# this endpoint should be retrieved by getting the tool group right?
if mcp_endpoint is None:
raise ValueError("mcp_endpoint is required")
tools = []
async with sse_client(mcp_endpoint.uri) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
tools_result = await session.list_tools()
for tool in tools_result.tools:
parameters = []
for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
parameters.append(
ToolParameter(
name=param_name,
parameter_type=param_schema.get("type", "string"),
description=param_schema.get("description", ""),
)
)
tools.append(
ToolDef(
name=tool.name,
description=tool.description,
parameters=parameters,
metadata={
"endpoint": mcp_endpoint.uri,
},
)
)
return ListToolDefsResponse(data=tools)
headers = await self.get_headers_from_request(mcp_endpoint.uri)
return await list_mcp_tools(mcp_endpoint.uri, headers)
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_name)
@ -71,12 +55,19 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
if urlparse(endpoint).scheme not in ("http", "https"):
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
async with sse_client(endpoint) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
result = await session.call_tool(tool.identifier, kwargs)
headers = await self.get_headers_from_request(endpoint)
return await invoke_mcp_tool(endpoint, headers, tool_name, kwargs)
return ToolInvocationResult(
content="\n".join([result.model_dump_json() for result in result.content]),
error_code=1 if result.isError else 0,
)
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]:
def canonicalize_uri(uri: str) -> str:
return f"{urlparse(uri).netloc or ''}/{urlparse(uri).path or ''}"
headers = {}
provider_data = self.get_request_provider_data()
if provider_data and provider_data.mcp_headers:
for uri, values in provider_data.mcp_headers.items():
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
continue
headers.update(values)
return headers

View file

@ -12,29 +12,29 @@ import httpx
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
ListToolDefsResponse,
Tool,
ToolDef,
ToolGroup,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from .config import TavilySearchToolConfig
class TavilySearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
class TavilySearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: TavilySearchToolConfig):
self.config = config
async def initialize(self):
pass
async def register_tool(self, tool: Tool) -> None:
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass
async def unregister_tool(self, tool_id: str) -> None:
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return
def _get_api_key(self) -> str:

View file

@ -12,19 +12,19 @@ import httpx
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
ListToolDefsResponse,
Tool,
ToolDef,
ToolGroup,
ToolInvocationResult,
ToolParameter,
ToolRuntime,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from .config import WolframAlphaToolConfig
class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
class WolframAlphaToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: WolframAlphaToolConfig):
self.config = config
self.url = "https://api.wolframalpha.com/v2/query"
@ -32,10 +32,10 @@ class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
async def initialize(self):
pass
async def register_tool(self, tool: Tool) -> None:
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass
async def unregister_tool(self, tool_id: str) -> None:
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return
def _get_api_key(self) -> str:

View file

@ -84,6 +84,14 @@ class ChromaIndex(EmbeddingIndex):
async def delete(self):
await maybe_await(self.client.delete_collection(self.collection.name))
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Chroma")
class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(

View file

@ -73,7 +73,7 @@ class MilvusIndex(EmbeddingIndex):
logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
raise e
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
search_res = await asyncio.to_thread(
self.client.search,
collection_name=self.collection_name,
@ -86,6 +86,14 @@ class MilvusIndex(EmbeddingIndex):
scores = [res["distance"] for res in search_res[0]]
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Milvus")
class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(

View file

@ -99,7 +99,7 @@ class PGVectorIndex(EmbeddingIndex):
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
execute_values(cur, query, values, template="(%s, %s, %s::vector)")
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(
f"""
@ -120,6 +120,14 @@ class PGVectorIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in PGVector")
async def delete(self):
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")

View file

@ -68,7 +68,7 @@ class QdrantIndex(EmbeddingIndex):
await self.client.upsert(collection_name=self.collection_name, points=points)
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
results = (
await self.client.query_points(
collection_name=self.collection_name,
@ -95,6 +95,14 @@ class QdrantIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Qdrant")
async def delete(self):
await self.client.delete_collection(collection_name=self.collection_name)

View file

@ -55,7 +55,7 @@ class WeaviateIndex(EmbeddingIndex):
# TODO: make this async friendly
collection.data.insert_many(data_objects)
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
collection = self.client.collections.get(self.collection_name)
results = collection.query.near_vector(
@ -84,6 +84,14 @@ class WeaviateIndex(EmbeddingIndex):
collection = self.client.collections.get(self.collection_name)
collection.data.delete_many(where=Filter.by_property("id").contains_any(chunk_ids))
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Weaviate")
class WeaviateVectorIOAdapter(
VectorIO,