Merge branch 'main' into nvidia-e2e-notebook

This commit is contained in:
Jash Gulabrai 2025-04-30 12:05:11 -04:00
commit 012dd6891f
96 changed files with 4675 additions and 426 deletions

View file

@ -38,6 +38,13 @@ from llama_stack.apis.safety import SafetyViolation
from llama_stack.apis.tools import ToolDef
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
from .openai_responses import (
OpenAIResponseInputMessage,
OpenAIResponseInputTool,
OpenAIResponseObject,
OpenAIResponseObjectStream,
)
class Attachment(BaseModel):
"""An attachment to an agent turn.
@ -593,3 +600,39 @@ class Agents(Protocol):
:returns: A ListAgentSessionsResponse.
"""
...
# We situate the OpenAI Responses API in the Agents API just like we did things
# for Inference. The Responses API, in its intent, serves the same purpose as
# the Agents API above -- it is essentially a lightweight "agentic loop" with
# integrated tool calling.
#
# Both of these APIs are inherently stateful.
@webmethod(route="/openai/v1/responses/{id}", method="GET")
async def get_openai_response(
self,
id: str,
) -> OpenAIResponseObject:
"""Retrieve an OpenAI response by its ID.
:param id: The ID of the OpenAI response to retrieve.
:returns: An OpenAIResponseObject.
"""
...
@webmethod(route="/openai/v1/responses", method="POST")
async def create_openai_response(
self,
input: Union[str, List[OpenAIResponseInputMessage]],
model: str,
previous_response_id: Optional[str] = None,
store: Optional[bool] = True,
stream: Optional[bool] = False,
tools: Optional[List[OpenAIResponseInputTool]] = None,
) -> Union[OpenAIResponseObject, AsyncIterator[OpenAIResponseObjectStream]]:
"""Create a new OpenAI response.
:param input: Input message(s) to create the response.
:param model: The underlying LLM used for completions.
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
"""

View file

@ -0,0 +1,140 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Literal, Optional, Union
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.schema_utils import json_schema_type, register_schema
@json_schema_type
class OpenAIResponseError(BaseModel):
code: str
message: str
@json_schema_type
class OpenAIResponseOutputMessageContentOutputText(BaseModel):
text: str
type: Literal["output_text"] = "output_text"
OpenAIResponseOutputMessageContent = Annotated[
Union[OpenAIResponseOutputMessageContentOutputText,],
Field(discriminator="type"),
]
register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMessageContent")
@json_schema_type
class OpenAIResponseOutputMessage(BaseModel):
id: str
content: List[OpenAIResponseOutputMessageContent]
role: Literal["assistant"] = "assistant"
status: str
type: Literal["message"] = "message"
@json_schema_type
class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
id: str
status: str
type: Literal["web_search_call"] = "web_search_call"
OpenAIResponseOutput = Annotated[
Union[
OpenAIResponseOutputMessage,
OpenAIResponseOutputMessageWebSearchToolCall,
],
Field(discriminator="type"),
]
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
@json_schema_type
class OpenAIResponseObject(BaseModel):
created_at: int
error: Optional[OpenAIResponseError] = None
id: str
model: str
object: Literal["response"] = "response"
output: List[OpenAIResponseOutput]
parallel_tool_calls: bool = False
previous_response_id: Optional[str] = None
status: str
temperature: Optional[float] = None
top_p: Optional[float] = None
truncation: Optional[str] = None
user: Optional[str] = None
@json_schema_type
class OpenAIResponseObjectStreamResponseCreated(BaseModel):
response: OpenAIResponseObject
type: Literal["response.created"] = "response.created"
@json_schema_type
class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
response: OpenAIResponseObject
type: Literal["response.completed"] = "response.completed"
OpenAIResponseObjectStream = Annotated[
Union[
OpenAIResponseObjectStreamResponseCreated,
OpenAIResponseObjectStreamResponseCompleted,
],
Field(discriminator="type"),
]
register_schema(OpenAIResponseObjectStream, name="OpenAIResponseObjectStream")
@json_schema_type
class OpenAIResponseInputMessageContentText(BaseModel):
text: str
type: Literal["input_text"] = "input_text"
@json_schema_type
class OpenAIResponseInputMessageContentImage(BaseModel):
detail: Literal["low"] | Literal["high"] | Literal["auto"] = "auto"
type: Literal["input_image"] = "input_image"
# TODO: handle file_id
image_url: Optional[str] = None
# TODO: handle file content types
OpenAIResponseInputMessageContent = Annotated[
Union[OpenAIResponseInputMessageContentText, OpenAIResponseInputMessageContentImage],
Field(discriminator="type"),
]
register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMessageContent")
@json_schema_type
class OpenAIResponseInputMessage(BaseModel):
content: Union[str, List[OpenAIResponseInputMessageContent]]
role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"]
type: Optional[Literal["message"]] = "message"
@json_schema_type
class OpenAIResponseInputToolWebSearch(BaseModel):
type: Literal["web_search"] | Literal["web_search_preview_2025_03_11"] = "web_search"
# TODO: actually use search_context_size somewhere...
search_context_size: Optional[str] = Field(default="medium", pattern="^low|medium|high$")
# TODO: add user_location
OpenAIResponseInputTool = Annotated[
Union[OpenAIResponseInputToolWebSearch,],
Field(discriminator="type"),
]
register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool")

View file

@ -460,15 +460,17 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
from llama_stack.models.llama.sku_list import llama_meta_net_info, resolve_model
from .model.safety_models import (
prompt_guard_download_info,
prompt_guard_model_sku,
prompt_guard_download_info_map,
prompt_guard_model_sku_map,
)
prompt_guard = prompt_guard_model_sku()
prompt_guard_model_sku_map = prompt_guard_model_sku_map()
prompt_guard_download_info_map = prompt_guard_download_info_map()
for model_id in model_ids:
if model_id == prompt_guard.model_id:
model = prompt_guard
info = prompt_guard_download_info()
if model_id in prompt_guard_model_sku_map.keys():
model = prompt_guard_model_sku_map[model_id]
info = prompt_guard_download_info_map[model_id]
else:
model = resolve_model(model_id)
if model is None:

View file

@ -36,11 +36,11 @@ class ModelDescribe(Subcommand):
)
def _run_model_describe_cmd(self, args: argparse.Namespace) -> None:
from .safety_models import prompt_guard_model_sku
from .safety_models import prompt_guard_model_sku_map
prompt_guard = prompt_guard_model_sku()
if args.model_id == prompt_guard.model_id:
model = prompt_guard
prompt_guard_model_map = prompt_guard_model_sku_map()
if args.model_id in prompt_guard_model_map.keys():
model = prompt_guard_model_map[args.model_id]
else:
model = resolve_model(args.model_id)

View file

@ -84,7 +84,7 @@ class ModelList(Subcommand):
)
def _run_model_list_cmd(self, args: argparse.Namespace) -> None:
from .safety_models import prompt_guard_model_sku
from .safety_models import prompt_guard_model_skus
if args.downloaded:
return _run_model_list_downloaded_cmd()
@ -96,7 +96,7 @@ class ModelList(Subcommand):
]
rows = []
for model in all_registered_models() + [prompt_guard_model_sku()]:
for model in all_registered_models() + prompt_guard_model_skus():
if not args.show_all and not model.is_featured:
continue

View file

@ -42,11 +42,12 @@ class ModelRemove(Subcommand):
)
def _run_model_remove_cmd(self, args: argparse.Namespace) -> None:
from .safety_models import prompt_guard_model_sku
from .safety_models import prompt_guard_model_sku_map
prompt_guard = prompt_guard_model_sku()
if args.model == prompt_guard.model_id:
model = prompt_guard
prompt_guard_model_map = prompt_guard_model_sku_map()
if args.model in prompt_guard_model_map.keys():
model = prompt_guard_model_map[args.model]
else:
model = resolve_model(args.model)

View file

@ -15,11 +15,11 @@ from llama_stack.models.llama.sku_types import CheckpointQuantizationFormat
class PromptGuardModel(BaseModel):
"""Make a 'fake' Model-like object for Prompt Guard. Eventually this will be removed."""
model_id: str = "Prompt-Guard-86M"
model_id: str
huggingface_repo: str
description: str = "Prompt Guard. NOTE: this model will not be provided via `llama` CLI soon."
is_featured: bool = False
huggingface_repo: str = "meta-llama/Prompt-Guard-86M"
max_seq_length: int = 2048
max_seq_length: int = 512
is_instruct_model: bool = False
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
arch_args: Dict[str, Any] = Field(default_factory=dict)
@ -30,18 +30,35 @@ class PromptGuardModel(BaseModel):
model_config = ConfigDict(protected_namespaces=())
def prompt_guard_model_sku():
return PromptGuardModel()
def prompt_guard_model_skus():
return [
PromptGuardModel(model_id="Prompt-Guard-86M", huggingface_repo="meta-llama/Prompt-Guard-86M"),
PromptGuardModel(
model_id="Llama-Prompt-Guard-2-86M",
huggingface_repo="meta-llama/Llama-Prompt-Guard-2-86M",
),
PromptGuardModel(
model_id="Llama-Prompt-Guard-2-22M",
huggingface_repo="meta-llama/Llama-Prompt-Guard-2-22M",
),
]
def prompt_guard_download_info():
return LlamaDownloadInfo(
folder="Prompt-Guard",
files=[
"model.safetensors",
"special_tokens_map.json",
"tokenizer.json",
"tokenizer_config.json",
],
pth_size=1,
)
def prompt_guard_model_sku_map() -> Dict[str, Any]:
return {model.model_id: model for model in prompt_guard_model_skus()}
def prompt_guard_download_info_map() -> Dict[str, LlamaDownloadInfo]:
return {
model.model_id: LlamaDownloadInfo(
folder="Prompt-Guard" if model.model_id == "Prompt-Guard-86M" else model.model_id,
files=[
"model.safetensors",
"special_tokens_map.json",
"tokenizer.json",
"tokenizer_config.json",
],
pth_size=1,
)
for model in prompt_guard_model_skus()
}

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Annotated, Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field
@ -235,10 +236,21 @@ class LoggingConfig(BaseModel):
)
class AuthProviderType(str, Enum):
"""Supported authentication provider types."""
KUBERNETES = "kubernetes"
CUSTOM = "custom"
class AuthenticationConfig(BaseModel):
endpoint: str = Field(
provider_type: AuthProviderType = Field(
...,
description="Endpoint URL to validate authentication tokens",
description="Type of authentication provider (e.g., 'kubernetes', 'custom')",
)
config: Dict[str, str] = Field(
...,
description="Provider-specific configuration",
)

View file

@ -5,74 +5,29 @@
# the root directory of this source tree.
import json
from typing import Dict, List, Optional
from urllib.parse import parse_qs
import httpx
from pydantic import BaseModel, Field
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.distribution.server.auth_providers import AuthProviderConfig, create_auth_provider
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
class AuthRequestContext(BaseModel):
path: str = Field(description="The path of the request being authenticated")
headers: Dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
params: Dict[str, List[str]] = Field(
description="Query parameters from the original request, parsed as dictionary of lists"
)
class AuthRequest(BaseModel):
api_key: str = Field(description="The API key extracted from the Authorization header")
request: AuthRequestContext = Field(description="Context information about the request being authenticated")
class AuthResponse(BaseModel):
"""The format of the authentication response from the auth endpoint."""
access_attributes: Optional[AccessAttributes] = Field(
default=None,
description="""
Structured user attributes for attribute-based access control.
These attributes determine which resources the user can access.
The model provides standard categories like "roles", "teams", "projects", and "namespaces".
Each attribute category contains a list of values that the user has for that category.
During access control checks, these values are compared against resource requirements.
Example with standard categories:
```json
{
"roles": ["admin", "data-scientist"],
"teams": ["ml-team"],
"projects": ["llama-3"],
"namespaces": ["research"]
}
```
""",
)
message: Optional[str] = Field(
default=None, description="Optional message providing additional context about the authentication result."
)
class AuthenticationMiddleware:
"""Middleware that authenticates requests using an external auth endpoint.
"""Middleware that authenticates requests using configured authentication provider.
This middleware:
1. Extracts the Bearer token from the Authorization header
2. Sends it to the configured auth endpoint along with request details
3. Validates the response and extracts user attributes
2. Uses the configured auth provider to validate the token
3. Extracts user attributes from the provider's response
4. Makes these attributes available to the route handlers for access control
Authentication Request Format:
The middleware supports multiple authentication providers through the AuthProvider interface:
- Kubernetes: Validates tokens against the Kubernetes API server
- Custom: Validates tokens against a custom endpoint
Authentication Request Format for Custom Auth Provider:
```json
{
"api_key": "the-api-key-extracted-from-auth-header",
@ -105,21 +60,26 @@ class AuthenticationMiddleware:
}
```
Token Validation:
Each provider implements its own token validation logic:
- Kubernetes: Uses TokenReview API to validate service account tokens
- Custom: Sends token to custom endpoint for validation
Attribute-Based Access Control:
The attributes returned by the auth endpoint are used to determine which
The attributes returned by the auth provider are used to determine which
resources the user can access. Resources can specify required attributes
using the access_attributes field. For a user to access a resource:
1. All attribute categories specified in the resource must be present in the user's attributes
2. For each category, the user must have at least one matching value
If the auth endpoint doesn't return any attributes, the user will only be able to
If the auth provider doesn't return any attributes, the user will only be able to
access resources that don't have access_attributes defined.
"""
def __init__(self, app, auth_endpoint):
def __init__(self, app, auth_config: AuthProviderConfig):
self.app = app
self.auth_endpoint = auth_endpoint
self.auth_provider = create_auth_provider(auth_config)
async def __call__(self, scope, receive, send):
if scope["type"] == "http":
@ -129,66 +89,34 @@ class AuthenticationMiddleware:
if not auth_header or not auth_header.startswith("Bearer "):
return await self._send_auth_error(send, "Missing or invalid Authorization header")
api_key = auth_header.split("Bearer ", 1)[1]
token = auth_header.split("Bearer ", 1)[1]
path = scope.get("path", "")
request_headers = {k.decode(): v.decode() for k, v in headers.items()}
# Remove sensitive headers
if "authorization" in request_headers:
del request_headers["authorization"]
query_string = scope.get("query_string", b"").decode()
params = parse_qs(query_string)
# Build the auth request model
auth_request = AuthRequest(
api_key=api_key,
request=AuthRequestContext(
path=path,
headers=request_headers,
params=params,
),
)
# Validate with authentication endpoint
# Validate token and get access attributes
try:
async with httpx.AsyncClient() as client:
response = await client.post(
self.auth_endpoint,
json=auth_request.model_dump(),
timeout=10.0, # Add a reasonable timeout
)
if response.status_code != 200:
logger.warning(f"Authentication failed: {response.status_code}")
return await self._send_auth_error(send, "Authentication failed")
# Parse and validate the auth response
try:
response_data = response.json()
auth_response = AuthResponse(**response_data)
# Store attributes in request scope for access control
if auth_response.access_attributes:
user_attributes = auth_response.access_attributes.model_dump(exclude_none=True)
else:
logger.warning("No access attributes, setting namespace to api_key by default")
user_attributes = {
"namespaces": [api_key],
}
scope["user_attributes"] = user_attributes
logger.debug(f"Authentication successful: {len(user_attributes)} attributes")
except Exception:
logger.exception("Error parsing authentication response")
return await self._send_auth_error(send, "Invalid authentication response format")
access_attributes = await self.auth_provider.validate_token(token, scope)
except httpx.TimeoutException:
logger.exception("Authentication request timed out")
return await self._send_auth_error(send, "Authentication service timeout")
except ValueError as e:
logger.exception("Error during authentication")
return await self._send_auth_error(send, str(e))
except Exception:
logger.exception("Error during authentication")
return await self._send_auth_error(send, "Authentication service error")
# Store attributes in request scope for access control
if access_attributes:
user_attributes = access_attributes.model_dump(exclude_none=True)
else:
logger.warning("No access attributes, setting namespace to token by default")
user_attributes = {
"namespaces": [token],
}
# Store attributes in request scope
scope["user_attributes"] = user_attributes
logger.debug(f"Authentication successful: {len(scope['user_attributes'])} attributes")
return await self.app(scope, receive, send)
async def _send_auth_error(self, send, message):

View file

@ -0,0 +1,262 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from abc import ABC, abstractmethod
from enum import Enum
from typing import Dict, List, Optional
from urllib.parse import parse_qs
import httpx
from pydantic import BaseModel, Field
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
class AuthResponse(BaseModel):
"""The format of the authentication response from the auth endpoint."""
access_attributes: Optional[AccessAttributes] = Field(
default=None,
description="""
Structured user attributes for attribute-based access control.
These attributes determine which resources the user can access.
The model provides standard categories like "roles", "teams", "projects", and "namespaces".
Each attribute category contains a list of values that the user has for that category.
During access control checks, these values are compared against resource requirements.
Example with standard categories:
```json
{
"roles": ["admin", "data-scientist"],
"teams": ["ml-team"],
"projects": ["llama-3"],
"namespaces": ["research"]
}
```
""",
)
message: Optional[str] = Field(
default=None, description="Optional message providing additional context about the authentication result."
)
class AuthRequestContext(BaseModel):
path: str = Field(description="The path of the request being authenticated")
headers: Dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
params: Dict[str, List[str]] = Field(
description="Query parameters from the original request, parsed as dictionary of lists"
)
class AuthRequest(BaseModel):
api_key: str = Field(description="The API key extracted from the Authorization header")
request: AuthRequestContext = Field(description="Context information about the request being authenticated")
class AuthProviderType(str, Enum):
"""Supported authentication provider types."""
KUBERNETES = "kubernetes"
CUSTOM = "custom"
class AuthProviderConfig(BaseModel):
"""Base configuration for authentication providers."""
provider_type: AuthProviderType = Field(..., description="Type of authentication provider")
config: Dict[str, str] = Field(..., description="Provider-specific configuration")
class AuthProvider(ABC):
"""Abstract base class for authentication providers."""
@abstractmethod
async def validate_token(self, token: str, scope: Optional[Dict] = None) -> Optional[AccessAttributes]:
"""Validate a token and return access attributes."""
pass
@abstractmethod
async def close(self):
"""Clean up any resources."""
pass
class KubernetesAuthProvider(AuthProvider):
"""Kubernetes authentication provider that validates tokens against the Kubernetes API server."""
def __init__(self, config: Dict[str, str]):
self.api_server_url = config["api_server_url"]
self.ca_cert_path = config.get("ca_cert_path")
self._client = None
async def _get_client(self):
"""Get or create a Kubernetes client."""
if self._client is None:
# kubernetes-client has not async support, see:
# https://github.com/kubernetes-client/python/issues/323
from kubernetes import client
from kubernetes.client import ApiClient
# Configure the client
configuration = client.Configuration()
configuration.host = self.api_server_url
if self.ca_cert_path:
configuration.ssl_ca_cert = self.ca_cert_path
configuration.verify_ssl = bool(self.ca_cert_path)
# Create API client
self._client = ApiClient(configuration)
return self._client
async def validate_token(self, token: str, scope: Optional[Dict] = None) -> Optional[AccessAttributes]:
"""Validate a Kubernetes token and return access attributes."""
try:
client = await self._get_client()
# Set the token in the client
client.set_default_header("Authorization", f"Bearer {token}")
# Make a request to validate the token
# We use the /api endpoint which requires authentication
from kubernetes.client import CoreV1Api
api = CoreV1Api(client)
api.get_api_resources(_request_timeout=3.0) # Set timeout for this specific request
# If we get here, the token is valid
# Extract user info from the token claims
import base64
# Decode the token (without verification since we've already validated it)
token_parts = token.split(".")
payload = json.loads(base64.b64decode(token_parts[1] + "=" * (-len(token_parts[1]) % 4)))
# Extract user information from the token
username = payload.get("sub", "")
groups = payload.get("groups", [])
return AccessAttributes(
roles=[username], # Use username as a role
teams=groups, # Use Kubernetes groups as teams
)
except Exception as e:
logger.exception("Failed to validate Kubernetes token")
raise ValueError("Invalid or expired token") from e
async def close(self):
"""Close the HTTP client."""
if self._client:
self._client.close()
self._client = None
class CustomAuthProvider(AuthProvider):
"""Custom authentication provider that uses an external endpoint."""
def __init__(self, config: Dict[str, str]):
self.endpoint = config["endpoint"]
self._client = None
async def validate_token(self, token: str, scope: Optional[Dict] = None) -> Optional[AccessAttributes]:
"""Validate a token using the custom authentication endpoint."""
if not self.endpoint:
raise ValueError("Authentication endpoint not configured")
if scope is None:
scope = {}
headers = dict(scope.get("headers", []))
path = scope.get("path", "")
request_headers = {k.decode(): v.decode() for k, v in headers.items()}
# Remove sensitive headers
if "authorization" in request_headers:
del request_headers["authorization"]
query_string = scope.get("query_string", b"").decode()
params = parse_qs(query_string)
# Build the auth request model
auth_request = AuthRequest(
api_key=token,
request=AuthRequestContext(
path=path,
headers=request_headers,
params=params,
),
)
# Validate with authentication endpoint
try:
async with httpx.AsyncClient() as client:
response = await client.post(
self.endpoint,
json=auth_request.model_dump(),
timeout=10.0, # Add a reasonable timeout
)
if response.status_code != 200:
logger.warning(f"Authentication failed with status code: {response.status_code}")
raise ValueError(f"Authentication failed: {response.status_code}")
# Parse and validate the auth response
try:
response_data = response.json()
auth_response = AuthResponse(**response_data)
# Store attributes in request scope for access control
if auth_response.access_attributes:
return auth_response.access_attributes
else:
logger.warning("No access attributes, setting namespace to api_key by default")
user_attributes = {
"namespaces": [token],
}
scope["user_attributes"] = user_attributes
logger.debug(f"Authentication successful: {len(user_attributes)} attributes")
return auth_response.access_attributes
except Exception as e:
logger.exception("Error parsing authentication response")
raise ValueError("Invalid authentication response format") from e
except httpx.TimeoutException:
logger.exception("Authentication request timed out")
raise
except ValueError:
# Re-raise ValueError exceptions to preserve their message
raise
except Exception as e:
logger.exception("Error during authentication")
raise ValueError("Authentication service error") from e
async def close(self):
"""Close the HTTP client."""
if self._client:
await self._client.aclose()
self._client = None
def create_auth_provider(config: AuthProviderConfig) -> AuthProvider:
"""Factory function to create the appropriate auth provider."""
provider_type = config.provider_type.lower()
if provider_type == "kubernetes":
return KubernetesAuthProvider(config.config)
elif provider_type == "custom":
return CustomAuthProvider(config.config)
else:
supported_providers = ", ".join([t.value for t in AuthProviderType])
raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}")

View file

@ -419,9 +419,9 @@ def main(args: Optional[argparse.Namespace] = None):
app.add_middleware(ClientVersionMiddleware)
# Add authentication middleware if configured
if config.server.auth and config.server.auth.endpoint:
logger.info(f"Enabling authentication with endpoint: {config.server.auth.endpoint}")
app.add_middleware(AuthenticationMiddleware, auth_endpoint=config.server.auth.endpoint)
if config.server.auth:
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_type.value}")
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
try:
impls = asyncio.run(construct_stack(config))

View file

@ -94,12 +94,16 @@ def tool_chat_page():
st.subheader("Agent Configurations")
st.subheader("Agent Type")
agent_type = st.radio(
"Select Agent Type",
[AgentType.REGULAR, AgentType.REACT],
format_func=lambda x: x.value,
label="Select Agent Type",
options=["Regular", "ReAct"],
on_change=reset_agent,
)
if agent_type == "ReAct":
agent_type = AgentType.REACT
else:
agent_type = AgentType.REGULAR
max_tokens = st.slider(
"Max Tokens",
min_value=0,

View file

@ -792,6 +792,13 @@ def llama3_3_instruct_models() -> List[Model]:
@lru_cache
def safety_models() -> List[Model]:
return [
Model(
core_model_id=CoreModelId.llama_guard_4_12b,
description="Llama Guard v4 12b system safety model",
huggingface_repo="meta-llama/Llama-Guard-4-12B",
arch_args={},
pth_file_count=1,
),
Model(
core_model_id=CoreModelId.llama_guard_3_11b_vision,
description="Llama Guard v3 11b vision system safety model",

View file

@ -81,6 +81,7 @@ class CoreModelId(Enum):
llama_guard_2_8b = "Llama-Guard-2-8B"
llama_guard_3_11b_vision = "Llama-Guard-3-11B-Vision"
llama_guard_3_1b = "Llama-Guard-3-1B"
llama_guard_4_12b = "Llama-Guard-4-12B"
def is_multimodal(model_id) -> bool:
@ -148,6 +149,7 @@ def model_family(model_id) -> ModelFamily:
CoreModelId.llama_guard_2_8b,
CoreModelId.llama_guard_3_11b_vision,
CoreModelId.llama_guard_3_1b,
CoreModelId.llama_guard_4_12b,
]:
return ModelFamily.safety
else:
@ -225,5 +227,7 @@ class Model(BaseModel):
CoreModelId.llama_guard_3_1b,
]:
return 131072
elif self.core_model_id == CoreModelId.llama_guard_4_12b:
return 8192
else:
raise ValueError(f"Unknown max_seq_len for {self.core_model_id}")

View file

@ -23,6 +23,9 @@ from llama_stack.apis.agents import (
Document,
ListAgentSessionsResponse,
ListAgentsResponse,
OpenAIResponseInputMessage,
OpenAIResponseInputTool,
OpenAIResponseObject,
Session,
Turn,
)
@ -40,6 +43,7 @@ from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_imp
from .agent_instance import ChatAgent
from .config import MetaReferenceAgentsImplConfig
from .openai_responses import OpenAIResponsesImpl
logger = logging.getLogger()
logger.setLevel(logging.INFO)
@ -63,9 +67,16 @@ class MetaReferenceAgentsImpl(Agents):
self.tool_groups_api = tool_groups_api
self.in_memory_store = InmemoryKVStoreImpl()
self.openai_responses_impl = None
async def initialize(self) -> None:
self.persistence_store = await kvstore_impl(self.config.persistence_store)
self.openai_responses_impl = OpenAIResponsesImpl(
self.persistence_store,
inference_api=self.inference_api,
tool_groups_api=self.tool_groups_api,
tool_runtime_api=self.tool_runtime_api,
)
# check if "bwrap" is available
if not shutil.which("bwrap"):
@ -244,3 +255,23 @@ class MetaReferenceAgentsImpl(Agents):
agent_id: str,
) -> ListAgentSessionsResponse:
pass
# OpenAI responses
async def get_openai_response(
self,
id: str,
) -> OpenAIResponseObject:
return await self.openai_responses_impl.get_openai_response(id)
async def create_openai_response(
self,
input: Union[str, List[OpenAIResponseInputMessage]],
model: str,
previous_response_id: Optional[str] = None,
store: Optional[bool] = True,
stream: Optional[bool] = False,
tools: Optional[List[OpenAIResponseInputTool]] = None,
) -> OpenAIResponseObject:
return await self.openai_responses_impl.create_openai_response(
input, model, previous_response_id, store, stream, tools
)

View file

@ -0,0 +1,319 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import uuid
from typing import AsyncIterator, List, Optional, Union, cast
from openai.types.chat import ChatCompletionToolParam
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputMessage,
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool,
OpenAIResponseObject,
OpenAIResponseObjectStream,
OpenAIResponseObjectStreamResponseCompleted,
OpenAIResponseObjectStreamResponseCreated,
OpenAIResponseOutput,
OpenAIResponseOutputMessage,
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageWebSearchToolCall,
)
from llama_stack.apis.inference.inference import (
Inference,
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionToolCallFunction,
OpenAIChoice,
OpenAIImageURL,
OpenAIMessageParam,
OpenAIToolMessageParam,
OpenAIUserMessageParam,
)
from llama_stack.apis.tools.tools import ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
from llama_stack.providers.utils.kvstore import KVStore
logger = get_logger(name=__name__, category="openai_responses")
OPENAI_RESPONSES_PREFIX = "openai_responses:"
async def _previous_response_to_messages(previous_response: OpenAIResponseObject) -> List[OpenAIMessageParam]:
messages: List[OpenAIMessageParam] = []
for output_message in previous_response.output:
if isinstance(output_message, OpenAIResponseOutputMessage):
messages.append(OpenAIAssistantMessageParam(content=output_message.content[0].text))
return messages
async def _openai_choices_to_output_messages(choices: List[OpenAIChoice]) -> List[OpenAIResponseOutputMessage]:
output_messages = []
for choice in choices:
output_content = ""
if isinstance(choice.message.content, str):
output_content = choice.message.content
elif isinstance(choice.message.content, OpenAIChatCompletionContentPartTextParam):
output_content = choice.message.content.text
# TODO: handle image content
output_messages.append(
OpenAIResponseOutputMessage(
id=f"msg_{uuid.uuid4()}",
content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)],
status="completed",
)
)
return output_messages
class OpenAIResponsesImpl:
def __init__(
self,
persistence_store: KVStore,
inference_api: Inference,
tool_groups_api: ToolGroups,
tool_runtime_api: ToolRuntime,
):
self.persistence_store = persistence_store
self.inference_api = inference_api
self.tool_groups_api = tool_groups_api
self.tool_runtime_api = tool_runtime_api
async def get_openai_response(
self,
id: str,
) -> OpenAIResponseObject:
key = f"{OPENAI_RESPONSES_PREFIX}{id}"
response_json = await self.persistence_store.get(key=key)
if response_json is None:
raise ValueError(f"OpenAI response with id '{id}' not found")
return OpenAIResponseObject.model_validate_json(response_json)
async def create_openai_response(
self,
input: Union[str, List[OpenAIResponseInputMessage]],
model: str,
previous_response_id: Optional[str] = None,
store: Optional[bool] = True,
stream: Optional[bool] = False,
tools: Optional[List[OpenAIResponseInputTool]] = None,
):
stream = False if stream is None else stream
messages: List[OpenAIMessageParam] = []
if previous_response_id:
previous_response = await self.get_openai_response(previous_response_id)
messages.extend(await _previous_response_to_messages(previous_response))
# TODO: refactor this user_content parsing out into a separate method
user_content: Union[str, List[OpenAIChatCompletionContentPartParam]] = ""
if isinstance(input, list):
user_content = []
for user_input in input:
if isinstance(user_input.content, list):
for user_input_content in user_input.content:
if isinstance(user_input_content, OpenAIResponseInputMessageContentText):
user_content.append(OpenAIChatCompletionContentPartTextParam(text=user_input_content.text))
elif isinstance(user_input_content, OpenAIResponseInputMessageContentImage):
if user_input_content.image_url:
image_url = OpenAIImageURL(
url=user_input_content.image_url, detail=user_input_content.detail
)
user_content.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url))
else:
user_content.append(OpenAIChatCompletionContentPartTextParam(text=user_input.content))
else:
user_content = input
messages.append(OpenAIUserMessageParam(content=user_content))
chat_tools = await self._convert_response_tools_to_chat_tools(tools) if tools else None
chat_response = await self.inference_api.openai_chat_completion(
model=model,
messages=messages,
tools=chat_tools,
stream=stream,
)
if stream:
# TODO: refactor this into a separate method that handles streaming
chat_response_id = ""
chat_response_content = []
# TODO: these chunk_ fields are hacky and only take the last chunk into account
chunk_created = 0
chunk_model = ""
chunk_finish_reason = ""
async for chunk in chat_response:
chat_response_id = chunk.id
chunk_created = chunk.created
chunk_model = chunk.model
for chunk_choice in chunk.choices:
# TODO: this only works for text content
chat_response_content.append(chunk_choice.delta.content or "")
if chunk_choice.finish_reason:
chunk_finish_reason = chunk_choice.finish_reason
assistant_message = OpenAIAssistantMessageParam(content="".join(chat_response_content))
chat_response = OpenAIChatCompletion(
id=chat_response_id,
choices=[
OpenAIChoice(
message=assistant_message,
finish_reason=chunk_finish_reason,
index=0,
)
],
created=chunk_created,
model=chunk_model,
)
else:
# dump and reload to map to our pydantic types
chat_response = OpenAIChatCompletion(**chat_response.model_dump())
output_messages: List[OpenAIResponseOutput] = []
if chat_response.choices[0].message.tool_calls:
output_messages.extend(
await self._execute_tool_and_return_final_output(model, stream, chat_response, messages)
)
else:
output_messages.extend(await _openai_choices_to_output_messages(chat_response.choices))
response = OpenAIResponseObject(
created_at=chat_response.created,
id=f"resp-{uuid.uuid4()}",
model=model,
object="response",
status="completed",
output=output_messages,
)
if store:
# Store in kvstore
key = f"{OPENAI_RESPONSES_PREFIX}{response.id}"
await self.persistence_store.set(
key=key,
value=response.model_dump_json(),
)
if stream:
async def async_response() -> AsyncIterator[OpenAIResponseObjectStream]:
# TODO: response created should actually get emitted much earlier in the process
yield OpenAIResponseObjectStreamResponseCreated(response=response)
yield OpenAIResponseObjectStreamResponseCompleted(response=response)
return async_response()
return response
async def _convert_response_tools_to_chat_tools(
self, tools: List[OpenAIResponseInputTool]
) -> List[ChatCompletionToolParam]:
chat_tools: List[ChatCompletionToolParam] = []
for input_tool in tools:
# TODO: Handle other tool types
if input_tool.type == "web_search":
tool_name = "web_search"
tool = await self.tool_groups_api.get_tool(tool_name)
tool_def = ToolDefinition(
tool_name=tool_name,
description=tool.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in tool.parameters
},
)
chat_tool = convert_tooldef_to_openai_tool(tool_def)
chat_tools.append(chat_tool)
else:
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
return chat_tools
async def _execute_tool_and_return_final_output(
self, model_id: str, stream: bool, chat_response: OpenAIChatCompletion, messages: List[OpenAIMessageParam]
) -> List[OpenAIResponseOutput]:
output_messages: List[OpenAIResponseOutput] = []
choice = chat_response.choices[0]
# If the choice is not an assistant message, we don't need to execute any tools
if not isinstance(choice.message, OpenAIAssistantMessageParam):
return output_messages
# If the assistant message doesn't have any tool calls, we don't need to execute any tools
if not choice.message.tool_calls:
return output_messages
# Add the assistant message with tool_calls response to the messages list
messages.append(choice.message)
for tool_call in choice.message.tool_calls:
tool_call_id = tool_call.id
function = tool_call.function
# If for some reason the tool call doesn't have a function or id, we can't execute it
if not function or not tool_call_id:
continue
# TODO: telemetry spans for tool calls
result = await self._execute_tool_call(function)
# Handle tool call failure
if not result:
output_messages.append(
OpenAIResponseOutputMessageWebSearchToolCall(
id=tool_call_id,
status="failed",
)
)
continue
output_messages.append(
OpenAIResponseOutputMessageWebSearchToolCall(
id=tool_call_id,
status="completed",
),
)
result_content = ""
# TODO: handle other result content types and lists
if isinstance(result.content, str):
result_content = result.content
messages.append(OpenAIToolMessageParam(content=result_content, tool_call_id=tool_call_id))
tool_results_chat_response = await self.inference_api.openai_chat_completion(
model=model_id,
messages=messages,
stream=stream,
)
# type cast to appease mypy
tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response)
tool_final_outputs = await _openai_choices_to_output_messages(tool_results_chat_response.choices)
# TODO: Wire in annotations with URLs, titles, etc to these output messages
output_messages.extend(tool_final_outputs)
return output_messages
async def _execute_tool_call(
self,
function: OpenAIChatCompletionToolCallFunction,
) -> Optional[ToolInvocationResult]:
if not function.name:
return None
function_args = json.loads(function.arguments) if function.arguments else {}
logger.info(f"executing tool call: {function.name} with args: {function_args}")
result = await self.tool_runtime_api.invoke_tool(
tool_name=function.name,
kwargs=function_args,
)
logger.debug(f"tool call {function.name} completed with result: {result}")
return result

View file

@ -17,10 +17,8 @@ from llama_stack.apis.common.type_system import (
DialogType,
StringType,
)
from llama_stack.apis.datasets import Datasets
from llama_stack.providers.utils.common.data_schema_validator import (
ColumnName,
validate_dataset_schema,
)
EXPECTED_DATASET_SCHEMA: dict[str, list[dict[str, Any]]] = {
@ -36,21 +34,3 @@ EXPECTED_DATASET_SCHEMA: dict[str, list[dict[str, Any]]] = {
}
],
}
async def validate_input_dataset_schema(
datasets_api: Datasets,
dataset_id: str,
dataset_type: str,
) -> None:
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def:
raise ValueError(f"Dataset {dataset_id} does not exist.")
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
if dataset_type not in EXPECTED_DATASET_SCHEMA:
raise ValueError(f"Dataset type {dataset_type} is not supported.")
validate_dataset_schema(dataset_def.dataset_schema, EXPECTED_DATASET_SCHEMA[dataset_type])

View file

@ -48,9 +48,6 @@ from llama_stack.apis.post_training import (
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.inline.post_training.common.validator import (
validate_input_dataset_schema,
)
from llama_stack.providers.inline.post_training.torchtune.common import utils
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
TorchtuneCheckpointer,
@ -348,11 +345,9 @@ class LoraFinetuningSingleDevice:
all_rows = await fetch_rows(dataset_id)
rows = all_rows.data
await validate_input_dataset_schema(
datasets_api=self.datasets_api,
dataset_id=dataset_id,
dataset_type=self._data_format.value,
)
# TODO (xiyan): validate dataset schema
# dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
data_transform = await utils.get_data_transform(self._data_format)
ds = SFTDataset(
rows,

View file

@ -30,7 +30,7 @@ class TelemetryConfig(BaseModel):
)
service_name: str = Field(
# service name is always the same, use zero-width space to avoid clutter
default="",
default="",
description="The service name to use for telemetry",
)
sinks: List[TelemetrySink] = Field(
@ -52,7 +52,7 @@ class TelemetryConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> Dict[str, Any]:
return {
"service_name": "${env.OTEL_SERVICE_NAME:}",
"service_name": "${env.OTEL_SERVICE_NAME:}",
"sinks": "${env.TELEMETRY_SINKS:console,sqlite}",
"sqlite_db_path": "${env.SQLITE_DB_PATH:" + __distro_dir__ + "/" + db_name + "}",
"sqlite_db_path": "${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + db_name,
}

View file

@ -227,6 +227,16 @@ def available_providers() -> List[ProviderSpec]:
provider_data_validator="llama_stack.providers.remote.inference.fireworks_openai_compat.config.FireworksProviderDataValidator",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="llama-openai-compat",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.llama_openai_compat",
config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig",
provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from .config import LlamaCompatConfig
async def get_adapter_impl(config: LlamaCompatConfig, _deps) -> Inference:
# import dynamically so the import is used only when it is needed
from .llama import LlamaCompatInferenceAdapter
adapter = LlamaCompatInferenceAdapter(config)
return adapter

View file

@ -0,0 +1,38 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
class LlamaProviderDataValidator(BaseModel):
llama_api_key: Optional[str] = Field(
default=None,
description="API key for api.llama models",
)
@json_schema_type
class LlamaCompatConfig(BaseModel):
api_key: Optional[str] = Field(
default=None,
description="The Llama API key",
)
openai_compat_api_base: str = Field(
default="https://api.llama.com/compat/v1/",
description="The URL for the Llama API server",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.LLAMA_API_KEY}", **kwargs) -> Dict[str, Any]:
return {
"openai_compat_api_base": "https://api.llama.com/compat/v1/",
"api_key": api_key,
}

View file

@ -0,0 +1,34 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.remote.inference.llama_openai_compat.config import (
LlamaCompatConfig,
)
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
LiteLLMOpenAIMixin,
)
from .models import MODEL_ENTRIES
class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin):
_config: LlamaCompatConfig
def __init__(self, config: LlamaCompatConfig):
LiteLLMOpenAIMixin.__init__(
self,
model_entries=MODEL_ENTRIES,
api_key_from_config=config.api_key,
provider_data_api_key_field="llama_api_key",
openai_compat_api_base=config.openai_compat_api_base,
)
self.config = config
async def initialize(self):
await super().initialize()
async def shutdown(self):
await super().shutdown()

View file

@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"Llama-3.3-70B-Instruct",
CoreModelId.llama3_3_70b_instruct.value,
),
build_hf_repo_model_entry(
"Llama-4-Scout-17B-16E-Instruct-FP8",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
build_hf_repo_model_entry(
"Llama-4-Maverick-17B-128E-Instruct-FP8",
CoreModelId.llama4_maverick_17b_128e_instruct.value,
),
]

View file

@ -433,6 +433,12 @@ class OllamaInferenceAdapter(
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self._get_model(model)
# ollama still makes tool calls even when tool_choice is "none"
# so we need to remove the tools in that case
if tool_choice == "none" and tools is not None:
tools = None
params = {
k: v
for k, v in {

View file

@ -90,6 +90,9 @@ class LiteLLMOpenAIMixin(
raise ValueError(f"Unsupported model: {model.provider_resource_id}")
return model
def get_litellm_model_name(self, model_id: str) -> str:
return "openai/" + model_id if self.is_openai_compat else model_id
async def completion(
self,
model_id: str,
@ -130,8 +133,7 @@ class LiteLLMOpenAIMixin(
)
params = await self._get_params(request)
if self.is_openai_compat:
params["model"] = "openai/" + params["model"]
params["model"] = self.get_litellm_model_name(params["model"])
logger.debug(f"params to litellm (openai compat): {params}")
# unfortunately, we need to use synchronous litellm.completion here because litellm
@ -220,21 +222,23 @@ class LiteLLMOpenAIMixin(
else request.tool_config.tool_choice
)
return {
"model": request.model,
"api_key": self.get_api_key(),
"api_base": self.api_base,
**input_dict,
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}
def get_api_key(self) -> str:
provider_data = self.get_request_provider_data()
key_field = self.provider_data_api_key_field
if provider_data and getattr(provider_data, key_field, None):
api_key = getattr(provider_data, key_field)
else:
api_key = self.api_key_from_config
return {
"model": request.model,
"api_key": api_key,
"api_base": self.api_base,
**input_dict,
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}
return api_key
async def embeddings(
self,
@ -247,7 +251,7 @@ class LiteLLMOpenAIMixin(
model = await self.model_store.get_model(model_id)
response = litellm.embedding(
model=model.provider_resource_id,
model=self.get_litellm_model_name(model.provider_resource_id),
input=[interleaved_content_as_str(content) for content in contents],
)
@ -278,7 +282,7 @@ class LiteLLMOpenAIMixin(
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
model=self.get_litellm_model_name(model_obj.provider_resource_id),
prompt=prompt,
best_of=best_of,
echo=echo,
@ -297,6 +301,8 @@ class LiteLLMOpenAIMixin(
user=user,
guided_choice=guided_choice,
prompt_logprobs=prompt_logprobs,
api_key=self.get_api_key(),
api_base=self.api_base,
)
return await litellm.atext_completion(**params)
@ -328,7 +334,7 @@ class LiteLLMOpenAIMixin(
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
model=self.get_litellm_model_name(model_obj.provider_resource_id),
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
@ -351,6 +357,8 @@ class LiteLLMOpenAIMixin(
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
api_key=self.get_api_key(),
api_base=self.api_base,
)
return await litellm.acompletion(**params)

View file

@ -638,10 +638,13 @@ async def convert_message_to_openai_dict_new(
)
for tool in message.tool_calls
]
params = {}
if tool_calls:
params["tool_calls"] = tool_calls
out = OpenAIChatCompletionAssistantMessage(
role="assistant",
content=await _convert_message_content(message.content),
tool_calls=tool_calls or None,
**params,
)
elif isinstance(message, ToolResponseMessage):
out = OpenAIChatCompletionToolMessage(

View file

@ -478,6 +478,8 @@ class JsonSchemaGenerator:
}
return ret
elif origin_type is Literal:
if len(typing.get_args(typ)) != 1:
raise ValueError(f"Literal type {typ} has {len(typing.get_args(typ))} arguments")
(literal_value,) = typing.get_args(typ) # unpack value of literal type
schema = self.type_to_schema(type(literal_value))
schema["const"] = literal_value

View file

@ -39,9 +39,9 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
service_name: ${env.OTEL_SERVICE_NAME:}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/bedrock/trace_store.db}
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/trace_store.db
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference

View file

@ -79,9 +79,9 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
service_name: ${env.OTEL_SERVICE_NAME:}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/cerebras/trace_store.db}
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/trace_store.db
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search

View file

@ -42,9 +42,9 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
service_name: ${env.OTEL_SERVICE_NAME:}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/ci-tests/trace_store.db}
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/trace_store.db
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference

View file

@ -45,9 +45,9 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
service_name: ${env.OTEL_SERVICE_NAME:}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/dell/trace_store.db}
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/trace_store.db
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference

View file

@ -41,9 +41,9 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
service_name: ${env.OTEL_SERVICE_NAME:}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/dell/trace_store.db}
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/trace_store.db
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference

View file

@ -344,6 +344,45 @@
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"llama_api": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"fastapi",
"fire",
"httpx",
"langdetect",
"litellm",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlite-vec",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"meta-reference-gpu": [
"accelerate",
"aiosqlite",

View file

@ -71,9 +71,9 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
service_name: ${env.OTEL_SERVICE_NAME:}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/dev/trace_store.db}
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/trace_store.db
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference

View file

@ -50,9 +50,9 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
service_name: ${env.OTEL_SERVICE_NAME:}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/fireworks/trace_store.db}
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/trace_store.db
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference

View file

@ -45,9 +45,9 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
service_name: ${env.OTEL_SERVICE_NAME:}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/fireworks/trace_store.db}
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/trace_store.db
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference

View file

@ -45,9 +45,9 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
service_name: ${env.OTEL_SERVICE_NAME:}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/groq/trace_store.db}
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/trace_store.db
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference

View file

@ -50,9 +50,9 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
service_name: ${env.OTEL_SERVICE_NAME:}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/hf-endpoint/trace_store.db}
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/trace_store.db
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference

View file

@ -45,9 +45,9 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
service_name: ${env.OTEL_SERVICE_NAME:}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/hf-endpoint/trace_store.db}
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/trace_store.db
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference

View file

@ -50,9 +50,9 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
service_name: ${env.OTEL_SERVICE_NAME:}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/hf-serverless/trace_store.db}
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/trace_store.db
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference

View file

@ -45,9 +45,9 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
service_name: ${env.OTEL_SERVICE_NAME:}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/hf-serverless/trace_store.db}
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/trace_store.db
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .llama_api import get_distribution_template # noqa: F401

View file

@ -0,0 +1,33 @@
version: '2'
distribution_spec:
description: Distribution for running e2e tests in CI
providers:
inference:
- remote::llama-openai-compat
- inline::sentence-transformers
vector_io:
- inline::sqlite-vec
- remote::chromadb
- remote::pgvector
safety:
- inline::llama-guard
agents:
- inline::meta-reference
telemetry:
- inline::meta-reference
eval:
- inline::meta-reference
datasetio:
- remote::huggingface
- inline::localfs
scoring:
- inline::basic
- inline::llm-as-judge
- inline::braintrust
tool_runtime:
- remote::brave-search
- remote::tavily-search
- inline::code-interpreter
- inline::rag-runtime
- remote::model-context-protocol
image_type: conda

View file

@ -0,0 +1,159 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Tuple
from llama_stack.apis.models.models import ModelType
from llama_stack.distribution.datatypes import (
ModelInput,
Provider,
ShieldInput,
ToolGroupInput,
)
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
from llama_stack.providers.inline.vector_io.sqlite_vec.config import (
SQLiteVectorIOConfig,
)
from llama_stack.providers.remote.inference.llama_openai_compat.config import (
LlamaCompatConfig,
)
from llama_stack.providers.remote.inference.llama_openai_compat.models import (
MODEL_ENTRIES as LLLAMA_MODEL_ENTRIES,
)
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
from llama_stack.providers.remote.vector_io.pgvector.config import (
PGVectorVectorIOConfig,
)
from llama_stack.templates.template import (
DistributionTemplate,
RunConfigSettings,
get_model_registry,
)
def get_inference_providers() -> Tuple[List[Provider], List[ModelInput]]:
# in this template, we allow each API key to be optional
providers = [
(
"llama-openai-compat",
LLLAMA_MODEL_ENTRIES,
LlamaCompatConfig.sample_run_config(api_key="${env.LLAMA_API_KEY:}"),
),
]
inference_providers = []
available_models = {}
for provider_id, model_entries, config in providers:
inference_providers.append(
Provider(
provider_id=provider_id,
provider_type=f"remote::{provider_id}",
config=config,
)
)
available_models[provider_id] = model_entries
return inference_providers, available_models
def get_distribution_template() -> DistributionTemplate:
inference_providers, available_models = get_inference_providers()
providers = {
"inference": ([p.provider_type for p in inference_providers] + ["inline::sentence-transformers"]),
"vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
"eval": ["inline::meta-reference"],
"datasetio": ["remote::huggingface", "inline::localfs"],
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
"tool_runtime": [
"remote::brave-search",
"remote::tavily-search",
"inline::code-interpreter",
"inline::rag-runtime",
"remote::model-context-protocol",
],
}
name = "llama_api"
vector_io_providers = [
Provider(
provider_id="sqlite-vec",
provider_type="inline::sqlite-vec",
config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
),
Provider(
provider_id="${env.ENABLE_CHROMADB+chromadb}",
provider_type="remote::chromadb",
config=ChromaVectorIOConfig.sample_run_config(url="${env.CHROMADB_URL:}"),
),
Provider(
provider_id="${env.ENABLE_PGVECTOR+pgvector}",
provider_type="remote::pgvector",
config=PGVectorVectorIOConfig.sample_run_config(
db="${env.PGVECTOR_DB:}",
user="${env.PGVECTOR_USER:}",
password="${env.PGVECTOR_PASSWORD:}",
),
),
]
embedding_provider = Provider(
provider_id="sentence-transformers",
provider_type="inline::sentence-transformers",
config=SentenceTransformersInferenceConfig.sample_run_config(),
)
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
ToolGroupInput(
toolgroup_id="builtin::code_interpreter",
provider_id="code-interpreter",
),
]
embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2",
provider_id=embedding_provider.provider_id,
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 384,
},
)
default_models = get_model_registry(available_models)
return DistributionTemplate(
name=name,
distro_type="self_hosted",
description="Distribution for running e2e tests in CI",
container_image=None,
template_path=None,
providers=providers,
available_models_by_provider=available_models,
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": inference_providers + [embedding_provider],
"vector_io": vector_io_providers,
},
default_models=default_models + [embedding_model],
default_tool_groups=default_tool_groups,
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
),
},
run_config_env_vars={
"LLAMA_STACK_PORT": (
"8321",
"Port for the Llama Stack distribution server",
),
},
)

View file

@ -0,0 +1,167 @@
version: '2'
image_name: llama_api
apis:
- agents
- datasetio
- eval
- inference
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: llama-openai-compat
provider_type: remote::llama-openai-compat
config:
openai_compat_api_base: https://api.llama.com/compat/v1/
api_key: ${env.LLAMA_API_KEY:}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
vector_io:
- provider_id: sqlite-vec
provider_type: inline::sqlite-vec
config:
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/sqlite_vec.db
- provider_id: ${env.ENABLE_CHROMADB+chromadb}
provider_type: remote::chromadb
config:
url: ${env.CHROMADB_URL:}
- provider_id: ${env.ENABLE_PGVECTOR+pgvector}
provider_type: remote::pgvector
config:
host: ${env.PGVECTOR_HOST:localhost}
port: ${env.PGVECTOR_PORT:5432}
db: ${env.PGVECTOR_DB:}
user: ${env.PGVECTOR_USER:}
password: ${env.PGVECTOR_PASSWORD:}
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config:
excluded_categories: []
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/agents_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: ${env.OTEL_SERVICE_NAME:}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/trace_store.db
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/meta_reference_eval.db
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/huggingface_datasetio.db
- provider_id: localfs
provider_type: inline::localfs
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/localfs_datasetio.db
scoring:
- provider_id: basic
provider_type: inline::basic
config: {}
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
config: {}
- provider_id: braintrust
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:}
max_results: 3
- provider_id: code-interpreter
provider_type: inline::code-interpreter
config: {}
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
config: {}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/registry.db
models:
- metadata: {}
model_id: Llama-3.3-70B-Instruct
provider_id: llama-openai-compat
provider_model_id: Llama-3.3-70B-Instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.3-70B-Instruct
provider_id: llama-openai-compat
provider_model_id: Llama-3.3-70B-Instruct
model_type: llm
- metadata: {}
model_id: Llama-4-Scout-17B-16E-Instruct-FP8
provider_id: llama-openai-compat
provider_model_id: Llama-4-Scout-17B-16E-Instruct-FP8
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct
provider_id: llama-openai-compat
provider_model_id: Llama-4-Scout-17B-16E-Instruct-FP8
model_type: llm
- metadata: {}
model_id: Llama-4-Maverick-17B-128E-Instruct-FP8
provider_id: llama-openai-compat
provider_model_id: Llama-4-Maverick-17B-128E-Instruct-FP8
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct
provider_id: llama-openai-compat
provider_model_id: Llama-4-Maverick-17B-128E-Instruct-FP8
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
model_type: embedding
shields:
- shield_id: meta-llama/Llama-Guard-3-8B
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
- toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter
server:
port: 8321

View file

@ -60,9 +60,9 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
service_name: ${env.OTEL_SERVICE_NAME:}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/meta-reference-gpu/trace_store.db}
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/trace_store.db
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference

View file

@ -50,9 +50,9 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
service_name: ${env.OTEL_SERVICE_NAME:}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/meta-reference-gpu/trace_store.db}
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/trace_store.db
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference

View file

@ -50,9 +50,9 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
service_name: ${env.OTEL_SERVICE_NAME:}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/nvidia/trace_store.db}
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/trace_store.db
eval:
- provider_id: nvidia
provider_type: remote::nvidia

View file

@ -45,9 +45,9 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
service_name: ${env.OTEL_SERVICE_NAME:}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/nvidia/trace_store.db}
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/trace_store.db
eval:
- provider_id: nvidia
provider_type: remote::nvidia

View file

@ -43,9 +43,9 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
service_name: ${env.OTEL_SERVICE_NAME:}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/ollama/trace_store.db}
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/trace_store.db
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference

View file

@ -41,9 +41,9 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
service_name: ${env.OTEL_SERVICE_NAME:}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/ollama/trace_store.db}
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/trace_store.db
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference

View file

@ -68,9 +68,9 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
service_name: ${env.OTEL_SERVICE_NAME:}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/open-benchmark/trace_store.db}
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/trace_store.db
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference

View file

@ -50,9 +50,9 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
service_name: ${env.OTEL_SERVICE_NAME:}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/passthrough/trace_store.db}
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/trace_store.db
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference

View file

@ -45,9 +45,9 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
service_name: ${env.OTEL_SERVICE_NAME:}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/passthrough/trace_store.db}
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/trace_store.db
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference

View file

@ -88,9 +88,9 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
service_name: ${env.OTEL_SERVICE_NAME:}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/remote-vllm/trace_store.db}
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/trace_store.db
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search

View file

@ -81,9 +81,9 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
service_name: ${env.OTEL_SERVICE_NAME:}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/remote-vllm/trace_store.db}
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/trace_store.db
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search

View file

@ -51,9 +51,9 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
service_name: ${env.OTEL_SERVICE_NAME:}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/sambanova/trace_store.db}
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/trace_store.db
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search

View file

@ -45,9 +45,9 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
service_name: ${env.OTEL_SERVICE_NAME:}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/tgi/trace_store.db}
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/trace_store.db
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference

View file

@ -44,9 +44,9 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
service_name: ${env.OTEL_SERVICE_NAME:}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/tgi/trace_store.db}
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/trace_store.db
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference

View file

@ -50,9 +50,9 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
service_name: ${env.OTEL_SERVICE_NAME:}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/together/trace_store.db}
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/trace_store.db
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference

View file

@ -45,9 +45,9 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
service_name: ${env.OTEL_SERVICE_NAME:}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/together/trace_store.db}
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/trace_store.db
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference

View file

@ -78,9 +78,9 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
service_name: ${env.OTEL_SERVICE_NAME:}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/verification/trace_store.db}
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/verification}/trace_store.db
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference

View file

@ -49,9 +49,9 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
service_name: ${env.OTEL_SERVICE_NAME:}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/vllm-gpu/trace_store.db}
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/vllm-gpu}/trace_store.db
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference

View file

@ -43,9 +43,9 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
service_name: ${env.OTEL_SERVICE_NAME:}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/watsonx/trace_store.db}
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/trace_store.db
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference