mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-21 03:59:42 +00:00
Merge branch 'main' into nvidia-e2e-notebook
This commit is contained in:
commit
012dd6891f
96 changed files with 4675 additions and 426 deletions
|
@ -38,6 +38,13 @@ from llama_stack.apis.safety import SafetyViolation
|
|||
from llama_stack.apis.tools import ToolDef
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
from .openai_responses import (
|
||||
OpenAIResponseInputMessage,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
)
|
||||
|
||||
|
||||
class Attachment(BaseModel):
|
||||
"""An attachment to an agent turn.
|
||||
|
@ -593,3 +600,39 @@ class Agents(Protocol):
|
|||
:returns: A ListAgentSessionsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
# We situate the OpenAI Responses API in the Agents API just like we did things
|
||||
# for Inference. The Responses API, in its intent, serves the same purpose as
|
||||
# the Agents API above -- it is essentially a lightweight "agentic loop" with
|
||||
# integrated tool calling.
|
||||
#
|
||||
# Both of these APIs are inherently stateful.
|
||||
|
||||
@webmethod(route="/openai/v1/responses/{id}", method="GET")
|
||||
async def get_openai_response(
|
||||
self,
|
||||
id: str,
|
||||
) -> OpenAIResponseObject:
|
||||
"""Retrieve an OpenAI response by its ID.
|
||||
|
||||
:param id: The ID of the OpenAI response to retrieve.
|
||||
:returns: An OpenAIResponseObject.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/responses", method="POST")
|
||||
async def create_openai_response(
|
||||
self,
|
||||
input: Union[str, List[OpenAIResponseInputMessage]],
|
||||
model: str,
|
||||
previous_response_id: Optional[str] = None,
|
||||
store: Optional[bool] = True,
|
||||
stream: Optional[bool] = False,
|
||||
tools: Optional[List[OpenAIResponseInputTool]] = None,
|
||||
) -> Union[OpenAIResponseObject, AsyncIterator[OpenAIResponseObjectStream]]:
|
||||
"""Create a new OpenAI response.
|
||||
|
||||
:param input: Input message(s) to create the response.
|
||||
:param model: The underlying LLM used for completions.
|
||||
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
|
||||
"""
|
||||
|
|
140
llama_stack/apis/agents/openai_responses.py
Normal file
140
llama_stack/apis/agents/openai_responses.py
Normal file
|
@ -0,0 +1,140 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseError(BaseModel):
|
||||
code: str
|
||||
message: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseOutputMessageContentOutputText(BaseModel):
|
||||
text: str
|
||||
type: Literal["output_text"] = "output_text"
|
||||
|
||||
|
||||
OpenAIResponseOutputMessageContent = Annotated[
|
||||
Union[OpenAIResponseOutputMessageContentOutputText,],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMessageContent")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseOutputMessage(BaseModel):
|
||||
id: str
|
||||
content: List[OpenAIResponseOutputMessageContent]
|
||||
role: Literal["assistant"] = "assistant"
|
||||
status: str
|
||||
type: Literal["message"] = "message"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
|
||||
id: str
|
||||
status: str
|
||||
type: Literal["web_search_call"] = "web_search_call"
|
||||
|
||||
|
||||
OpenAIResponseOutput = Annotated[
|
||||
Union[
|
||||
OpenAIResponseOutputMessage,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObject(BaseModel):
|
||||
created_at: int
|
||||
error: Optional[OpenAIResponseError] = None
|
||||
id: str
|
||||
model: str
|
||||
object: Literal["response"] = "response"
|
||||
output: List[OpenAIResponseOutput]
|
||||
parallel_tool_calls: bool = False
|
||||
previous_response_id: Optional[str] = None
|
||||
status: str
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
truncation: Optional[str] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseCreated(BaseModel):
|
||||
response: OpenAIResponseObject
|
||||
type: Literal["response.created"] = "response.created"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
|
||||
response: OpenAIResponseObject
|
||||
type: Literal["response.completed"] = "response.completed"
|
||||
|
||||
|
||||
OpenAIResponseObjectStream = Annotated[
|
||||
Union[
|
||||
OpenAIResponseObjectStreamResponseCreated,
|
||||
OpenAIResponseObjectStreamResponseCompleted,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseObjectStream, name="OpenAIResponseObjectStream")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputMessageContentText(BaseModel):
|
||||
text: str
|
||||
type: Literal["input_text"] = "input_text"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputMessageContentImage(BaseModel):
|
||||
detail: Literal["low"] | Literal["high"] | Literal["auto"] = "auto"
|
||||
type: Literal["input_image"] = "input_image"
|
||||
# TODO: handle file_id
|
||||
image_url: Optional[str] = None
|
||||
|
||||
|
||||
# TODO: handle file content types
|
||||
OpenAIResponseInputMessageContent = Annotated[
|
||||
Union[OpenAIResponseInputMessageContentText, OpenAIResponseInputMessageContentImage],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMessageContent")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputMessage(BaseModel):
|
||||
content: Union[str, List[OpenAIResponseInputMessageContent]]
|
||||
role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"]
|
||||
type: Optional[Literal["message"]] = "message"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseInputToolWebSearch(BaseModel):
|
||||
type: Literal["web_search"] | Literal["web_search_preview_2025_03_11"] = "web_search"
|
||||
# TODO: actually use search_context_size somewhere...
|
||||
search_context_size: Optional[str] = Field(default="medium", pattern="^low|medium|high$")
|
||||
# TODO: add user_location
|
||||
|
||||
|
||||
OpenAIResponseInputTool = Annotated[
|
||||
Union[OpenAIResponseInputToolWebSearch,],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool")
|
|
@ -460,15 +460,17 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
|||
from llama_stack.models.llama.sku_list import llama_meta_net_info, resolve_model
|
||||
|
||||
from .model.safety_models import (
|
||||
prompt_guard_download_info,
|
||||
prompt_guard_model_sku,
|
||||
prompt_guard_download_info_map,
|
||||
prompt_guard_model_sku_map,
|
||||
)
|
||||
|
||||
prompt_guard = prompt_guard_model_sku()
|
||||
prompt_guard_model_sku_map = prompt_guard_model_sku_map()
|
||||
prompt_guard_download_info_map = prompt_guard_download_info_map()
|
||||
|
||||
for model_id in model_ids:
|
||||
if model_id == prompt_guard.model_id:
|
||||
model = prompt_guard
|
||||
info = prompt_guard_download_info()
|
||||
if model_id in prompt_guard_model_sku_map.keys():
|
||||
model = prompt_guard_model_sku_map[model_id]
|
||||
info = prompt_guard_download_info_map[model_id]
|
||||
else:
|
||||
model = resolve_model(model_id)
|
||||
if model is None:
|
||||
|
|
|
@ -36,11 +36,11 @@ class ModelDescribe(Subcommand):
|
|||
)
|
||||
|
||||
def _run_model_describe_cmd(self, args: argparse.Namespace) -> None:
|
||||
from .safety_models import prompt_guard_model_sku
|
||||
from .safety_models import prompt_guard_model_sku_map
|
||||
|
||||
prompt_guard = prompt_guard_model_sku()
|
||||
if args.model_id == prompt_guard.model_id:
|
||||
model = prompt_guard
|
||||
prompt_guard_model_map = prompt_guard_model_sku_map()
|
||||
if args.model_id in prompt_guard_model_map.keys():
|
||||
model = prompt_guard_model_map[args.model_id]
|
||||
else:
|
||||
model = resolve_model(args.model_id)
|
||||
|
||||
|
|
|
@ -84,7 +84,7 @@ class ModelList(Subcommand):
|
|||
)
|
||||
|
||||
def _run_model_list_cmd(self, args: argparse.Namespace) -> None:
|
||||
from .safety_models import prompt_guard_model_sku
|
||||
from .safety_models import prompt_guard_model_skus
|
||||
|
||||
if args.downloaded:
|
||||
return _run_model_list_downloaded_cmd()
|
||||
|
@ -96,7 +96,7 @@ class ModelList(Subcommand):
|
|||
]
|
||||
|
||||
rows = []
|
||||
for model in all_registered_models() + [prompt_guard_model_sku()]:
|
||||
for model in all_registered_models() + prompt_guard_model_skus():
|
||||
if not args.show_all and not model.is_featured:
|
||||
continue
|
||||
|
||||
|
|
|
@ -42,11 +42,12 @@ class ModelRemove(Subcommand):
|
|||
)
|
||||
|
||||
def _run_model_remove_cmd(self, args: argparse.Namespace) -> None:
|
||||
from .safety_models import prompt_guard_model_sku
|
||||
from .safety_models import prompt_guard_model_sku_map
|
||||
|
||||
prompt_guard = prompt_guard_model_sku()
|
||||
if args.model == prompt_guard.model_id:
|
||||
model = prompt_guard
|
||||
prompt_guard_model_map = prompt_guard_model_sku_map()
|
||||
|
||||
if args.model in prompt_guard_model_map.keys():
|
||||
model = prompt_guard_model_map[args.model]
|
||||
else:
|
||||
model = resolve_model(args.model)
|
||||
|
||||
|
|
|
@ -15,11 +15,11 @@ from llama_stack.models.llama.sku_types import CheckpointQuantizationFormat
|
|||
class PromptGuardModel(BaseModel):
|
||||
"""Make a 'fake' Model-like object for Prompt Guard. Eventually this will be removed."""
|
||||
|
||||
model_id: str = "Prompt-Guard-86M"
|
||||
model_id: str
|
||||
huggingface_repo: str
|
||||
description: str = "Prompt Guard. NOTE: this model will not be provided via `llama` CLI soon."
|
||||
is_featured: bool = False
|
||||
huggingface_repo: str = "meta-llama/Prompt-Guard-86M"
|
||||
max_seq_length: int = 2048
|
||||
max_seq_length: int = 512
|
||||
is_instruct_model: bool = False
|
||||
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
|
||||
arch_args: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
@ -30,18 +30,35 @@ class PromptGuardModel(BaseModel):
|
|||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
def prompt_guard_model_sku():
|
||||
return PromptGuardModel()
|
||||
def prompt_guard_model_skus():
|
||||
return [
|
||||
PromptGuardModel(model_id="Prompt-Guard-86M", huggingface_repo="meta-llama/Prompt-Guard-86M"),
|
||||
PromptGuardModel(
|
||||
model_id="Llama-Prompt-Guard-2-86M",
|
||||
huggingface_repo="meta-llama/Llama-Prompt-Guard-2-86M",
|
||||
),
|
||||
PromptGuardModel(
|
||||
model_id="Llama-Prompt-Guard-2-22M",
|
||||
huggingface_repo="meta-llama/Llama-Prompt-Guard-2-22M",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def prompt_guard_download_info():
|
||||
return LlamaDownloadInfo(
|
||||
folder="Prompt-Guard",
|
||||
files=[
|
||||
"model.safetensors",
|
||||
"special_tokens_map.json",
|
||||
"tokenizer.json",
|
||||
"tokenizer_config.json",
|
||||
],
|
||||
pth_size=1,
|
||||
)
|
||||
def prompt_guard_model_sku_map() -> Dict[str, Any]:
|
||||
return {model.model_id: model for model in prompt_guard_model_skus()}
|
||||
|
||||
|
||||
def prompt_guard_download_info_map() -> Dict[str, LlamaDownloadInfo]:
|
||||
return {
|
||||
model.model_id: LlamaDownloadInfo(
|
||||
folder="Prompt-Guard" if model.model_id == "Prompt-Guard-86M" else model.model_id,
|
||||
files=[
|
||||
"model.safetensors",
|
||||
"special_tokens_map.json",
|
||||
"tokenizer.json",
|
||||
"tokenizer_config.json",
|
||||
],
|
||||
pth_size=1,
|
||||
)
|
||||
for model in prompt_guard_model_skus()
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
@ -235,10 +236,21 @@ class LoggingConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class AuthProviderType(str, Enum):
|
||||
"""Supported authentication provider types."""
|
||||
|
||||
KUBERNETES = "kubernetes"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class AuthenticationConfig(BaseModel):
|
||||
endpoint: str = Field(
|
||||
provider_type: AuthProviderType = Field(
|
||||
...,
|
||||
description="Endpoint URL to validate authentication tokens",
|
||||
description="Type of authentication provider (e.g., 'kubernetes', 'custom')",
|
||||
)
|
||||
config: Dict[str, str] = Field(
|
||||
...,
|
||||
description="Provider-specific configuration",
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -5,74 +5,29 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from typing import Dict, List, Optional
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.distribution.datatypes import AccessAttributes
|
||||
from llama_stack.distribution.server.auth_providers import AuthProviderConfig, create_auth_provider
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="auth")
|
||||
|
||||
|
||||
class AuthRequestContext(BaseModel):
|
||||
path: str = Field(description="The path of the request being authenticated")
|
||||
|
||||
headers: Dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
|
||||
|
||||
params: Dict[str, List[str]] = Field(
|
||||
description="Query parameters from the original request, parsed as dictionary of lists"
|
||||
)
|
||||
|
||||
|
||||
class AuthRequest(BaseModel):
|
||||
api_key: str = Field(description="The API key extracted from the Authorization header")
|
||||
|
||||
request: AuthRequestContext = Field(description="Context information about the request being authenticated")
|
||||
|
||||
|
||||
class AuthResponse(BaseModel):
|
||||
"""The format of the authentication response from the auth endpoint."""
|
||||
|
||||
access_attributes: Optional[AccessAttributes] = Field(
|
||||
default=None,
|
||||
description="""
|
||||
Structured user attributes for attribute-based access control.
|
||||
|
||||
These attributes determine which resources the user can access.
|
||||
The model provides standard categories like "roles", "teams", "projects", and "namespaces".
|
||||
Each attribute category contains a list of values that the user has for that category.
|
||||
During access control checks, these values are compared against resource requirements.
|
||||
|
||||
Example with standard categories:
|
||||
```json
|
||||
{
|
||||
"roles": ["admin", "data-scientist"],
|
||||
"teams": ["ml-team"],
|
||||
"projects": ["llama-3"],
|
||||
"namespaces": ["research"]
|
||||
}
|
||||
```
|
||||
""",
|
||||
)
|
||||
|
||||
message: Optional[str] = Field(
|
||||
default=None, description="Optional message providing additional context about the authentication result."
|
||||
)
|
||||
|
||||
|
||||
class AuthenticationMiddleware:
|
||||
"""Middleware that authenticates requests using an external auth endpoint.
|
||||
"""Middleware that authenticates requests using configured authentication provider.
|
||||
|
||||
This middleware:
|
||||
1. Extracts the Bearer token from the Authorization header
|
||||
2. Sends it to the configured auth endpoint along with request details
|
||||
3. Validates the response and extracts user attributes
|
||||
2. Uses the configured auth provider to validate the token
|
||||
3. Extracts user attributes from the provider's response
|
||||
4. Makes these attributes available to the route handlers for access control
|
||||
|
||||
Authentication Request Format:
|
||||
The middleware supports multiple authentication providers through the AuthProvider interface:
|
||||
- Kubernetes: Validates tokens against the Kubernetes API server
|
||||
- Custom: Validates tokens against a custom endpoint
|
||||
|
||||
Authentication Request Format for Custom Auth Provider:
|
||||
```json
|
||||
{
|
||||
"api_key": "the-api-key-extracted-from-auth-header",
|
||||
|
@ -105,21 +60,26 @@ class AuthenticationMiddleware:
|
|||
}
|
||||
```
|
||||
|
||||
Token Validation:
|
||||
Each provider implements its own token validation logic:
|
||||
- Kubernetes: Uses TokenReview API to validate service account tokens
|
||||
- Custom: Sends token to custom endpoint for validation
|
||||
|
||||
Attribute-Based Access Control:
|
||||
The attributes returned by the auth endpoint are used to determine which
|
||||
The attributes returned by the auth provider are used to determine which
|
||||
resources the user can access. Resources can specify required attributes
|
||||
using the access_attributes field. For a user to access a resource:
|
||||
|
||||
1. All attribute categories specified in the resource must be present in the user's attributes
|
||||
2. For each category, the user must have at least one matching value
|
||||
|
||||
If the auth endpoint doesn't return any attributes, the user will only be able to
|
||||
If the auth provider doesn't return any attributes, the user will only be able to
|
||||
access resources that don't have access_attributes defined.
|
||||
"""
|
||||
|
||||
def __init__(self, app, auth_endpoint):
|
||||
def __init__(self, app, auth_config: AuthProviderConfig):
|
||||
self.app = app
|
||||
self.auth_endpoint = auth_endpoint
|
||||
self.auth_provider = create_auth_provider(auth_config)
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope["type"] == "http":
|
||||
|
@ -129,66 +89,34 @@ class AuthenticationMiddleware:
|
|||
if not auth_header or not auth_header.startswith("Bearer "):
|
||||
return await self._send_auth_error(send, "Missing or invalid Authorization header")
|
||||
|
||||
api_key = auth_header.split("Bearer ", 1)[1]
|
||||
token = auth_header.split("Bearer ", 1)[1]
|
||||
|
||||
path = scope.get("path", "")
|
||||
request_headers = {k.decode(): v.decode() for k, v in headers.items()}
|
||||
|
||||
# Remove sensitive headers
|
||||
if "authorization" in request_headers:
|
||||
del request_headers["authorization"]
|
||||
|
||||
query_string = scope.get("query_string", b"").decode()
|
||||
params = parse_qs(query_string)
|
||||
|
||||
# Build the auth request model
|
||||
auth_request = AuthRequest(
|
||||
api_key=api_key,
|
||||
request=AuthRequestContext(
|
||||
path=path,
|
||||
headers=request_headers,
|
||||
params=params,
|
||||
),
|
||||
)
|
||||
|
||||
# Validate with authentication endpoint
|
||||
# Validate token and get access attributes
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self.auth_endpoint,
|
||||
json=auth_request.model_dump(),
|
||||
timeout=10.0, # Add a reasonable timeout
|
||||
)
|
||||
if response.status_code != 200:
|
||||
logger.warning(f"Authentication failed: {response.status_code}")
|
||||
return await self._send_auth_error(send, "Authentication failed")
|
||||
|
||||
# Parse and validate the auth response
|
||||
try:
|
||||
response_data = response.json()
|
||||
auth_response = AuthResponse(**response_data)
|
||||
|
||||
# Store attributes in request scope for access control
|
||||
if auth_response.access_attributes:
|
||||
user_attributes = auth_response.access_attributes.model_dump(exclude_none=True)
|
||||
else:
|
||||
logger.warning("No access attributes, setting namespace to api_key by default")
|
||||
user_attributes = {
|
||||
"namespaces": [api_key],
|
||||
}
|
||||
|
||||
scope["user_attributes"] = user_attributes
|
||||
logger.debug(f"Authentication successful: {len(user_attributes)} attributes")
|
||||
except Exception:
|
||||
logger.exception("Error parsing authentication response")
|
||||
return await self._send_auth_error(send, "Invalid authentication response format")
|
||||
access_attributes = await self.auth_provider.validate_token(token, scope)
|
||||
except httpx.TimeoutException:
|
||||
logger.exception("Authentication request timed out")
|
||||
return await self._send_auth_error(send, "Authentication service timeout")
|
||||
except ValueError as e:
|
||||
logger.exception("Error during authentication")
|
||||
return await self._send_auth_error(send, str(e))
|
||||
except Exception:
|
||||
logger.exception("Error during authentication")
|
||||
return await self._send_auth_error(send, "Authentication service error")
|
||||
|
||||
# Store attributes in request scope for access control
|
||||
if access_attributes:
|
||||
user_attributes = access_attributes.model_dump(exclude_none=True)
|
||||
else:
|
||||
logger.warning("No access attributes, setting namespace to token by default")
|
||||
user_attributes = {
|
||||
"namespaces": [token],
|
||||
}
|
||||
|
||||
# Store attributes in request scope
|
||||
scope["user_attributes"] = user_attributes
|
||||
logger.debug(f"Authentication successful: {len(scope['user_attributes'])} attributes")
|
||||
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
async def _send_auth_error(self, send, message):
|
||||
|
|
262
llama_stack/distribution/server/auth_providers.py
Normal file
262
llama_stack/distribution/server/auth_providers.py
Normal file
|
@ -0,0 +1,262 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.distribution.datatypes import AccessAttributes
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="auth")
|
||||
|
||||
|
||||
class AuthResponse(BaseModel):
|
||||
"""The format of the authentication response from the auth endpoint."""
|
||||
|
||||
access_attributes: Optional[AccessAttributes] = Field(
|
||||
default=None,
|
||||
description="""
|
||||
Structured user attributes for attribute-based access control.
|
||||
|
||||
These attributes determine which resources the user can access.
|
||||
The model provides standard categories like "roles", "teams", "projects", and "namespaces".
|
||||
Each attribute category contains a list of values that the user has for that category.
|
||||
During access control checks, these values are compared against resource requirements.
|
||||
|
||||
Example with standard categories:
|
||||
```json
|
||||
{
|
||||
"roles": ["admin", "data-scientist"],
|
||||
"teams": ["ml-team"],
|
||||
"projects": ["llama-3"],
|
||||
"namespaces": ["research"]
|
||||
}
|
||||
```
|
||||
""",
|
||||
)
|
||||
|
||||
message: Optional[str] = Field(
|
||||
default=None, description="Optional message providing additional context about the authentication result."
|
||||
)
|
||||
|
||||
|
||||
class AuthRequestContext(BaseModel):
|
||||
path: str = Field(description="The path of the request being authenticated")
|
||||
|
||||
headers: Dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
|
||||
|
||||
params: Dict[str, List[str]] = Field(
|
||||
description="Query parameters from the original request, parsed as dictionary of lists"
|
||||
)
|
||||
|
||||
|
||||
class AuthRequest(BaseModel):
|
||||
api_key: str = Field(description="The API key extracted from the Authorization header")
|
||||
|
||||
request: AuthRequestContext = Field(description="Context information about the request being authenticated")
|
||||
|
||||
|
||||
class AuthProviderType(str, Enum):
|
||||
"""Supported authentication provider types."""
|
||||
|
||||
KUBERNETES = "kubernetes"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class AuthProviderConfig(BaseModel):
|
||||
"""Base configuration for authentication providers."""
|
||||
|
||||
provider_type: AuthProviderType = Field(..., description="Type of authentication provider")
|
||||
config: Dict[str, str] = Field(..., description="Provider-specific configuration")
|
||||
|
||||
|
||||
class AuthProvider(ABC):
|
||||
"""Abstract base class for authentication providers."""
|
||||
|
||||
@abstractmethod
|
||||
async def validate_token(self, token: str, scope: Optional[Dict] = None) -> Optional[AccessAttributes]:
|
||||
"""Validate a token and return access attributes."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def close(self):
|
||||
"""Clean up any resources."""
|
||||
pass
|
||||
|
||||
|
||||
class KubernetesAuthProvider(AuthProvider):
|
||||
"""Kubernetes authentication provider that validates tokens against the Kubernetes API server."""
|
||||
|
||||
def __init__(self, config: Dict[str, str]):
|
||||
self.api_server_url = config["api_server_url"]
|
||||
self.ca_cert_path = config.get("ca_cert_path")
|
||||
self._client = None
|
||||
|
||||
async def _get_client(self):
|
||||
"""Get or create a Kubernetes client."""
|
||||
if self._client is None:
|
||||
# kubernetes-client has not async support, see:
|
||||
# https://github.com/kubernetes-client/python/issues/323
|
||||
from kubernetes import client
|
||||
from kubernetes.client import ApiClient
|
||||
|
||||
# Configure the client
|
||||
configuration = client.Configuration()
|
||||
configuration.host = self.api_server_url
|
||||
if self.ca_cert_path:
|
||||
configuration.ssl_ca_cert = self.ca_cert_path
|
||||
configuration.verify_ssl = bool(self.ca_cert_path)
|
||||
|
||||
# Create API client
|
||||
self._client = ApiClient(configuration)
|
||||
return self._client
|
||||
|
||||
async def validate_token(self, token: str, scope: Optional[Dict] = None) -> Optional[AccessAttributes]:
|
||||
"""Validate a Kubernetes token and return access attributes."""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
|
||||
# Set the token in the client
|
||||
client.set_default_header("Authorization", f"Bearer {token}")
|
||||
|
||||
# Make a request to validate the token
|
||||
# We use the /api endpoint which requires authentication
|
||||
from kubernetes.client import CoreV1Api
|
||||
|
||||
api = CoreV1Api(client)
|
||||
api.get_api_resources(_request_timeout=3.0) # Set timeout for this specific request
|
||||
|
||||
# If we get here, the token is valid
|
||||
# Extract user info from the token claims
|
||||
import base64
|
||||
|
||||
# Decode the token (without verification since we've already validated it)
|
||||
token_parts = token.split(".")
|
||||
payload = json.loads(base64.b64decode(token_parts[1] + "=" * (-len(token_parts[1]) % 4)))
|
||||
|
||||
# Extract user information from the token
|
||||
username = payload.get("sub", "")
|
||||
groups = payload.get("groups", [])
|
||||
|
||||
return AccessAttributes(
|
||||
roles=[username], # Use username as a role
|
||||
teams=groups, # Use Kubernetes groups as teams
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to validate Kubernetes token")
|
||||
raise ValueError("Invalid or expired token") from e
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP client."""
|
||||
if self._client:
|
||||
self._client.close()
|
||||
self._client = None
|
||||
|
||||
|
||||
class CustomAuthProvider(AuthProvider):
|
||||
"""Custom authentication provider that uses an external endpoint."""
|
||||
|
||||
def __init__(self, config: Dict[str, str]):
|
||||
self.endpoint = config["endpoint"]
|
||||
self._client = None
|
||||
|
||||
async def validate_token(self, token: str, scope: Optional[Dict] = None) -> Optional[AccessAttributes]:
|
||||
"""Validate a token using the custom authentication endpoint."""
|
||||
if not self.endpoint:
|
||||
raise ValueError("Authentication endpoint not configured")
|
||||
|
||||
if scope is None:
|
||||
scope = {}
|
||||
|
||||
headers = dict(scope.get("headers", []))
|
||||
path = scope.get("path", "")
|
||||
request_headers = {k.decode(): v.decode() for k, v in headers.items()}
|
||||
|
||||
# Remove sensitive headers
|
||||
if "authorization" in request_headers:
|
||||
del request_headers["authorization"]
|
||||
|
||||
query_string = scope.get("query_string", b"").decode()
|
||||
params = parse_qs(query_string)
|
||||
|
||||
# Build the auth request model
|
||||
auth_request = AuthRequest(
|
||||
api_key=token,
|
||||
request=AuthRequestContext(
|
||||
path=path,
|
||||
headers=request_headers,
|
||||
params=params,
|
||||
),
|
||||
)
|
||||
|
||||
# Validate with authentication endpoint
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self.endpoint,
|
||||
json=auth_request.model_dump(),
|
||||
timeout=10.0, # Add a reasonable timeout
|
||||
)
|
||||
if response.status_code != 200:
|
||||
logger.warning(f"Authentication failed with status code: {response.status_code}")
|
||||
raise ValueError(f"Authentication failed: {response.status_code}")
|
||||
|
||||
# Parse and validate the auth response
|
||||
try:
|
||||
response_data = response.json()
|
||||
auth_response = AuthResponse(**response_data)
|
||||
|
||||
# Store attributes in request scope for access control
|
||||
if auth_response.access_attributes:
|
||||
return auth_response.access_attributes
|
||||
else:
|
||||
logger.warning("No access attributes, setting namespace to api_key by default")
|
||||
user_attributes = {
|
||||
"namespaces": [token],
|
||||
}
|
||||
|
||||
scope["user_attributes"] = user_attributes
|
||||
logger.debug(f"Authentication successful: {len(user_attributes)} attributes")
|
||||
return auth_response.access_attributes
|
||||
except Exception as e:
|
||||
logger.exception("Error parsing authentication response")
|
||||
raise ValueError("Invalid authentication response format") from e
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.exception("Authentication request timed out")
|
||||
raise
|
||||
except ValueError:
|
||||
# Re-raise ValueError exceptions to preserve their message
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error during authentication")
|
||||
raise ValueError("Authentication service error") from e
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP client."""
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
|
||||
def create_auth_provider(config: AuthProviderConfig) -> AuthProvider:
|
||||
"""Factory function to create the appropriate auth provider."""
|
||||
provider_type = config.provider_type.lower()
|
||||
|
||||
if provider_type == "kubernetes":
|
||||
return KubernetesAuthProvider(config.config)
|
||||
elif provider_type == "custom":
|
||||
return CustomAuthProvider(config.config)
|
||||
else:
|
||||
supported_providers = ", ".join([t.value for t in AuthProviderType])
|
||||
raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}")
|
|
@ -419,9 +419,9 @@ def main(args: Optional[argparse.Namespace] = None):
|
|||
app.add_middleware(ClientVersionMiddleware)
|
||||
|
||||
# Add authentication middleware if configured
|
||||
if config.server.auth and config.server.auth.endpoint:
|
||||
logger.info(f"Enabling authentication with endpoint: {config.server.auth.endpoint}")
|
||||
app.add_middleware(AuthenticationMiddleware, auth_endpoint=config.server.auth.endpoint)
|
||||
if config.server.auth:
|
||||
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_type.value}")
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
|
||||
|
||||
try:
|
||||
impls = asyncio.run(construct_stack(config))
|
||||
|
|
|
@ -94,12 +94,16 @@ def tool_chat_page():
|
|||
st.subheader("Agent Configurations")
|
||||
st.subheader("Agent Type")
|
||||
agent_type = st.radio(
|
||||
"Select Agent Type",
|
||||
[AgentType.REGULAR, AgentType.REACT],
|
||||
format_func=lambda x: x.value,
|
||||
label="Select Agent Type",
|
||||
options=["Regular", "ReAct"],
|
||||
on_change=reset_agent,
|
||||
)
|
||||
|
||||
if agent_type == "ReAct":
|
||||
agent_type = AgentType.REACT
|
||||
else:
|
||||
agent_type = AgentType.REGULAR
|
||||
|
||||
max_tokens = st.slider(
|
||||
"Max Tokens",
|
||||
min_value=0,
|
||||
|
|
|
@ -792,6 +792,13 @@ def llama3_3_instruct_models() -> List[Model]:
|
|||
@lru_cache
|
||||
def safety_models() -> List[Model]:
|
||||
return [
|
||||
Model(
|
||||
core_model_id=CoreModelId.llama_guard_4_12b,
|
||||
description="Llama Guard v4 12b system safety model",
|
||||
huggingface_repo="meta-llama/Llama-Guard-4-12B",
|
||||
arch_args={},
|
||||
pth_file_count=1,
|
||||
),
|
||||
Model(
|
||||
core_model_id=CoreModelId.llama_guard_3_11b_vision,
|
||||
description="Llama Guard v3 11b vision system safety model",
|
||||
|
|
|
@ -81,6 +81,7 @@ class CoreModelId(Enum):
|
|||
llama_guard_2_8b = "Llama-Guard-2-8B"
|
||||
llama_guard_3_11b_vision = "Llama-Guard-3-11B-Vision"
|
||||
llama_guard_3_1b = "Llama-Guard-3-1B"
|
||||
llama_guard_4_12b = "Llama-Guard-4-12B"
|
||||
|
||||
|
||||
def is_multimodal(model_id) -> bool:
|
||||
|
@ -148,6 +149,7 @@ def model_family(model_id) -> ModelFamily:
|
|||
CoreModelId.llama_guard_2_8b,
|
||||
CoreModelId.llama_guard_3_11b_vision,
|
||||
CoreModelId.llama_guard_3_1b,
|
||||
CoreModelId.llama_guard_4_12b,
|
||||
]:
|
||||
return ModelFamily.safety
|
||||
else:
|
||||
|
@ -225,5 +227,7 @@ class Model(BaseModel):
|
|||
CoreModelId.llama_guard_3_1b,
|
||||
]:
|
||||
return 131072
|
||||
elif self.core_model_id == CoreModelId.llama_guard_4_12b:
|
||||
return 8192
|
||||
else:
|
||||
raise ValueError(f"Unknown max_seq_len for {self.core_model_id}")
|
||||
|
|
|
@ -23,6 +23,9 @@ from llama_stack.apis.agents import (
|
|||
Document,
|
||||
ListAgentSessionsResponse,
|
||||
ListAgentsResponse,
|
||||
OpenAIResponseInputMessage,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseObject,
|
||||
Session,
|
||||
Turn,
|
||||
)
|
||||
|
@ -40,6 +43,7 @@ from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_imp
|
|||
|
||||
from .agent_instance import ChatAgent
|
||||
from .config import MetaReferenceAgentsImplConfig
|
||||
from .openai_responses import OpenAIResponsesImpl
|
||||
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
|
@ -63,9 +67,16 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self.tool_groups_api = tool_groups_api
|
||||
|
||||
self.in_memory_store = InmemoryKVStoreImpl()
|
||||
self.openai_responses_impl = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.persistence_store = await kvstore_impl(self.config.persistence_store)
|
||||
self.openai_responses_impl = OpenAIResponsesImpl(
|
||||
self.persistence_store,
|
||||
inference_api=self.inference_api,
|
||||
tool_groups_api=self.tool_groups_api,
|
||||
tool_runtime_api=self.tool_runtime_api,
|
||||
)
|
||||
|
||||
# check if "bwrap" is available
|
||||
if not shutil.which("bwrap"):
|
||||
|
@ -244,3 +255,23 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
agent_id: str,
|
||||
) -> ListAgentSessionsResponse:
|
||||
pass
|
||||
|
||||
# OpenAI responses
|
||||
async def get_openai_response(
|
||||
self,
|
||||
id: str,
|
||||
) -> OpenAIResponseObject:
|
||||
return await self.openai_responses_impl.get_openai_response(id)
|
||||
|
||||
async def create_openai_response(
|
||||
self,
|
||||
input: Union[str, List[OpenAIResponseInputMessage]],
|
||||
model: str,
|
||||
previous_response_id: Optional[str] = None,
|
||||
store: Optional[bool] = True,
|
||||
stream: Optional[bool] = False,
|
||||
tools: Optional[List[OpenAIResponseInputTool]] = None,
|
||||
) -> OpenAIResponseObject:
|
||||
return await self.openai_responses_impl.create_openai_response(
|
||||
input, model, previous_response_id, store, stream, tools
|
||||
)
|
||||
|
|
|
@ -0,0 +1,319 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import AsyncIterator, List, Optional, Union, cast
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInputMessage,
|
||||
OpenAIResponseInputMessageContentImage,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseObjectStreamResponseCompleted,
|
||||
OpenAIResponseObjectStreamResponseCreated,
|
||||
OpenAIResponseOutput,
|
||||
OpenAIResponseOutputMessage,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
Inference,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIChoice,
|
||||
OpenAIImageURL,
|
||||
OpenAIMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.apis.tools.tools import ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
logger = get_logger(name=__name__, category="openai_responses")
|
||||
|
||||
OPENAI_RESPONSES_PREFIX = "openai_responses:"
|
||||
|
||||
|
||||
async def _previous_response_to_messages(previous_response: OpenAIResponseObject) -> List[OpenAIMessageParam]:
|
||||
messages: List[OpenAIMessageParam] = []
|
||||
for output_message in previous_response.output:
|
||||
if isinstance(output_message, OpenAIResponseOutputMessage):
|
||||
messages.append(OpenAIAssistantMessageParam(content=output_message.content[0].text))
|
||||
return messages
|
||||
|
||||
|
||||
async def _openai_choices_to_output_messages(choices: List[OpenAIChoice]) -> List[OpenAIResponseOutputMessage]:
|
||||
output_messages = []
|
||||
for choice in choices:
|
||||
output_content = ""
|
||||
if isinstance(choice.message.content, str):
|
||||
output_content = choice.message.content
|
||||
elif isinstance(choice.message.content, OpenAIChatCompletionContentPartTextParam):
|
||||
output_content = choice.message.content.text
|
||||
# TODO: handle image content
|
||||
output_messages.append(
|
||||
OpenAIResponseOutputMessage(
|
||||
id=f"msg_{uuid.uuid4()}",
|
||||
content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)],
|
||||
status="completed",
|
||||
)
|
||||
)
|
||||
return output_messages
|
||||
|
||||
|
||||
class OpenAIResponsesImpl:
|
||||
def __init__(
|
||||
self,
|
||||
persistence_store: KVStore,
|
||||
inference_api: Inference,
|
||||
tool_groups_api: ToolGroups,
|
||||
tool_runtime_api: ToolRuntime,
|
||||
):
|
||||
self.persistence_store = persistence_store
|
||||
self.inference_api = inference_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
|
||||
async def get_openai_response(
|
||||
self,
|
||||
id: str,
|
||||
) -> OpenAIResponseObject:
|
||||
key = f"{OPENAI_RESPONSES_PREFIX}{id}"
|
||||
response_json = await self.persistence_store.get(key=key)
|
||||
if response_json is None:
|
||||
raise ValueError(f"OpenAI response with id '{id}' not found")
|
||||
return OpenAIResponseObject.model_validate_json(response_json)
|
||||
|
||||
async def create_openai_response(
|
||||
self,
|
||||
input: Union[str, List[OpenAIResponseInputMessage]],
|
||||
model: str,
|
||||
previous_response_id: Optional[str] = None,
|
||||
store: Optional[bool] = True,
|
||||
stream: Optional[bool] = False,
|
||||
tools: Optional[List[OpenAIResponseInputTool]] = None,
|
||||
):
|
||||
stream = False if stream is None else stream
|
||||
|
||||
messages: List[OpenAIMessageParam] = []
|
||||
if previous_response_id:
|
||||
previous_response = await self.get_openai_response(previous_response_id)
|
||||
messages.extend(await _previous_response_to_messages(previous_response))
|
||||
# TODO: refactor this user_content parsing out into a separate method
|
||||
user_content: Union[str, List[OpenAIChatCompletionContentPartParam]] = ""
|
||||
if isinstance(input, list):
|
||||
user_content = []
|
||||
for user_input in input:
|
||||
if isinstance(user_input.content, list):
|
||||
for user_input_content in user_input.content:
|
||||
if isinstance(user_input_content, OpenAIResponseInputMessageContentText):
|
||||
user_content.append(OpenAIChatCompletionContentPartTextParam(text=user_input_content.text))
|
||||
elif isinstance(user_input_content, OpenAIResponseInputMessageContentImage):
|
||||
if user_input_content.image_url:
|
||||
image_url = OpenAIImageURL(
|
||||
url=user_input_content.image_url, detail=user_input_content.detail
|
||||
)
|
||||
user_content.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url))
|
||||
else:
|
||||
user_content.append(OpenAIChatCompletionContentPartTextParam(text=user_input.content))
|
||||
else:
|
||||
user_content = input
|
||||
messages.append(OpenAIUserMessageParam(content=user_content))
|
||||
|
||||
chat_tools = await self._convert_response_tools_to_chat_tools(tools) if tools else None
|
||||
chat_response = await self.inference_api.openai_chat_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=chat_tools,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
if stream:
|
||||
# TODO: refactor this into a separate method that handles streaming
|
||||
chat_response_id = ""
|
||||
chat_response_content = []
|
||||
# TODO: these chunk_ fields are hacky and only take the last chunk into account
|
||||
chunk_created = 0
|
||||
chunk_model = ""
|
||||
chunk_finish_reason = ""
|
||||
async for chunk in chat_response:
|
||||
chat_response_id = chunk.id
|
||||
chunk_created = chunk.created
|
||||
chunk_model = chunk.model
|
||||
for chunk_choice in chunk.choices:
|
||||
# TODO: this only works for text content
|
||||
chat_response_content.append(chunk_choice.delta.content or "")
|
||||
if chunk_choice.finish_reason:
|
||||
chunk_finish_reason = chunk_choice.finish_reason
|
||||
assistant_message = OpenAIAssistantMessageParam(content="".join(chat_response_content))
|
||||
chat_response = OpenAIChatCompletion(
|
||||
id=chat_response_id,
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
message=assistant_message,
|
||||
finish_reason=chunk_finish_reason,
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
created=chunk_created,
|
||||
model=chunk_model,
|
||||
)
|
||||
else:
|
||||
# dump and reload to map to our pydantic types
|
||||
chat_response = OpenAIChatCompletion(**chat_response.model_dump())
|
||||
|
||||
output_messages: List[OpenAIResponseOutput] = []
|
||||
if chat_response.choices[0].message.tool_calls:
|
||||
output_messages.extend(
|
||||
await self._execute_tool_and_return_final_output(model, stream, chat_response, messages)
|
||||
)
|
||||
else:
|
||||
output_messages.extend(await _openai_choices_to_output_messages(chat_response.choices))
|
||||
response = OpenAIResponseObject(
|
||||
created_at=chat_response.created,
|
||||
id=f"resp-{uuid.uuid4()}",
|
||||
model=model,
|
||||
object="response",
|
||||
status="completed",
|
||||
output=output_messages,
|
||||
)
|
||||
|
||||
if store:
|
||||
# Store in kvstore
|
||||
key = f"{OPENAI_RESPONSES_PREFIX}{response.id}"
|
||||
await self.persistence_store.set(
|
||||
key=key,
|
||||
value=response.model_dump_json(),
|
||||
)
|
||||
|
||||
if stream:
|
||||
|
||||
async def async_response() -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# TODO: response created should actually get emitted much earlier in the process
|
||||
yield OpenAIResponseObjectStreamResponseCreated(response=response)
|
||||
yield OpenAIResponseObjectStreamResponseCompleted(response=response)
|
||||
|
||||
return async_response()
|
||||
|
||||
return response
|
||||
|
||||
async def _convert_response_tools_to_chat_tools(
|
||||
self, tools: List[OpenAIResponseInputTool]
|
||||
) -> List[ChatCompletionToolParam]:
|
||||
chat_tools: List[ChatCompletionToolParam] = []
|
||||
for input_tool in tools:
|
||||
# TODO: Handle other tool types
|
||||
if input_tool.type == "web_search":
|
||||
tool_name = "web_search"
|
||||
tool = await self.tool_groups_api.get_tool(tool_name)
|
||||
tool_def = ToolDefinition(
|
||||
tool_name=tool_name,
|
||||
description=tool.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool.parameters
|
||||
},
|
||||
)
|
||||
chat_tool = convert_tooldef_to_openai_tool(tool_def)
|
||||
chat_tools.append(chat_tool)
|
||||
else:
|
||||
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
|
||||
return chat_tools
|
||||
|
||||
async def _execute_tool_and_return_final_output(
|
||||
self, model_id: str, stream: bool, chat_response: OpenAIChatCompletion, messages: List[OpenAIMessageParam]
|
||||
) -> List[OpenAIResponseOutput]:
|
||||
output_messages: List[OpenAIResponseOutput] = []
|
||||
choice = chat_response.choices[0]
|
||||
|
||||
# If the choice is not an assistant message, we don't need to execute any tools
|
||||
if not isinstance(choice.message, OpenAIAssistantMessageParam):
|
||||
return output_messages
|
||||
|
||||
# If the assistant message doesn't have any tool calls, we don't need to execute any tools
|
||||
if not choice.message.tool_calls:
|
||||
return output_messages
|
||||
|
||||
# Add the assistant message with tool_calls response to the messages list
|
||||
messages.append(choice.message)
|
||||
|
||||
for tool_call in choice.message.tool_calls:
|
||||
tool_call_id = tool_call.id
|
||||
function = tool_call.function
|
||||
|
||||
# If for some reason the tool call doesn't have a function or id, we can't execute it
|
||||
if not function or not tool_call_id:
|
||||
continue
|
||||
|
||||
# TODO: telemetry spans for tool calls
|
||||
result = await self._execute_tool_call(function)
|
||||
|
||||
# Handle tool call failure
|
||||
if not result:
|
||||
output_messages.append(
|
||||
OpenAIResponseOutputMessageWebSearchToolCall(
|
||||
id=tool_call_id,
|
||||
status="failed",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
output_messages.append(
|
||||
OpenAIResponseOutputMessageWebSearchToolCall(
|
||||
id=tool_call_id,
|
||||
status="completed",
|
||||
),
|
||||
)
|
||||
|
||||
result_content = ""
|
||||
# TODO: handle other result content types and lists
|
||||
if isinstance(result.content, str):
|
||||
result_content = result.content
|
||||
messages.append(OpenAIToolMessageParam(content=result_content, tool_call_id=tool_call_id))
|
||||
|
||||
tool_results_chat_response = await self.inference_api.openai_chat_completion(
|
||||
model=model_id,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
)
|
||||
# type cast to appease mypy
|
||||
tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response)
|
||||
tool_final_outputs = await _openai_choices_to_output_messages(tool_results_chat_response.choices)
|
||||
# TODO: Wire in annotations with URLs, titles, etc to these output messages
|
||||
output_messages.extend(tool_final_outputs)
|
||||
return output_messages
|
||||
|
||||
async def _execute_tool_call(
|
||||
self,
|
||||
function: OpenAIChatCompletionToolCallFunction,
|
||||
) -> Optional[ToolInvocationResult]:
|
||||
if not function.name:
|
||||
return None
|
||||
function_args = json.loads(function.arguments) if function.arguments else {}
|
||||
logger.info(f"executing tool call: {function.name} with args: {function_args}")
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name=function.name,
|
||||
kwargs=function_args,
|
||||
)
|
||||
logger.debug(f"tool call {function.name} completed with result: {result}")
|
||||
return result
|
|
@ -17,10 +17,8 @@ from llama_stack.apis.common.type_system import (
|
|||
DialogType,
|
||||
StringType,
|
||||
)
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.providers.utils.common.data_schema_validator import (
|
||||
ColumnName,
|
||||
validate_dataset_schema,
|
||||
)
|
||||
|
||||
EXPECTED_DATASET_SCHEMA: dict[str, list[dict[str, Any]]] = {
|
||||
|
@ -36,21 +34,3 @@ EXPECTED_DATASET_SCHEMA: dict[str, list[dict[str, Any]]] = {
|
|||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
async def validate_input_dataset_schema(
|
||||
datasets_api: Datasets,
|
||||
dataset_id: str,
|
||||
dataset_type: str,
|
||||
) -> None:
|
||||
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
if not dataset_def:
|
||||
raise ValueError(f"Dataset {dataset_id} does not exist.")
|
||||
|
||||
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
||||
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
|
||||
|
||||
if dataset_type not in EXPECTED_DATASET_SCHEMA:
|
||||
raise ValueError(f"Dataset type {dataset_type} is not supported.")
|
||||
|
||||
validate_dataset_schema(dataset_def.dataset_schema, EXPECTED_DATASET_SCHEMA[dataset_type])
|
||||
|
|
|
@ -48,9 +48,6 @@ from llama_stack.apis.post_training import (
|
|||
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.providers.inline.post_training.common.validator import (
|
||||
validate_input_dataset_schema,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.torchtune.common import utils
|
||||
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
|
||||
TorchtuneCheckpointer,
|
||||
|
@ -348,11 +345,9 @@ class LoraFinetuningSingleDevice:
|
|||
all_rows = await fetch_rows(dataset_id)
|
||||
rows = all_rows.data
|
||||
|
||||
await validate_input_dataset_schema(
|
||||
datasets_api=self.datasets_api,
|
||||
dataset_id=dataset_id,
|
||||
dataset_type=self._data_format.value,
|
||||
)
|
||||
# TODO (xiyan): validate dataset schema
|
||||
# dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
|
||||
data_transform = await utils.get_data_transform(self._data_format)
|
||||
ds = SFTDataset(
|
||||
rows,
|
||||
|
|
|
@ -30,7 +30,7 @@ class TelemetryConfig(BaseModel):
|
|||
)
|
||||
service_name: str = Field(
|
||||
# service name is always the same, use zero-width space to avoid clutter
|
||||
default="",
|
||||
default="",
|
||||
description="The service name to use for telemetry",
|
||||
)
|
||||
sinks: List[TelemetrySink] = Field(
|
||||
|
@ -52,7 +52,7 @@ class TelemetryConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> Dict[str, Any]:
|
||||
return {
|
||||
"service_name": "${env.OTEL_SERVICE_NAME:}",
|
||||
"service_name": "${env.OTEL_SERVICE_NAME:}",
|
||||
"sinks": "${env.TELEMETRY_SINKS:console,sqlite}",
|
||||
"sqlite_db_path": "${env.SQLITE_DB_PATH:" + __distro_dir__ + "/" + db_name + "}",
|
||||
"sqlite_db_path": "${env.SQLITE_STORE_DIR:" + __distro_dir__ + "}/" + db_name,
|
||||
}
|
||||
|
|
|
@ -227,6 +227,16 @@ def available_providers() -> List[ProviderSpec]:
|
|||
provider_data_validator="llama_stack.providers.remote.inference.fireworks_openai_compat.config.FireworksProviderDataValidator",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="llama-openai-compat",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.llama_openai_compat",
|
||||
config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
|
||||
from .config import LlamaCompatConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: LlamaCompatConfig, _deps) -> Inference:
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .llama import LlamaCompatInferenceAdapter
|
||||
|
||||
adapter = LlamaCompatInferenceAdapter(config)
|
||||
return adapter
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class LlamaProviderDataValidator(BaseModel):
|
||||
llama_api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="API key for api.llama models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class LlamaCompatConfig(BaseModel):
|
||||
api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The Llama API key",
|
||||
)
|
||||
|
||||
openai_compat_api_base: str = Field(
|
||||
default="https://api.llama.com/compat/v1/",
|
||||
description="The URL for the Llama API server",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.LLAMA_API_KEY}", **kwargs) -> Dict[str, Any]:
|
||||
return {
|
||||
"openai_compat_api_base": "https://api.llama.com/compat/v1/",
|
||||
"api_key": api_key,
|
||||
}
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.remote.inference.llama_openai_compat.config import (
|
||||
LlamaCompatConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
|
||||
LiteLLMOpenAIMixin,
|
||||
)
|
||||
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
_config: LlamaCompatConfig
|
||||
|
||||
def __init__(self, config: LlamaCompatConfig):
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="llama_api_key",
|
||||
openai_compat_api_base=config.openai_compat_api_base,
|
||||
)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
|
@ -0,0 +1,25 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"Llama-3.3-70B-Instruct",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"Llama-4-Scout-17B-16E-Instruct-FP8",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"Llama-4-Maverick-17B-128E-Instruct-FP8",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
),
|
||||
]
|
|
@ -433,6 +433,12 @@ class OllamaInferenceAdapter(
|
|||
user: Optional[str] = None,
|
||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
model_obj = await self._get_model(model)
|
||||
|
||||
# ollama still makes tool calls even when tool_choice is "none"
|
||||
# so we need to remove the tools in that case
|
||||
if tool_choice == "none" and tools is not None:
|
||||
tools = None
|
||||
|
||||
params = {
|
||||
k: v
|
||||
for k, v in {
|
||||
|
|
|
@ -90,6 +90,9 @@ class LiteLLMOpenAIMixin(
|
|||
raise ValueError(f"Unsupported model: {model.provider_resource_id}")
|
||||
return model
|
||||
|
||||
def get_litellm_model_name(self, model_id: str) -> str:
|
||||
return "openai/" + model_id if self.is_openai_compat else model_id
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
@ -130,8 +133,7 @@ class LiteLLMOpenAIMixin(
|
|||
)
|
||||
|
||||
params = await self._get_params(request)
|
||||
if self.is_openai_compat:
|
||||
params["model"] = "openai/" + params["model"]
|
||||
params["model"] = self.get_litellm_model_name(params["model"])
|
||||
|
||||
logger.debug(f"params to litellm (openai compat): {params}")
|
||||
# unfortunately, we need to use synchronous litellm.completion here because litellm
|
||||
|
@ -220,21 +222,23 @@ class LiteLLMOpenAIMixin(
|
|||
else request.tool_config.tool_choice
|
||||
)
|
||||
|
||||
return {
|
||||
"model": request.model,
|
||||
"api_key": self.get_api_key(),
|
||||
"api_base": self.api_base,
|
||||
**input_dict,
|
||||
"stream": request.stream,
|
||||
**get_sampling_options(request.sampling_params),
|
||||
}
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
provider_data = self.get_request_provider_data()
|
||||
key_field = self.provider_data_api_key_field
|
||||
if provider_data and getattr(provider_data, key_field, None):
|
||||
api_key = getattr(provider_data, key_field)
|
||||
else:
|
||||
api_key = self.api_key_from_config
|
||||
|
||||
return {
|
||||
"model": request.model,
|
||||
"api_key": api_key,
|
||||
"api_base": self.api_base,
|
||||
**input_dict,
|
||||
"stream": request.stream,
|
||||
**get_sampling_options(request.sampling_params),
|
||||
}
|
||||
return api_key
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
|
@ -247,7 +251,7 @@ class LiteLLMOpenAIMixin(
|
|||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
response = litellm.embedding(
|
||||
model=model.provider_resource_id,
|
||||
model=self.get_litellm_model_name(model.provider_resource_id),
|
||||
input=[interleaved_content_as_str(content) for content in contents],
|
||||
)
|
||||
|
||||
|
@ -278,7 +282,7 @@ class LiteLLMOpenAIMixin(
|
|||
) -> OpenAICompletion:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
||||
prompt=prompt,
|
||||
best_of=best_of,
|
||||
echo=echo,
|
||||
|
@ -297,6 +301,8 @@ class LiteLLMOpenAIMixin(
|
|||
user=user,
|
||||
guided_choice=guided_choice,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
api_key=self.get_api_key(),
|
||||
api_base=self.api_base,
|
||||
)
|
||||
return await litellm.atext_completion(**params)
|
||||
|
||||
|
@ -328,7 +334,7 @@ class LiteLLMOpenAIMixin(
|
|||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||
model_obj = await self.model_store.get_model(model)
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id,
|
||||
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
|
@ -351,6 +357,8 @@ class LiteLLMOpenAIMixin(
|
|||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
api_key=self.get_api_key(),
|
||||
api_base=self.api_base,
|
||||
)
|
||||
return await litellm.acompletion(**params)
|
||||
|
||||
|
|
|
@ -638,10 +638,13 @@ async def convert_message_to_openai_dict_new(
|
|||
)
|
||||
for tool in message.tool_calls
|
||||
]
|
||||
params = {}
|
||||
if tool_calls:
|
||||
params["tool_calls"] = tool_calls
|
||||
out = OpenAIChatCompletionAssistantMessage(
|
||||
role="assistant",
|
||||
content=await _convert_message_content(message.content),
|
||||
tool_calls=tool_calls or None,
|
||||
**params,
|
||||
)
|
||||
elif isinstance(message, ToolResponseMessage):
|
||||
out = OpenAIChatCompletionToolMessage(
|
||||
|
|
|
@ -478,6 +478,8 @@ class JsonSchemaGenerator:
|
|||
}
|
||||
return ret
|
||||
elif origin_type is Literal:
|
||||
if len(typing.get_args(typ)) != 1:
|
||||
raise ValueError(f"Literal type {typ} has {len(typing.get_args(typ))} arguments")
|
||||
(literal_value,) = typing.get_args(typ) # unpack value of literal type
|
||||
schema = self.type_to_schema(type(literal_value))
|
||||
schema["const"] = literal_value
|
||||
|
|
|
@ -39,9 +39,9 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/bedrock/trace_store.db}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/bedrock}/trace_store.db
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -79,9 +79,9 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/cerebras/trace_store.db}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/cerebras}/trace_store.db
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
|
|
|
@ -42,9 +42,9 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/ci-tests/trace_store.db}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ci-tests}/trace_store.db
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -45,9 +45,9 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/dell/trace_store.db}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/trace_store.db
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -41,9 +41,9 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/dell/trace_store.db}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dell}/trace_store.db
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -344,6 +344,45 @@
|
|||
"sentence-transformers --no-deps",
|
||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||
],
|
||||
"llama_api": [
|
||||
"aiosqlite",
|
||||
"autoevals",
|
||||
"blobfile",
|
||||
"chardet",
|
||||
"chromadb-client",
|
||||
"datasets",
|
||||
"emoji",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"langdetect",
|
||||
"litellm",
|
||||
"matplotlib",
|
||||
"mcp",
|
||||
"nltk",
|
||||
"numpy",
|
||||
"openai",
|
||||
"opentelemetry-exporter-otlp-proto-http",
|
||||
"opentelemetry-sdk",
|
||||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"pythainlp",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"sqlite-vec",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn",
|
||||
"sentence-transformers --no-deps",
|
||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||
],
|
||||
"meta-reference-gpu": [
|
||||
"accelerate",
|
||||
"aiosqlite",
|
||||
|
|
|
@ -71,9 +71,9 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/dev/trace_store.db}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/dev}/trace_store.db
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -50,9 +50,9 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/fireworks/trace_store.db}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/trace_store.db
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -45,9 +45,9 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/fireworks/trace_store.db}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/fireworks}/trace_store.db
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -45,9 +45,9 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/groq/trace_store.db}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/groq}/trace_store.db
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -50,9 +50,9 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/hf-endpoint/trace_store.db}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/trace_store.db
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -45,9 +45,9 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/hf-endpoint/trace_store.db}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-endpoint}/trace_store.db
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -50,9 +50,9 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/hf-serverless/trace_store.db}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/trace_store.db
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -45,9 +45,9 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/hf-serverless/trace_store.db}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/hf-serverless}/trace_store.db
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
7
llama_stack/templates/llama_api/__init__.py
Normal file
7
llama_stack/templates/llama_api/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .llama_api import get_distribution_template # noqa: F401
|
33
llama_stack/templates/llama_api/build.yaml
Normal file
33
llama_stack/templates/llama_api/build.yaml
Normal file
|
@ -0,0 +1,33 @@
|
|||
version: '2'
|
||||
distribution_spec:
|
||||
description: Distribution for running e2e tests in CI
|
||||
providers:
|
||||
inference:
|
||||
- remote::llama-openai-compat
|
||||
- inline::sentence-transformers
|
||||
vector_io:
|
||||
- inline::sqlite-vec
|
||||
- remote::chromadb
|
||||
- remote::pgvector
|
||||
safety:
|
||||
- inline::llama-guard
|
||||
agents:
|
||||
- inline::meta-reference
|
||||
telemetry:
|
||||
- inline::meta-reference
|
||||
eval:
|
||||
- inline::meta-reference
|
||||
datasetio:
|
||||
- remote::huggingface
|
||||
- inline::localfs
|
||||
scoring:
|
||||
- inline::basic
|
||||
- inline::llm-as-judge
|
||||
- inline::braintrust
|
||||
tool_runtime:
|
||||
- remote::brave-search
|
||||
- remote::tavily-search
|
||||
- inline::code-interpreter
|
||||
- inline::rag-runtime
|
||||
- remote::model-context-protocol
|
||||
image_type: conda
|
159
llama_stack/templates/llama_api/llama_api.py
Normal file
159
llama_stack/templates/llama_api/llama_api.py
Normal file
|
@ -0,0 +1,159 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
from llama_stack.apis.models.models import ModelType
|
||||
from llama_stack.distribution.datatypes import (
|
||||
ModelInput,
|
||||
Provider,
|
||||
ShieldInput,
|
||||
ToolGroupInput,
|
||||
)
|
||||
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||
SentenceTransformersInferenceConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec.config import (
|
||||
SQLiteVectorIOConfig,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.llama_openai_compat.config import (
|
||||
LlamaCompatConfig,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.llama_openai_compat.models import (
|
||||
MODEL_ENTRIES as LLLAMA_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.pgvector.config import (
|
||||
PGVectorVectorIOConfig,
|
||||
)
|
||||
from llama_stack.templates.template import (
|
||||
DistributionTemplate,
|
||||
RunConfigSettings,
|
||||
get_model_registry,
|
||||
)
|
||||
|
||||
|
||||
def get_inference_providers() -> Tuple[List[Provider], List[ModelInput]]:
|
||||
# in this template, we allow each API key to be optional
|
||||
providers = [
|
||||
(
|
||||
"llama-openai-compat",
|
||||
LLLAMA_MODEL_ENTRIES,
|
||||
LlamaCompatConfig.sample_run_config(api_key="${env.LLAMA_API_KEY:}"),
|
||||
),
|
||||
]
|
||||
inference_providers = []
|
||||
available_models = {}
|
||||
for provider_id, model_entries, config in providers:
|
||||
inference_providers.append(
|
||||
Provider(
|
||||
provider_id=provider_id,
|
||||
provider_type=f"remote::{provider_id}",
|
||||
config=config,
|
||||
)
|
||||
)
|
||||
available_models[provider_id] = model_entries
|
||||
return inference_providers, available_models
|
||||
|
||||
|
||||
def get_distribution_template() -> DistributionTemplate:
|
||||
inference_providers, available_models = get_inference_providers()
|
||||
providers = {
|
||||
"inference": ([p.provider_type for p in inference_providers] + ["inline::sentence-transformers"]),
|
||||
"vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"],
|
||||
"safety": ["inline::llama-guard"],
|
||||
"agents": ["inline::meta-reference"],
|
||||
"telemetry": ["inline::meta-reference"],
|
||||
"eval": ["inline::meta-reference"],
|
||||
"datasetio": ["remote::huggingface", "inline::localfs"],
|
||||
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
|
||||
"tool_runtime": [
|
||||
"remote::brave-search",
|
||||
"remote::tavily-search",
|
||||
"inline::code-interpreter",
|
||||
"inline::rag-runtime",
|
||||
"remote::model-context-protocol",
|
||||
],
|
||||
}
|
||||
name = "llama_api"
|
||||
|
||||
vector_io_providers = [
|
||||
Provider(
|
||||
provider_id="sqlite-vec",
|
||||
provider_type="inline::sqlite-vec",
|
||||
config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.ENABLE_CHROMADB+chromadb}",
|
||||
provider_type="remote::chromadb",
|
||||
config=ChromaVectorIOConfig.sample_run_config(url="${env.CHROMADB_URL:}"),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.ENABLE_PGVECTOR+pgvector}",
|
||||
provider_type="remote::pgvector",
|
||||
config=PGVectorVectorIOConfig.sample_run_config(
|
||||
db="${env.PGVECTOR_DB:}",
|
||||
user="${env.PGVECTOR_USER:}",
|
||||
password="${env.PGVECTOR_PASSWORD:}",
|
||||
),
|
||||
),
|
||||
]
|
||||
embedding_provider = Provider(
|
||||
provider_id="sentence-transformers",
|
||||
provider_type="inline::sentence-transformers",
|
||||
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
||||
)
|
||||
|
||||
default_tool_groups = [
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::websearch",
|
||||
provider_id="tavily-search",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::rag",
|
||||
provider_id="rag-runtime",
|
||||
),
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::code_interpreter",
|
||||
provider_id="code-interpreter",
|
||||
),
|
||||
]
|
||||
embedding_model = ModelInput(
|
||||
model_id="all-MiniLM-L6-v2",
|
||||
provider_id=embedding_provider.provider_id,
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 384,
|
||||
},
|
||||
)
|
||||
|
||||
default_models = get_model_registry(available_models)
|
||||
return DistributionTemplate(
|
||||
name=name,
|
||||
distro_type="self_hosted",
|
||||
description="Distribution for running e2e tests in CI",
|
||||
container_image=None,
|
||||
template_path=None,
|
||||
providers=providers,
|
||||
available_models_by_provider=available_models,
|
||||
run_configs={
|
||||
"run.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
"inference": inference_providers + [embedding_provider],
|
||||
"vector_io": vector_io_providers,
|
||||
},
|
||||
default_models=default_models + [embedding_model],
|
||||
default_tool_groups=default_tool_groups,
|
||||
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
|
||||
),
|
||||
},
|
||||
run_config_env_vars={
|
||||
"LLAMA_STACK_PORT": (
|
||||
"8321",
|
||||
"Port for the Llama Stack distribution server",
|
||||
),
|
||||
},
|
||||
)
|
167
llama_stack/templates/llama_api/run.yaml
Normal file
167
llama_stack/templates/llama_api/run.yaml
Normal file
|
@ -0,0 +1,167 @@
|
|||
version: '2'
|
||||
image_name: llama_api
|
||||
apis:
|
||||
- agents
|
||||
- datasetio
|
||||
- eval
|
||||
- inference
|
||||
- safety
|
||||
- scoring
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
- vector_io
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: llama-openai-compat
|
||||
provider_type: remote::llama-openai-compat
|
||||
config:
|
||||
openai_compat_api_base: https://api.llama.com/compat/v1/
|
||||
api_key: ${env.LLAMA_API_KEY:}
|
||||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
config: {}
|
||||
vector_io:
|
||||
- provider_id: sqlite-vec
|
||||
provider_type: inline::sqlite-vec
|
||||
config:
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/sqlite_vec.db
|
||||
- provider_id: ${env.ENABLE_CHROMADB+chromadb}
|
||||
provider_type: remote::chromadb
|
||||
config:
|
||||
url: ${env.CHROMADB_URL:}
|
||||
- provider_id: ${env.ENABLE_PGVECTOR+pgvector}
|
||||
provider_type: remote::pgvector
|
||||
config:
|
||||
host: ${env.PGVECTOR_HOST:localhost}
|
||||
port: ${env.PGVECTOR_PORT:5432}
|
||||
db: ${env.PGVECTOR_DB:}
|
||||
user: ${env.PGVECTOR_USER:}
|
||||
password: ${env.PGVECTOR_PASSWORD:}
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config:
|
||||
excluded_categories: []
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/agents_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/trace_store.db
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/meta_reference_eval.db
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/huggingface_datasetio.db
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/localfs_datasetio.db
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
config: {}
|
||||
- provider_id: llm-as-judge
|
||||
provider_type: inline::llm-as-judge
|
||||
config: {}
|
||||
- provider_id: braintrust
|
||||
provider_type: inline::braintrust
|
||||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:}
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
config:
|
||||
api_key: ${env.BRAVE_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: code-interpreter
|
||||
provider_type: inline::code-interpreter
|
||||
config: {}
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
config: {}
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
config: {}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/llama_api}/registry.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: Llama-3.3-70B-Instruct
|
||||
provider_id: llama-openai-compat
|
||||
provider_model_id: Llama-3.3-70B-Instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.3-70B-Instruct
|
||||
provider_id: llama-openai-compat
|
||||
provider_model_id: Llama-3.3-70B-Instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: Llama-4-Scout-17B-16E-Instruct-FP8
|
||||
provider_id: llama-openai-compat
|
||||
provider_model_id: Llama-4-Scout-17B-16E-Instruct-FP8
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct
|
||||
provider_id: llama-openai-compat
|
||||
provider_model_id: Llama-4-Scout-17B-16E-Instruct-FP8
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: Llama-4-Maverick-17B-128E-Instruct-FP8
|
||||
provider_id: llama-openai-compat
|
||||
provider_model_id: Llama-4-Maverick-17B-128E-Instruct-FP8
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct
|
||||
provider_id: llama-openai-compat
|
||||
provider_model_id: Llama-4-Maverick-17B-128E-Instruct-FP8
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 384
|
||||
model_id: all-MiniLM-L6-v2
|
||||
provider_id: sentence-transformers
|
||||
model_type: embedding
|
||||
shields:
|
||||
- shield_id: meta-llama/Llama-Guard-3-8B
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
benchmarks: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
- toolgroup_id: builtin::code_interpreter
|
||||
provider_id: code-interpreter
|
||||
server:
|
||||
port: 8321
|
|
@ -60,9 +60,9 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/meta-reference-gpu/trace_store.db}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/trace_store.db
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -50,9 +50,9 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/meta-reference-gpu/trace_store.db}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/trace_store.db
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -50,9 +50,9 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/nvidia/trace_store.db}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/trace_store.db
|
||||
eval:
|
||||
- provider_id: nvidia
|
||||
provider_type: remote::nvidia
|
||||
|
|
|
@ -45,9 +45,9 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/nvidia/trace_store.db}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/trace_store.db
|
||||
eval:
|
||||
- provider_id: nvidia
|
||||
provider_type: remote::nvidia
|
||||
|
|
|
@ -43,9 +43,9 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/ollama/trace_store.db}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/trace_store.db
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -41,9 +41,9 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/ollama/trace_store.db}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/trace_store.db
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -68,9 +68,9 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/open-benchmark/trace_store.db}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/open-benchmark}/trace_store.db
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -50,9 +50,9 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/passthrough/trace_store.db}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/trace_store.db
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -45,9 +45,9 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/passthrough/trace_store.db}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/passthrough}/trace_store.db
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -88,9 +88,9 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/remote-vllm/trace_store.db}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/trace_store.db
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
|
|
|
@ -81,9 +81,9 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/remote-vllm/trace_store.db}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/trace_store.db
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
|
|
|
@ -51,9 +51,9 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/sambanova/trace_store.db}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/sambanova}/trace_store.db
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
|
|
|
@ -45,9 +45,9 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/tgi/trace_store.db}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/trace_store.db
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -44,9 +44,9 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/tgi/trace_store.db}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/tgi}/trace_store.db
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -50,9 +50,9 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/together/trace_store.db}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/trace_store.db
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -45,9 +45,9 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/together/trace_store.db}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/together}/trace_store.db
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -78,9 +78,9 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/verification/trace_store.db}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/verification}/trace_store.db
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -49,9 +49,9 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/vllm-gpu/trace_store.db}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/vllm-gpu}/trace_store.db
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -43,9 +43,9 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
service_name: ${env.OTEL_SERVICE_NAME:}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/watsonx/trace_store.db}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/trace_store.db
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue