Merge branch 'main' into agents-openai-migration

This commit is contained in:
Matthew Farrellee 2025-09-30 17:51:45 -04:00
commit 724322eeb2
673 changed files with 164269 additions and 14378 deletions

View file

@ -27,6 +27,7 @@ from llama_stack.apis.inference import (
)
from llama_stack.apis.safety import SafetyViolation
from llama_stack.apis.tools import ToolDef
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
from .openai_responses import (
@ -481,7 +482,10 @@ class Agents(Protocol):
- Agents can also use Memory to retrieve information from knowledge bases. See the RAG Tool and Vector IO APIs for more details.
"""
@webmethod(route="/agents", method="POST", descriptive_name="create_agent")
@webmethod(
route="/agents", method="POST", descriptive_name="create_agent", deprecated=True, level=LLAMA_STACK_API_V1
)
@webmethod(route="/agents", method="POST", descriptive_name="create_agent", level=LLAMA_STACK_API_V1ALPHA)
async def create_agent(
self,
agent_config: AgentConfig,
@ -494,7 +498,17 @@ class Agents(Protocol):
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn", method="POST", descriptive_name="create_agent_turn"
route="/agents/{agent_id}/session/{session_id}/turn",
method="POST",
descriptive_name="create_agent_turn",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn",
method="POST",
descriptive_name="create_agent_turn",
level=LLAMA_STACK_API_V1ALPHA,
)
async def create_agent_turn(
self,
@ -524,6 +538,14 @@ class Agents(Protocol):
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
method="POST",
descriptive_name="resume_agent_turn",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
method="POST",
descriptive_name="resume_agent_turn",
level=LLAMA_STACK_API_V1ALPHA,
)
async def resume_agent_turn(
self,
@ -549,6 +571,13 @@ class Agents(Protocol):
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
method="GET",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
method="GET",
level=LLAMA_STACK_API_V1ALPHA,
)
async def get_agents_turn(
self,
@ -568,6 +597,13 @@ class Agents(Protocol):
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}",
method="GET",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}",
method="GET",
level=LLAMA_STACK_API_V1ALPHA,
)
async def get_agents_step(
self,
@ -586,7 +622,19 @@ class Agents(Protocol):
"""
...
@webmethod(route="/agents/{agent_id}/session", method="POST", descriptive_name="create_agent_session")
@webmethod(
route="/agents/{agent_id}/session",
method="POST",
descriptive_name="create_agent_session",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/agents/{agent_id}/session",
method="POST",
descriptive_name="create_agent_session",
level=LLAMA_STACK_API_V1ALPHA,
)
async def create_agent_session(
self,
agent_id: str,
@ -600,7 +648,8 @@ class Agents(Protocol):
"""
...
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="GET")
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def get_agents_session(
self,
session_id: str,
@ -616,7 +665,10 @@ class Agents(Protocol):
"""
...
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="DELETE")
@webmethod(
route="/agents/{agent_id}/session/{session_id}", method="DELETE", deprecated=True, level=LLAMA_STACK_API_V1
)
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
async def delete_agents_session(
self,
session_id: str,
@ -629,7 +681,8 @@ class Agents(Protocol):
"""
...
@webmethod(route="/agents/{agent_id}", method="DELETE")
@webmethod(route="/agents/{agent_id}", method="DELETE", deprecated=True, level=LLAMA_STACK_API_V1)
@webmethod(route="/agents/{agent_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
async def delete_agent(
self,
agent_id: str,
@ -640,7 +693,8 @@ class Agents(Protocol):
"""
...
@webmethod(route="/agents", method="GET")
@webmethod(route="/agents", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
@webmethod(route="/agents", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
"""List all agents.
@ -650,7 +704,8 @@ class Agents(Protocol):
"""
...
@webmethod(route="/agents/{agent_id}", method="GET")
@webmethod(route="/agents/{agent_id}", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
@webmethod(route="/agents/{agent_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def get_agent(self, agent_id: str) -> Agent:
"""Describe an agent by its ID.
@ -659,7 +714,8 @@ class Agents(Protocol):
"""
...
@webmethod(route="/agents/{agent_id}/sessions", method="GET")
@webmethod(route="/agents/{agent_id}/sessions", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
@webmethod(route="/agents/{agent_id}/sessions", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def list_agent_sessions(
self,
agent_id: str,
@ -682,7 +738,7 @@ class Agents(Protocol):
#
# Both of these APIs are inherently stateful.
@webmethod(route="/openai/v1/responses/{response_id}", method="GET")
@webmethod(route="/responses/{response_id}", method="GET", level=LLAMA_STACK_API_V1)
async def get_openai_response(
self,
response_id: str,
@ -694,7 +750,7 @@ class Agents(Protocol):
"""
...
@webmethod(route="/openai/v1/responses", method="POST")
@webmethod(route="/responses", method="POST", level=LLAMA_STACK_API_V1)
async def create_openai_response(
self,
input: str | list[OpenAIResponseInput],
@ -719,7 +775,7 @@ class Agents(Protocol):
"""
...
@webmethod(route="/openai/v1/responses", method="GET")
@webmethod(route="/responses", method="GET", level=LLAMA_STACK_API_V1)
async def list_openai_responses(
self,
after: str | None = None,
@ -737,7 +793,7 @@ class Agents(Protocol):
"""
...
@webmethod(route="/openai/v1/responses/{response_id}/input_items", method="GET")
@webmethod(route="/responses/{response_id}/input_items", method="GET", level=LLAMA_STACK_API_V1)
async def list_openai_response_input_items(
self,
response_id: str,
@ -759,7 +815,7 @@ class Agents(Protocol):
"""
...
@webmethod(route="/openai/v1/responses/{response_id}", method="DELETE")
@webmethod(route="/responses/{response_id}", method="DELETE", level=LLAMA_STACK_API_V1)
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
"""Delete an OpenAI response by its ID.

View file

@ -276,13 +276,40 @@ class OpenAIResponseOutputMessageMCPListTools(BaseModel):
tools: list[MCPListToolsTool]
@json_schema_type
class OpenAIResponseMCPApprovalRequest(BaseModel):
"""
A request for human approval of a tool invocation.
"""
arguments: str
id: str
name: str
server_label: str
type: Literal["mcp_approval_request"] = "mcp_approval_request"
@json_schema_type
class OpenAIResponseMCPApprovalResponse(BaseModel):
"""
A response to an MCP approval request.
"""
approval_request_id: str
approve: bool
type: Literal["mcp_approval_response"] = "mcp_approval_response"
id: str | None = None
reason: str | None = None
OpenAIResponseOutput = Annotated[
OpenAIResponseMessage
| OpenAIResponseOutputMessageWebSearchToolCall
| OpenAIResponseOutputMessageFileSearchToolCall
| OpenAIResponseOutputMessageFunctionToolCall
| OpenAIResponseOutputMessageMCPCall
| OpenAIResponseOutputMessageMCPListTools,
| OpenAIResponseOutputMessageMCPListTools
| OpenAIResponseMCPApprovalRequest,
Field(discriminator="type"),
]
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
@ -336,7 +363,6 @@ class OpenAIResponseObject(BaseModel):
:param text: Text formatting configuration for the response
:param top_p: (Optional) Nucleus sampling parameter used for generation
:param truncation: (Optional) Truncation strategy applied to the response
:param user: (Optional) User identifier associated with the request
"""
created_at: int
@ -354,7 +380,6 @@ class OpenAIResponseObject(BaseModel):
text: OpenAIResponseText = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text"))
top_p: float | None = None
truncation: str | None = None
user: str | None = None
@json_schema_type
@ -725,6 +750,8 @@ OpenAIResponseInput = Annotated[
| OpenAIResponseOutputMessageFileSearchToolCall
| OpenAIResponseOutputMessageFunctionToolCall
| OpenAIResponseInputFunctionToolCallOutput
| OpenAIResponseMCPApprovalRequest
| OpenAIResponseMCPApprovalResponse
|
# Fallback to the generic message type as a last resort
OpenAIResponseMessage,

View file

@ -1,78 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Protocol, runtime_checkable
from llama_stack.apis.common.job_types import Job
from llama_stack.apis.inference import (
InterleavedContent,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.schema_utils import webmethod
@runtime_checkable
class BatchInference(Protocol):
"""Batch inference API for generating completions and chat completions.
This is an asynchronous API. If the request is successful, the response will be a job which can be polled for completion.
NOTE: This API is not yet implemented and is subject to change in concert with other asynchronous APIs
including (post-training, evals, etc).
"""
@webmethod(route="/batch-inference/completion", method="POST")
async def completion(
self,
model: str,
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
) -> Job:
"""Generate completions for a batch of content.
:param model: The model to use for the completion.
:param content_batch: The content to complete.
:param sampling_params: The sampling parameters to use for the completion.
:param response_format: The response format to use for the completion.
:param logprobs: The logprobs to use for the completion.
:returns: A job for the completion.
"""
...
@webmethod(route="/batch-inference/chat-completion", method="POST")
async def chat_completion(
self,
model: str,
messages_batch: list[list[Message]],
sampling_params: SamplingParams | None = None,
# zero-shot tool definitions as input to the model
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
) -> Job:
"""Generate chat completions for a batch of messages.
:param model: The model to use for the chat completion.
:param messages_batch: The messages to complete.
:param sampling_params: The sampling parameters to use for the completion.
:param tools: The tools to use for the chat completion.
:param tool_choice: The tool choice to use for the chat completion.
:param tool_prompt_format: The tool prompt format to use for the chat completion.
:param response_format: The response format to use for the chat completion.
:param logprobs: The logprobs to use for the chat completion.
:returns: A job for the chat completion.
"""
...

View file

@ -8,6 +8,7 @@ from typing import Literal, Protocol, runtime_checkable
from pydantic import BaseModel, Field
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.schema_utils import json_schema_type, webmethod
try:
@ -42,7 +43,7 @@ class Batches(Protocol):
Note: This API is currently under active development and may undergo changes.
"""
@webmethod(route="/openai/v1/batches", method="POST")
@webmethod(route="/batches", method="POST", level=LLAMA_STACK_API_V1)
async def create_batch(
self,
input_file_id: str,
@ -62,7 +63,7 @@ class Batches(Protocol):
"""
...
@webmethod(route="/openai/v1/batches/{batch_id}", method="GET")
@webmethod(route="/batches/{batch_id}", method="GET", level=LLAMA_STACK_API_V1)
async def retrieve_batch(self, batch_id: str) -> BatchObject:
"""Retrieve information about a specific batch.
@ -71,7 +72,7 @@ class Batches(Protocol):
"""
...
@webmethod(route="/openai/v1/batches/{batch_id}/cancel", method="POST")
@webmethod(route="/batches/{batch_id}/cancel", method="POST", level=LLAMA_STACK_API_V1)
async def cancel_batch(self, batch_id: str) -> BatchObject:
"""Cancel a batch that is in progress.
@ -80,7 +81,7 @@ class Batches(Protocol):
"""
...
@webmethod(route="/openai/v1/batches", method="GET")
@webmethod(route="/batches", method="GET", level=LLAMA_STACK_API_V1)
async def list_batches(
self,
after: str | None = None,

View file

@ -8,6 +8,7 @@ from typing import Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel, Field
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
from llama_stack.schema_utils import json_schema_type, webmethod
@ -53,7 +54,8 @@ class ListBenchmarksResponse(BaseModel):
@runtime_checkable
class Benchmarks(Protocol):
@webmethod(route="/eval/benchmarks", method="GET")
@webmethod(route="/eval/benchmarks", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/eval/benchmarks", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def list_benchmarks(self) -> ListBenchmarksResponse:
"""List all benchmarks.
@ -61,7 +63,8 @@ class Benchmarks(Protocol):
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET")
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def get_benchmark(
self,
benchmark_id: str,
@ -73,7 +76,8 @@ class Benchmarks(Protocol):
"""
...
@webmethod(route="/eval/benchmarks", method="POST")
@webmethod(route="/eval/benchmarks", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/eval/benchmarks", method="POST", level=LLAMA_STACK_API_V1ALPHA)
async def register_benchmark(
self,
benchmark_id: str,
@ -93,3 +97,12 @@ class Benchmarks(Protocol):
:param metadata: The metadata to use for the benchmark.
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
async def unregister_benchmark(self, benchmark_id: str) -> None:
"""Unregister a benchmark.
:param benchmark_id: The ID of the benchmark to unregister.
"""
...

View file

@ -8,6 +8,7 @@ from typing import Any, Protocol, runtime_checkable
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datasets import Dataset
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.schema_utils import webmethod
@ -20,7 +21,7 @@ class DatasetIO(Protocol):
# keeping for aligning with inference/safety, but this is not used
dataset_store: DatasetStore
@webmethod(route="/datasetio/iterrows/{dataset_id:path}", method="GET")
@webmethod(route="/datasetio/iterrows/{dataset_id:path}", method="GET", level=LLAMA_STACK_API_V1)
async def iterrows(
self,
dataset_id: str,
@ -44,7 +45,7 @@ class DatasetIO(Protocol):
"""
...
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST")
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST", level=LLAMA_STACK_API_V1)
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
"""Append rows to a dataset.

View file

@ -10,6 +10,7 @@ from typing import Annotated, Any, Literal, Protocol
from pydantic import BaseModel, Field
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@ -145,7 +146,7 @@ class ListDatasetsResponse(BaseModel):
class Datasets(Protocol):
@webmethod(route="/datasets", method="POST")
@webmethod(route="/datasets", method="POST", level=LLAMA_STACK_API_V1)
async def register_dataset(
self,
purpose: DatasetPurpose,
@ -214,7 +215,7 @@ class Datasets(Protocol):
"""
...
@webmethod(route="/datasets/{dataset_id:path}", method="GET")
@webmethod(route="/datasets/{dataset_id:path}", method="GET", level=LLAMA_STACK_API_V1)
async def get_dataset(
self,
dataset_id: str,
@ -226,7 +227,7 @@ class Datasets(Protocol):
"""
...
@webmethod(route="/datasets", method="GET")
@webmethod(route="/datasets", method="GET", level=LLAMA_STACK_API_V1)
async def list_datasets(self) -> ListDatasetsResponse:
"""List all datasets.
@ -234,7 +235,7 @@ class Datasets(Protocol):
"""
...
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE")
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
async def unregister_dataset(
self,
dataset_id: str,

View file

@ -13,6 +13,7 @@ from llama_stack.apis.common.job_types import Job
from llama_stack.apis.inference import SamplingParams, SystemMessage
from llama_stack.apis.scoring import ScoringResult
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@ -83,7 +84,8 @@ class EvaluateResponse(BaseModel):
class Eval(Protocol):
"""Llama Stack Evaluation API for running evaluations on model and agent candidates."""
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST")
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST", level=LLAMA_STACK_API_V1ALPHA)
async def run_eval(
self,
benchmark_id: str,
@ -97,7 +99,10 @@ class Eval(Protocol):
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST")
@webmethod(
route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST", level=LLAMA_STACK_API_V1, deprecated=True
)
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST", level=LLAMA_STACK_API_V1ALPHA)
async def evaluate_rows(
self,
benchmark_id: str,
@ -115,7 +120,10 @@ class Eval(Protocol):
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
@webmethod(
route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True
)
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
"""Get the status of a job.
@ -125,7 +133,13 @@ class Eval(Protocol):
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="DELETE")
@webmethod(
route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}",
method="DELETE",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
"""Cancel a job.
@ -134,7 +148,15 @@ class Eval(Protocol):
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET")
@webmethod(
route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result",
method="GET",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET", level=LLAMA_STACK_API_V1ALPHA
)
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
"""Get the result of a job.

View file

@ -11,6 +11,7 @@ from fastapi import File, Form, Response, UploadFile
from pydantic import BaseModel, Field
from llama_stack.apis.common.responses import Order
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
@ -104,14 +105,12 @@ class OpenAIFileDeleteResponse(BaseModel):
@trace_protocol
class Files(Protocol):
# OpenAI Files API Endpoints
@webmethod(route="/openai/v1/files", method="POST")
@webmethod(route="/files", method="POST", level=LLAMA_STACK_API_V1)
async def openai_upload_file(
self,
file: Annotated[UploadFile, File()],
purpose: Annotated[OpenAIFilePurpose, Form()],
expires_after_anchor: Annotated[str | None, Form(alias="expires_after[anchor]")] = None,
expires_after_seconds: Annotated[int | None, Form(alias="expires_after[seconds]")] = None,
# TODO: expires_after is producing strange openapi spec, params are showing up as a required w/ oneOf being null
expires_after: Annotated[ExpiresAfter | None, Form()] = None,
) -> OpenAIFileObject:
"""
Upload a file that can be used across various endpoints.
@ -119,15 +118,16 @@ class Files(Protocol):
The file upload should be a multipart form request with:
- file: The File object (not file name) to be uploaded.
- purpose: The intended purpose of the uploaded file.
- expires_after: Optional form values describing expiration for the file. Expected expires_after[anchor] = "created_at", expires_after[seconds] = <int>. Seconds must be between 3600 and 2592000 (1 hour to 30 days).
- expires_after: Optional form values describing expiration for the file.
:param file: The uploaded file object containing content and metadata (filename, content_type, etc.).
:param purpose: The intended purpose of the uploaded file (e.g., "assistants", "fine-tune").
:param expires_after: Optional form values describing expiration for the file.
:returns: An OpenAIFileObject representing the uploaded file.
"""
...
@webmethod(route="/openai/v1/files", method="GET")
@webmethod(route="/files", method="GET", level=LLAMA_STACK_API_V1)
async def openai_list_files(
self,
after: str | None = None,
@ -146,7 +146,7 @@ class Files(Protocol):
"""
...
@webmethod(route="/openai/v1/files/{file_id}", method="GET")
@webmethod(route="/files/{file_id}", method="GET", level=LLAMA_STACK_API_V1)
async def openai_retrieve_file(
self,
file_id: str,
@ -159,7 +159,7 @@ class Files(Protocol):
"""
...
@webmethod(route="/openai/v1/files/{file_id}", method="DELETE")
@webmethod(route="/files/{file_id}", method="DELETE", level=LLAMA_STACK_API_V1)
async def openai_delete_file(
self,
file_id: str,
@ -172,7 +172,7 @@ class Files(Protocol):
"""
...
@webmethod(route="/openai/v1/files/{file_id}/content", method="GET")
@webmethod(route="/files/{file_id}/content", method="GET", level=LLAMA_STACK_API_V1)
async def openai_retrieve_file_content(
self,
file_id: str,

View file

@ -17,10 +17,11 @@ from typing import (
from pydantic import BaseModel, Field, field_validator
from typing_extensions import TypedDict
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
from llama_stack.apis.common.responses import Order
from llama_stack.apis.models import Model
from llama_stack.apis.telemetry import MetricResponseMixin
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
from llama_stack.models.llama.datatypes import (
BuiltinTool,
StopReason,
@ -913,6 +914,7 @@ class OpenAIEmbeddingData(BaseModel):
"""
object: Literal["embedding"] = "embedding"
# TODO: consider dropping str and using openai.types.embeddings.Embedding instead of OpenAIEmbeddingData
embedding: list[float] | str
index: int
@ -973,26 +975,6 @@ class EmbeddingTaskType(Enum):
document = "document"
@json_schema_type
class BatchCompletionResponse(BaseModel):
"""Response from a batch completion request.
:param batch: List of completion responses, one for each input in the batch
"""
batch: list[CompletionResponse]
@json_schema_type
class BatchChatCompletionResponse(BaseModel):
"""Response from a batch chat completion request.
:param batch: List of chat completion responses, one for each conversation in the batch
"""
batch: list[ChatCompletionResponse]
class OpenAICompletionWithInputMessages(OpenAIChatCompletion):
input_messages: list[OpenAIMessageParam]
@ -1026,7 +1008,6 @@ class InferenceProvider(Protocol):
model_store: ModelStore | None = None
@webmethod(route="/inference/completion", method="POST")
async def completion(
self,
model_id: str,
@ -1049,28 +1030,6 @@ class InferenceProvider(Protocol):
"""
...
@webmethod(route="/inference/batch-completion", method="POST", experimental=True)
async def batch_completion(
self,
model_id: str,
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
) -> BatchCompletionResponse:
"""Generate completions for a batch of content using the specified model.
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param content_batch: The content to generate completions for.
:param sampling_params: (Optional) Parameters to control the sampling strategy.
:param response_format: (Optional) Grammar specification for guided (structured) decoding.
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
:returns: A BatchCompletionResponse with the full completions.
"""
raise NotImplementedError("Batch completion is not implemented")
return # this is so mypy's safe-super rule will consider the method concrete
@webmethod(route="/inference/chat-completion", method="POST")
async def chat_completion(
self,
model_id: str,
@ -1110,52 +1069,7 @@ class InferenceProvider(Protocol):
"""
...
@webmethod(route="/inference/batch-chat-completion", method="POST", experimental=True)
async def batch_chat_completion(
self,
model_id: str,
messages_batch: list[list[Message]],
sampling_params: SamplingParams | None = None,
tools: list[ToolDefinition] | None = None,
tool_config: ToolConfig | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
) -> BatchChatCompletionResponse:
"""Generate chat completions for a batch of messages using the specified model.
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param messages_batch: The messages to generate completions for.
:param sampling_params: (Optional) Parameters to control the sampling strategy.
:param tools: (Optional) List of tool definitions available to the model.
:param tool_config: (Optional) Configuration for tool use.
:param response_format: (Optional) Grammar specification for guided (structured) decoding.
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
:returns: A BatchChatCompletionResponse with the full completions.
"""
raise NotImplementedError("Batch chat completion is not implemented")
return # this is so mypy's safe-super rule will consider the method concrete
@webmethod(route="/inference/embeddings", method="POST")
async def embeddings(
self,
model_id: str,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
"""Generate embeddings for content pieces using the specified model.
:param model_id: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
:param contents: List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text.
:param output_dimension: (Optional) Output dimensionality for the embeddings. Only supported by Matryoshka models.
:param text_truncation: (Optional) Config for how to truncate text for embedding when text is longer than the model's max sequence length.
:param task_type: (Optional) How is the embedding being used? This is only supported by asymmetric embedding models.
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}.
"""
...
@webmethod(route="/inference/rerank", method="POST", experimental=True)
@webmethod(route="/inference/rerank", method="POST", level=LLAMA_STACK_API_V1ALPHA)
async def rerank(
self,
model: str,
@ -1174,7 +1088,7 @@ class InferenceProvider(Protocol):
raise NotImplementedError("Reranking is not implemented")
return # this is so mypy's safe-super rule will consider the method concrete
@webmethod(route="/openai/v1/completions", method="POST")
@webmethod(route="/completions", method="POST", level=LLAMA_STACK_API_V1)
async def openai_completion(
self,
# Standard OpenAI completion parameters
@ -1225,7 +1139,7 @@ class InferenceProvider(Protocol):
"""
...
@webmethod(route="/openai/v1/chat/completions", method="POST")
@webmethod(route="/chat/completions", method="POST", level=LLAMA_STACK_API_V1)
async def openai_chat_completion(
self,
model: str,
@ -1281,7 +1195,7 @@ class InferenceProvider(Protocol):
"""
...
@webmethod(route="/openai/v1/embeddings", method="POST")
@webmethod(route="/embeddings", method="POST", level=LLAMA_STACK_API_V1)
async def openai_embeddings(
self,
model: str,
@ -1310,7 +1224,7 @@ class Inference(InferenceProvider):
- Embedding models: these models generate embeddings to be used for semantic search.
"""
@webmethod(route="/openai/v1/chat/completions", method="GET")
@webmethod(route="/chat/completions", method="GET", level=LLAMA_STACK_API_V1)
async def list_chat_completions(
self,
after: str | None = None,
@ -1328,7 +1242,7 @@ class Inference(InferenceProvider):
"""
raise NotImplementedError("List chat completions is not implemented")
@webmethod(route="/openai/v1/chat/completions/{completion_id}", method="GET")
@webmethod(route="/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1)
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
"""Describe a chat completion by its ID.

View file

@ -8,6 +8,7 @@ from typing import Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.datatypes import HealthStatus
from llama_stack.schema_utils import json_schema_type, webmethod
@ -57,7 +58,7 @@ class ListRoutesResponse(BaseModel):
@runtime_checkable
class Inspect(Protocol):
@webmethod(route="/inspect/routes", method="GET")
@webmethod(route="/inspect/routes", method="GET", level=LLAMA_STACK_API_V1)
async def list_routes(self) -> ListRoutesResponse:
"""List all available API routes with their methods and implementing providers.
@ -65,7 +66,7 @@ class Inspect(Protocol):
"""
...
@webmethod(route="/health", method="GET")
@webmethod(route="/health", method="GET", level=LLAMA_STACK_API_V1)
async def health(self) -> HealthInfo:
"""Get the current health status of the service.
@ -73,7 +74,7 @@ class Inspect(Protocol):
"""
...
@webmethod(route="/version", method="GET")
@webmethod(route="/version", method="GET", level=LLAMA_STACK_API_V1)
async def version(self) -> VersionInfo:
"""Get the version of the service.

View file

@ -10,6 +10,7 @@ from typing import Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel, ConfigDict, Field, field_validator
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
@ -102,7 +103,7 @@ class OpenAIListModelsResponse(BaseModel):
@runtime_checkable
@trace_protocol
class Models(Protocol):
@webmethod(route="/models", method="GET")
@webmethod(route="/models", method="GET", level=LLAMA_STACK_API_V1)
async def list_models(self) -> ListModelsResponse:
"""List all models.
@ -110,15 +111,7 @@ class Models(Protocol):
"""
...
@webmethod(route="/openai/v1/models", method="GET")
async def openai_list_models(self) -> OpenAIListModelsResponse:
"""List models using the OpenAI API.
:returns: A OpenAIListModelsResponse.
"""
...
@webmethod(route="/models/{model_id:path}", method="GET")
@webmethod(route="/models/{model_id:path}", method="GET", level=LLAMA_STACK_API_V1)
async def get_model(
self,
model_id: str,
@ -130,7 +123,7 @@ class Models(Protocol):
"""
...
@webmethod(route="/models", method="POST")
@webmethod(route="/models", method="POST", level=LLAMA_STACK_API_V1)
async def register_model(
self,
model_id: str,
@ -150,7 +143,7 @@ class Models(Protocol):
"""
...
@webmethod(route="/models/{model_id:path}", method="DELETE")
@webmethod(route="/models/{model_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
async def unregister_model(
self,
model_id: str,

View file

@ -13,6 +13,7 @@ from pydantic import BaseModel, Field
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.common.training_types import Checkpoint
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@ -283,7 +284,8 @@ class PostTrainingJobArtifactsResponse(BaseModel):
class PostTraining(Protocol):
@webmethod(route="/post-training/supervised-fine-tune", method="POST")
@webmethod(route="/post-training/supervised-fine-tune", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/post-training/supervised-fine-tune", method="POST", level=LLAMA_STACK_API_V1ALPHA)
async def supervised_fine_tune(
self,
job_uuid: str,
@ -310,7 +312,8 @@ class PostTraining(Protocol):
"""
...
@webmethod(route="/post-training/preference-optimize", method="POST")
@webmethod(route="/post-training/preference-optimize", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/post-training/preference-optimize", method="POST", level=LLAMA_STACK_API_V1ALPHA)
async def preference_optimize(
self,
job_uuid: str,
@ -332,7 +335,8 @@ class PostTraining(Protocol):
"""
...
@webmethod(route="/post-training/jobs", method="GET")
@webmethod(route="/post-training/jobs", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/post-training/jobs", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
"""Get all training jobs.
@ -340,7 +344,8 @@ class PostTraining(Protocol):
"""
...
@webmethod(route="/post-training/job/status", method="GET")
@webmethod(route="/post-training/job/status", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/post-training/job/status", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse:
"""Get the status of a training job.
@ -349,7 +354,8 @@ class PostTraining(Protocol):
"""
...
@webmethod(route="/post-training/job/cancel", method="POST")
@webmethod(route="/post-training/job/cancel", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/post-training/job/cancel", method="POST", level=LLAMA_STACK_API_V1ALPHA)
async def cancel_training_job(self, job_uuid: str) -> None:
"""Cancel a training job.
@ -357,7 +363,8 @@ class PostTraining(Protocol):
"""
...
@webmethod(route="/post-training/job/artifacts", method="GET")
@webmethod(route="/post-training/job/artifacts", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/post-training/job/artifacts", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
"""Get the artifacts of a training job.

View file

@ -10,6 +10,7 @@ from typing import Protocol, runtime_checkable
from pydantic import BaseModel, Field, field_validator, model_validator
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
@ -95,7 +96,7 @@ class ListPromptsResponse(BaseModel):
class Prompts(Protocol):
"""Protocol for prompt management operations."""
@webmethod(route="/prompts", method="GET")
@webmethod(route="/prompts", method="GET", level=LLAMA_STACK_API_V1)
async def list_prompts(self) -> ListPromptsResponse:
"""List all prompts.
@ -103,7 +104,7 @@ class Prompts(Protocol):
"""
...
@webmethod(route="/prompts/{prompt_id}/versions", method="GET")
@webmethod(route="/prompts/{prompt_id}/versions", method="GET", level=LLAMA_STACK_API_V1)
async def list_prompt_versions(
self,
prompt_id: str,
@ -115,7 +116,7 @@ class Prompts(Protocol):
"""
...
@webmethod(route="/prompts/{prompt_id}", method="GET")
@webmethod(route="/prompts/{prompt_id}", method="GET", level=LLAMA_STACK_API_V1)
async def get_prompt(
self,
prompt_id: str,
@ -129,7 +130,7 @@ class Prompts(Protocol):
"""
...
@webmethod(route="/prompts", method="POST")
@webmethod(route="/prompts", method="POST", level=LLAMA_STACK_API_V1)
async def create_prompt(
self,
prompt: str,
@ -143,7 +144,7 @@ class Prompts(Protocol):
"""
...
@webmethod(route="/prompts/{prompt_id}", method="PUT")
@webmethod(route="/prompts/{prompt_id}", method="PUT", level=LLAMA_STACK_API_V1)
async def update_prompt(
self,
prompt_id: str,
@ -163,7 +164,7 @@ class Prompts(Protocol):
"""
...
@webmethod(route="/prompts/{prompt_id}", method="DELETE")
@webmethod(route="/prompts/{prompt_id}", method="DELETE", level=LLAMA_STACK_API_V1)
async def delete_prompt(
self,
prompt_id: str,
@ -174,7 +175,7 @@ class Prompts(Protocol):
"""
...
@webmethod(route="/prompts/{prompt_id}/set-default-version", method="PUT")
@webmethod(route="/prompts/{prompt_id}/set-default-version", method="PUT", level=LLAMA_STACK_API_V1)
async def set_default_version(
self,
prompt_id: str,

View file

@ -8,6 +8,7 @@ from typing import Any, Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.datatypes import HealthResponse
from llama_stack.schema_utils import json_schema_type, webmethod
@ -45,7 +46,7 @@ class Providers(Protocol):
Providers API for inspecting, listing, and modifying providers and their configurations.
"""
@webmethod(route="/providers", method="GET")
@webmethod(route="/providers", method="GET", level=LLAMA_STACK_API_V1)
async def list_providers(self) -> ListProvidersResponse:
"""List all available providers.
@ -53,7 +54,7 @@ class Providers(Protocol):
"""
...
@webmethod(route="/providers/{provider_id}", method="GET")
@webmethod(route="/providers/{provider_id}", method="GET", level=LLAMA_STACK_API_V1)
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
"""Get detailed information about a specific provider.

View file

@ -11,6 +11,7 @@ from pydantic import BaseModel, Field
from llama_stack.apis.inference import Message
from llama_stack.apis.shields import Shield
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
@ -97,7 +98,7 @@ class ShieldStore(Protocol):
class Safety(Protocol):
shield_store: ShieldStore
@webmethod(route="/safety/run-shield", method="POST")
@webmethod(route="/safety/run-shield", method="POST", level=LLAMA_STACK_API_V1)
async def run_shield(
self,
shield_id: str,
@ -113,7 +114,7 @@ class Safety(Protocol):
"""
...
@webmethod(route="/openai/v1/moderations", method="POST")
@webmethod(route="/moderations", method="POST", level=LLAMA_STACK_API_V1)
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
"""Classifies if text and/or image inputs are potentially harmful.
:param input: Input (or inputs) to classify.

View file

@ -9,6 +9,7 @@ from typing import Any, Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.schema_utils import json_schema_type, webmethod
# mapping of metric to value
@ -61,7 +62,7 @@ class ScoringFunctionStore(Protocol):
class Scoring(Protocol):
scoring_function_store: ScoringFunctionStore
@webmethod(route="/scoring/score-batch", method="POST")
@webmethod(route="/scoring/score-batch", method="POST", level=LLAMA_STACK_API_V1)
async def score_batch(
self,
dataset_id: str,
@ -77,7 +78,7 @@ class Scoring(Protocol):
"""
...
@webmethod(route="/scoring/score", method="POST")
@webmethod(route="/scoring/score", method="POST", level=LLAMA_STACK_API_V1)
async def score(
self,
input_rows: list[dict[str, Any]],

View file

@ -18,6 +18,7 @@ from pydantic import BaseModel, Field
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@ -160,7 +161,7 @@ class ListScoringFunctionsResponse(BaseModel):
@runtime_checkable
class ScoringFunctions(Protocol):
@webmethod(route="/scoring-functions", method="GET")
@webmethod(route="/scoring-functions", method="GET", level=LLAMA_STACK_API_V1)
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
"""List all scoring functions.
@ -168,7 +169,7 @@ class ScoringFunctions(Protocol):
"""
...
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET", level=LLAMA_STACK_API_V1)
async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn:
"""Get a scoring function by its ID.
@ -177,7 +178,7 @@ class ScoringFunctions(Protocol):
"""
...
@webmethod(route="/scoring-functions", method="POST")
@webmethod(route="/scoring-functions", method="POST", level=LLAMA_STACK_API_V1)
async def register_scoring_function(
self,
scoring_fn_id: str,
@ -197,3 +198,11 @@ class ScoringFunctions(Protocol):
:param params: The parameters for the scoring function for benchmark eval, these can be overridden for app eval.
"""
...
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
async def unregister_scoring_function(self, scoring_fn_id: str) -> None:
"""Unregister a scoring function.
:param scoring_fn_id: The ID of the scoring function to unregister.
"""
...

View file

@ -9,6 +9,7 @@ from typing import Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
@ -49,7 +50,7 @@ class ListShieldsResponse(BaseModel):
@runtime_checkable
@trace_protocol
class Shields(Protocol):
@webmethod(route="/shields", method="GET")
@webmethod(route="/shields", method="GET", level=LLAMA_STACK_API_V1)
async def list_shields(self) -> ListShieldsResponse:
"""List all shields.
@ -57,7 +58,7 @@ class Shields(Protocol):
"""
...
@webmethod(route="/shields/{identifier:path}", method="GET")
@webmethod(route="/shields/{identifier:path}", method="GET", level=LLAMA_STACK_API_V1)
async def get_shield(self, identifier: str) -> Shield:
"""Get a shield by its identifier.
@ -66,7 +67,7 @@ class Shields(Protocol):
"""
...
@webmethod(route="/shields", method="POST")
@webmethod(route="/shields", method="POST", level=LLAMA_STACK_API_V1)
async def register_shield(
self,
shield_id: str,
@ -84,7 +85,7 @@ class Shields(Protocol):
"""
...
@webmethod(route="/shields/{identifier:path}", method="DELETE")
@webmethod(route="/shields/{identifier:path}", method="DELETE", level=LLAMA_STACK_API_V1)
async def unregister_shield(self, identifier: str) -> None:
"""Unregister a shield.

View file

@ -10,6 +10,7 @@ from typing import Any, Protocol
from pydantic import BaseModel
from llama_stack.apis.inference import Message
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.schema_utils import json_schema_type, webmethod
@ -59,7 +60,7 @@ class SyntheticDataGenerationResponse(BaseModel):
class SyntheticDataGeneration(Protocol):
@webmethod(route="/synthetic-data-generation/generate")
@webmethod(route="/synthetic-data-generation/generate", level=LLAMA_STACK_API_V1)
def synthetic_data_generate(
self,
dialogs: list[Message],

View file

@ -16,6 +16,7 @@ from typing import (
from pydantic import BaseModel, Field
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.models.llama.datatypes import Primitive
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@ -412,7 +413,7 @@ class QueryMetricsResponse(BaseModel):
@runtime_checkable
class Telemetry(Protocol):
@webmethod(route="/telemetry/events", method="POST")
@webmethod(route="/telemetry/events", method="POST", level=LLAMA_STACK_API_V1)
async def log_event(
self,
event: Event,
@ -425,7 +426,7 @@ class Telemetry(Protocol):
"""
...
@webmethod(route="/telemetry/traces", method="POST", required_scope=REQUIRED_SCOPE)
@webmethod(route="/telemetry/traces", method="POST", required_scope=REQUIRED_SCOPE, level=LLAMA_STACK_API_V1)
async def query_traces(
self,
attribute_filters: list[QueryCondition] | None = None,
@ -443,7 +444,9 @@ class Telemetry(Protocol):
"""
...
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET", required_scope=REQUIRED_SCOPE)
@webmethod(
route="/telemetry/traces/{trace_id:path}", method="GET", required_scope=REQUIRED_SCOPE, level=LLAMA_STACK_API_V1
)
async def get_trace(self, trace_id: str) -> Trace:
"""Get a trace by its ID.
@ -453,7 +456,10 @@ class Telemetry(Protocol):
...
@webmethod(
route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET", required_scope=REQUIRED_SCOPE
route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}",
method="GET",
required_scope=REQUIRED_SCOPE,
level=LLAMA_STACK_API_V1,
)
async def get_span(self, trace_id: str, span_id: str) -> Span:
"""Get a span by its ID.
@ -464,7 +470,12 @@ class Telemetry(Protocol):
"""
...
@webmethod(route="/telemetry/spans/{span_id:path}/tree", method="POST", required_scope=REQUIRED_SCOPE)
@webmethod(
route="/telemetry/spans/{span_id:path}/tree",
method="POST",
required_scope=REQUIRED_SCOPE,
level=LLAMA_STACK_API_V1,
)
async def get_span_tree(
self,
span_id: str,
@ -480,7 +491,7 @@ class Telemetry(Protocol):
"""
...
@webmethod(route="/telemetry/spans", method="POST", required_scope=REQUIRED_SCOPE)
@webmethod(route="/telemetry/spans", method="POST", required_scope=REQUIRED_SCOPE, level=LLAMA_STACK_API_V1)
async def query_spans(
self,
attribute_filters: list[QueryCondition],
@ -496,7 +507,7 @@ class Telemetry(Protocol):
"""
...
@webmethod(route="/telemetry/spans/export", method="POST")
@webmethod(route="/telemetry/spans/export", method="POST", level=LLAMA_STACK_API_V1)
async def save_spans_to_dataset(
self,
attribute_filters: list[QueryCondition],
@ -513,7 +524,9 @@ class Telemetry(Protocol):
"""
...
@webmethod(route="/telemetry/metrics/{metric_name}", method="POST", required_scope=REQUIRED_SCOPE)
@webmethod(
route="/telemetry/metrics/{metric_name}", method="POST", required_scope=REQUIRED_SCOPE, level=LLAMA_STACK_API_V1
)
async def query_metrics(
self,
metric_name: str,

View file

@ -11,6 +11,7 @@ from pydantic import BaseModel, Field, field_validator
from typing_extensions import runtime_checkable
from llama_stack.apis.common.content_types import URL, InterleavedContent
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@ -185,7 +186,7 @@ class RAGQueryConfig(BaseModel):
@runtime_checkable
@trace_protocol
class RAGToolRuntime(Protocol):
@webmethod(route="/tool-runtime/rag-tool/insert", method="POST")
@webmethod(route="/tool-runtime/rag-tool/insert", method="POST", level=LLAMA_STACK_API_V1)
async def insert(
self,
documents: list[RAGDocument],
@ -200,7 +201,7 @@ class RAGToolRuntime(Protocol):
"""
...
@webmethod(route="/tool-runtime/rag-tool/query", method="POST")
@webmethod(route="/tool-runtime/rag-tool/query", method="POST", level=LLAMA_STACK_API_V1)
async def query(
self,
content: InterleavedContent,

View file

@ -12,6 +12,7 @@ from typing_extensions import runtime_checkable
from llama_stack.apis.common.content_types import URL, InterleavedContent
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
@ -26,6 +27,8 @@ class ToolParameter(BaseModel):
:param parameter_type: Type of the parameter (e.g., string, integer)
:param description: Human-readable description of what the parameter does
:param required: Whether this parameter is required for tool invocation
:param items: Type of the elements when parameter_type is array
:param title: (Optional) Title of the parameter
:param default: (Optional) Default value for the parameter if not provided
"""
@ -33,6 +36,8 @@ class ToolParameter(BaseModel):
parameter_type: str
description: str
required: bool = Field(default=True)
items: dict | None = None
title: str | None = None
default: Any | None = None
@ -151,7 +156,7 @@ class ListToolDefsResponse(BaseModel):
@runtime_checkable
@trace_protocol
class ToolGroups(Protocol):
@webmethod(route="/toolgroups", method="POST")
@webmethod(route="/toolgroups", method="POST", level=LLAMA_STACK_API_V1)
async def register_tool_group(
self,
toolgroup_id: str,
@ -168,7 +173,7 @@ class ToolGroups(Protocol):
"""
...
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="GET")
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="GET", level=LLAMA_STACK_API_V1)
async def get_tool_group(
self,
toolgroup_id: str,
@ -180,7 +185,7 @@ class ToolGroups(Protocol):
"""
...
@webmethod(route="/toolgroups", method="GET")
@webmethod(route="/toolgroups", method="GET", level=LLAMA_STACK_API_V1)
async def list_tool_groups(self) -> ListToolGroupsResponse:
"""List tool groups with optional provider.
@ -188,7 +193,7 @@ class ToolGroups(Protocol):
"""
...
@webmethod(route="/tools", method="GET")
@webmethod(route="/tools", method="GET", level=LLAMA_STACK_API_V1)
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
"""List tools with optional tool group.
@ -197,7 +202,7 @@ class ToolGroups(Protocol):
"""
...
@webmethod(route="/tools/{tool_name:path}", method="GET")
@webmethod(route="/tools/{tool_name:path}", method="GET", level=LLAMA_STACK_API_V1)
async def get_tool(
self,
tool_name: str,
@ -209,7 +214,7 @@ class ToolGroups(Protocol):
"""
...
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE")
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
async def unregister_toolgroup(
self,
toolgroup_id: str,
@ -238,7 +243,7 @@ class ToolRuntime(Protocol):
rag_tool: RAGToolRuntime | None = None
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
@webmethod(route="/tool-runtime/list-tools", method="GET")
@webmethod(route="/tool-runtime/list-tools", method="GET", level=LLAMA_STACK_API_V1)
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolDefsResponse:
@ -250,7 +255,7 @@ class ToolRuntime(Protocol):
"""
...
@webmethod(route="/tool-runtime/invoke", method="POST")
@webmethod(route="/tool-runtime/invoke", method="POST", level=LLAMA_STACK_API_V1)
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
"""Run a tool with the given arguments.

View file

@ -9,6 +9,7 @@ from typing import Literal, Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
@ -65,7 +66,7 @@ class ListVectorDBsResponse(BaseModel):
@runtime_checkable
@trace_protocol
class VectorDBs(Protocol):
@webmethod(route="/vector-dbs", method="GET")
@webmethod(route="/vector-dbs", method="GET", level=LLAMA_STACK_API_V1)
async def list_vector_dbs(self) -> ListVectorDBsResponse:
"""List all vector databases.
@ -73,7 +74,7 @@ class VectorDBs(Protocol):
"""
...
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="GET")
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="GET", level=LLAMA_STACK_API_V1)
async def get_vector_db(
self,
vector_db_id: str,
@ -85,7 +86,7 @@ class VectorDBs(Protocol):
"""
...
@webmethod(route="/vector-dbs", method="POST")
@webmethod(route="/vector-dbs", method="POST", level=LLAMA_STACK_API_V1)
async def register_vector_db(
self,
vector_db_id: str,
@ -107,7 +108,7 @@ class VectorDBs(Protocol):
"""
...
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE")
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
async def unregister_vector_db(self, vector_db_id: str) -> None:
"""Unregister a vector database.

View file

@ -15,6 +15,7 @@ from pydantic import BaseModel, Field
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
from llama_stack.schema_utils import json_schema_type, webmethod
@ -317,7 +318,8 @@ class VectorStoreChunkingStrategyStatic(BaseModel):
VectorStoreChunkingStrategy = Annotated[
VectorStoreChunkingStrategyAuto | VectorStoreChunkingStrategyStatic, Field(discriminator="type")
VectorStoreChunkingStrategyAuto | VectorStoreChunkingStrategyStatic,
Field(discriminator="type"),
]
register_schema(VectorStoreChunkingStrategy, name="VectorStoreChunkingStrategy")
@ -426,6 +428,44 @@ class VectorStoreFileDeleteResponse(BaseModel):
deleted: bool = True
@json_schema_type
class VectorStoreFileBatchObject(BaseModel):
"""OpenAI Vector Store File Batch object.
:param id: Unique identifier for the file batch
:param object: Object type identifier, always "vector_store.file_batch"
:param created_at: Timestamp when the file batch was created
:param vector_store_id: ID of the vector store containing the file batch
:param status: Current processing status of the file batch
:param file_counts: File processing status counts for the batch
"""
id: str
object: str = "vector_store.file_batch"
created_at: int
vector_store_id: str
status: VectorStoreFileStatus
file_counts: VectorStoreFileCounts
@json_schema_type
class VectorStoreFilesListInBatchResponse(BaseModel):
"""Response from listing files in a vector store file batch.
:param object: Object type identifier, always "list"
:param data: List of vector store file objects in the batch
:param first_id: (Optional) ID of the first file in the list for pagination
:param last_id: (Optional) ID of the last file in the list for pagination
:param has_more: Whether there are more files available beyond this page
"""
object: str = "list"
data: list[VectorStoreFileObject]
first_id: str | None = None
last_id: str | None = None
has_more: bool = False
class VectorDBStore(Protocol):
def get_vector_db(self, vector_db_id: str) -> VectorDB | None: ...
@ -437,7 +477,7 @@ class VectorIO(Protocol):
# this will just block now until chunks are inserted, but it should
# probably return a Job instance which can be polled for completion
@webmethod(route="/vector-io/insert", method="POST")
@webmethod(route="/vector-io/insert", method="POST", level=LLAMA_STACK_API_V1)
async def insert_chunks(
self,
vector_db_id: str,
@ -455,7 +495,7 @@ class VectorIO(Protocol):
"""
...
@webmethod(route="/vector-io/query", method="POST")
@webmethod(route="/vector-io/query", method="POST", level=LLAMA_STACK_API_V1)
async def query_chunks(
self,
vector_db_id: str,
@ -472,7 +512,7 @@ class VectorIO(Protocol):
...
# OpenAI Vector Stores API endpoints
@webmethod(route="/openai/v1/vector_stores", method="POST")
@webmethod(route="/vector_stores", method="POST", level=LLAMA_STACK_API_V1)
async def openai_create_vector_store(
self,
name: str | None = None,
@ -498,7 +538,7 @@ class VectorIO(Protocol):
"""
...
@webmethod(route="/openai/v1/vector_stores", method="GET")
@webmethod(route="/vector_stores", method="GET", level=LLAMA_STACK_API_V1)
async def openai_list_vector_stores(
self,
limit: int | None = 20,
@ -516,7 +556,7 @@ class VectorIO(Protocol):
"""
...
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}", method="GET")
@webmethod(route="/vector_stores/{vector_store_id}", method="GET", level=LLAMA_STACK_API_V1)
async def openai_retrieve_vector_store(
self,
vector_store_id: str,
@ -528,7 +568,11 @@ class VectorIO(Protocol):
"""
...
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}", method="POST")
@webmethod(
route="/vector_stores/{vector_store_id}",
method="POST",
level=LLAMA_STACK_API_V1,
)
async def openai_update_vector_store(
self,
vector_store_id: str,
@ -546,7 +590,11 @@ class VectorIO(Protocol):
"""
...
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}", method="DELETE")
@webmethod(
route="/vector_stores/{vector_store_id}",
method="DELETE",
level=LLAMA_STACK_API_V1,
)
async def openai_delete_vector_store(
self,
vector_store_id: str,
@ -558,7 +606,11 @@ class VectorIO(Protocol):
"""
...
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}/search", method="POST")
@webmethod(
route="/vector_stores/{vector_store_id}/search",
method="POST",
level=LLAMA_STACK_API_V1,
)
async def openai_search_vector_store(
self,
vector_store_id: str,
@ -567,7 +619,9 @@ class VectorIO(Protocol):
max_num_results: int | None = 10,
ranking_options: SearchRankingOptions | None = None,
rewrite_query: bool | None = False,
search_mode: str | None = "vector", # Using str instead of Literal due to OpenAPI schema generator limitations
search_mode: (
str | None
) = "vector", # Using str instead of Literal due to OpenAPI schema generator limitations
) -> VectorStoreSearchResponsePage:
"""Search for chunks in a vector store.
@ -584,7 +638,11 @@ class VectorIO(Protocol):
"""
...
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}/files", method="POST")
@webmethod(
route="/vector_stores/{vector_store_id}/files",
method="POST",
level=LLAMA_STACK_API_V1,
)
async def openai_attach_file_to_vector_store(
self,
vector_store_id: str,
@ -602,7 +660,11 @@ class VectorIO(Protocol):
"""
...
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}/files", method="GET")
@webmethod(
route="/vector_stores/{vector_store_id}/files",
method="GET",
level=LLAMA_STACK_API_V1,
)
async def openai_list_files_in_vector_store(
self,
vector_store_id: str,
@ -624,7 +686,11 @@ class VectorIO(Protocol):
"""
...
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}", method="GET")
@webmethod(
route="/vector_stores/{vector_store_id}/files/{file_id}",
method="GET",
level=LLAMA_STACK_API_V1,
)
async def openai_retrieve_vector_store_file(
self,
vector_store_id: str,
@ -638,7 +704,11 @@ class VectorIO(Protocol):
"""
...
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}/content", method="GET")
@webmethod(
route="/vector_stores/{vector_store_id}/files/{file_id}/content",
method="GET",
level=LLAMA_STACK_API_V1,
)
async def openai_retrieve_vector_store_file_contents(
self,
vector_store_id: str,
@ -652,7 +722,11 @@ class VectorIO(Protocol):
"""
...
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}", method="POST")
@webmethod(
route="/vector_stores/{vector_store_id}/files/{file_id}",
method="POST",
level=LLAMA_STACK_API_V1,
)
async def openai_update_vector_store_file(
self,
vector_store_id: str,
@ -668,7 +742,11 @@ class VectorIO(Protocol):
"""
...
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}", method="DELETE")
@webmethod(
route="/vector_stores/{vector_store_id}/files/{file_id}",
method="DELETE",
level=LLAMA_STACK_API_V1,
)
async def openai_delete_vector_store_file(
self,
vector_store_id: str,
@ -681,3 +759,89 @@ class VectorIO(Protocol):
:returns: A VectorStoreFileDeleteResponse indicating the deletion status.
"""
...
@webmethod(
route="/vector_stores/{vector_store_id}/file_batches",
method="POST",
level=LLAMA_STACK_API_V1,
)
async def openai_create_vector_store_file_batch(
self,
vector_store_id: str,
file_ids: list[str],
attributes: dict[str, Any] | None = None,
chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileBatchObject:
"""Create a vector store file batch.
:param vector_store_id: The ID of the vector store to create the file batch for.
:param file_ids: A list of File IDs that the vector store should use.
:param attributes: (Optional) Key-value attributes to store with the files.
:param chunking_strategy: (Optional) The chunking strategy used to chunk the file(s). Defaults to auto.
:returns: A VectorStoreFileBatchObject representing the created file batch.
"""
...
@webmethod(
route="/vector_stores/{vector_store_id}/file_batches/{batch_id}",
method="GET",
level=LLAMA_STACK_API_V1,
)
async def openai_retrieve_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
) -> VectorStoreFileBatchObject:
"""Retrieve a vector store file batch.
:param batch_id: The ID of the file batch to retrieve.
:param vector_store_id: The ID of the vector store containing the file batch.
:returns: A VectorStoreFileBatchObject representing the file batch.
"""
...
@webmethod(
route="/vector_stores/{vector_store_id}/file_batches/{batch_id}/files",
method="GET",
level=LLAMA_STACK_API_V1,
)
async def openai_list_files_in_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
after: str | None = None,
before: str | None = None,
filter: str | None = None,
limit: int | None = 20,
order: str | None = "desc",
) -> VectorStoreFilesListInBatchResponse:
"""Returns a list of vector store files in a batch.
:param batch_id: The ID of the file batch to list files from.
:param vector_store_id: The ID of the vector store containing the file batch.
:param after: A cursor for use in pagination. `after` is an object ID that defines your place in the list.
:param before: A cursor for use in pagination. `before` is an object ID that defines your place in the list.
:param filter: Filter by file status. One of in_progress, completed, failed, cancelled.
:param limit: A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.
:param order: Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and `desc` for descending order.
:returns: A VectorStoreFilesListInBatchResponse containing the list of files in the batch.
"""
...
@webmethod(
route="/vector_stores/{vector_store_id}/file_batches/{batch_id}/cancel",
method="POST",
level=LLAMA_STACK_API_V1,
)
async def openai_cancel_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
) -> VectorStoreFileBatchObject:
"""Cancels a vector store file batch.
:param batch_id: The ID of the file batch to cancel.
:param vector_store_id: The ID of the vector store containing the file batch.
:returns: A VectorStoreFileBatchObject representing the cancelled file batch.
"""
...

View file

@ -4,4 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
LLAMA_STACK_API_VERSION = "v1"
LLAMA_STACK_API_V1 = "v1"
LLAMA_STACK_API_V1BETA = "v1beta"
LLAMA_STACK_API_V1ALPHA = "v1alpha"

View file

@ -48,15 +48,12 @@ def setup_verify_download_parser(parser: argparse.ArgumentParser) -> None:
parser.set_defaults(func=partial(run_verify_cmd, parser=parser))
def calculate_md5(filepath: Path, chunk_size: int = 8192) -> str:
# NOTE: MD5 is used here only for download integrity verification,
# not for security purposes
# TODO: switch to SHA256
md5_hash = hashlib.md5(usedforsecurity=False)
def calculate_sha256(filepath: Path, chunk_size: int = 8192) -> str:
sha256_hash = hashlib.sha256()
with open(filepath, "rb") as f:
for chunk in iter(lambda: f.read(chunk_size), b""):
md5_hash.update(chunk)
return md5_hash.hexdigest()
sha256_hash.update(chunk)
return sha256_hash.hexdigest()
def load_checksums(checklist_path: Path) -> dict[str, str]:
@ -64,10 +61,10 @@ def load_checksums(checklist_path: Path) -> dict[str, str]:
with open(checklist_path) as f:
for line in f:
if line.strip():
md5sum, filepath = line.strip().split(" ", 1)
sha256sum, filepath = line.strip().split(" ", 1)
# Remove leading './' if present
filepath = filepath.lstrip("./")
checksums[filepath] = md5sum
checksums[filepath] = sha256sum
return checksums
@ -88,7 +85,7 @@ def verify_files(model_dir: Path, checksums: dict[str, str], console: Console) -
matches = False
if exists:
actual_hash = calculate_md5(full_path)
actual_hash = calculate_sha256(full_path)
matches = actual_hash == expected_hash
results.append(

View file

@ -147,7 +147,7 @@ WORKDIR /app
RUN dnf -y update && dnf install -y iputils git net-tools wget \
vim-minimal python3.12 python3.12-pip python3.12-wheel \
python3.12-setuptools python3.12-devel gcc make && \
python3.12-setuptools python3.12-devel gcc gcc-c++ make && \
ln -s /bin/pip3.12 /bin/pip && ln -s /bin/python3.12 /bin/python && dnf clean all
ENV UV_SYSTEM_PYTHON=1
@ -164,7 +164,7 @@ RUN apt-get update && apt-get install -y \
procps psmisc lsof \
traceroute \
bubblewrap \
gcc \
gcc g++ \
&& rm -rf /var/lib/apt/lists/*
ENV UV_SYSTEM_PYTHON=1

View file

@ -15,7 +15,6 @@ import httpx
from pydantic import BaseModel, parse_obj_as
from termcolor import cprint
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
from llama_stack.providers.datatypes import RemoteProviderConfig
_CLIENT_CLASSES = {}
@ -114,7 +113,24 @@ def create_api_client_class(protocol) -> type:
break
kwargs[param.name] = args[i]
url = f"{self.base_url}/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
# Get all webmethods for this method (supports multiple decorators)
webmethods = getattr(method, "__webmethods__", [])
if not webmethods:
raise RuntimeError(f"Method {method} has no webmethod decorators")
# Choose the preferred webmethod (non-deprecated if available)
preferred_webmethod = None
for wm in webmethods:
if not getattr(wm, "deprecated", False):
preferred_webmethod = wm
break
# If no non-deprecated found, use the first one
if preferred_webmethod is None:
preferred_webmethod = webmethods[0]
url = f"{self.base_url}/{preferred_webmethod.level}/{preferred_webmethod.route.lstrip('/')}"
def convert(value):
if isinstance(value, list):

View file

@ -121,10 +121,6 @@ class AutoRoutedProviderSpec(ProviderSpec):
default=None,
)
@property
def pip_packages(self) -> list[str]:
raise AssertionError("Should not be called on AutoRoutedProviderSpec")
# Example: /models, /shields
class RoutingTableProviderSpec(ProviderSpec):
@ -437,6 +433,12 @@ class InferenceStoreConfig(BaseModel):
num_writers: int = Field(default=4, description="Number of concurrent background writers")
class ResponsesStoreConfig(BaseModel):
sql_store_config: SqlStoreConfig
max_write_queue_size: int = Field(default=10000, description="Max queued writes for responses store")
num_writers: int = Field(default=4, description="Number of concurrent background writers")
class StackRunConfig(BaseModel):
version: int = LLAMA_STACK_RUN_CONFIG_VERSION

View file

@ -16,16 +16,18 @@ from llama_stack.core.datatypes import BuildConfig, DistributionSpec
from llama_stack.core.external import load_external_apis
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
RemoteProviderSpec,
)
logger = get_logger(name=__name__, category="core")
INTERNAL_APIS = {Api.inspect, Api.providers, Api.prompts}
def stack_apis() -> list[Api]:
return list(Api)
@ -70,31 +72,16 @@ def builtin_automatically_routed_apis() -> list[AutoRoutedApiInfo]:
def providable_apis() -> list[Api]:
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers]
return [api for api in Api if api not in routing_table_apis and api not in INTERNAL_APIS]
def _load_remote_provider_spec(spec_data: dict[str, Any], api: Api) -> ProviderSpec:
adapter = AdapterSpec(**spec_data["adapter"])
spec = remote_provider_spec(
api=api,
adapter=adapter,
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
)
spec = RemoteProviderSpec(api=api, provider_type=f"remote::{spec_data['adapter_type']}", **spec_data)
return spec
def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_name: str) -> ProviderSpec:
spec = InlineProviderSpec(
api=api,
provider_type=f"inline::{provider_name}",
pip_packages=spec_data.get("pip_packages", []),
module=spec_data["module"],
config_class=spec_data["config_class"],
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
optional_api_dependencies=[Api(dep) for dep in spec_data.get("optional_api_dependencies", [])],
provider_data_validator=spec_data.get("provider_data_validator"),
container_image=spec_data.get("container_image"),
)
spec = InlineProviderSpec(api=api, provider_type=f"inline::{provider_name}", **spec_data)
return spec

View file

@ -40,7 +40,7 @@ from llama_stack.core.request_headers import (
from llama_stack.core.resolver import ProviderRegistry
from llama_stack.core.server.routes import RouteImpls, find_matching_route, initialize_route_impls
from llama_stack.core.stack import (
construct_stack,
Stack,
get_stack_run_config_from_distro,
replace_env_vars,
)
@ -252,7 +252,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
try:
self.route_impls = None
self.impls = await construct_stack(self.config, self.custom_provider_registry)
stack = Stack(self.config, self.custom_provider_registry)
await stack.initialize()
self.impls = stack.impls
except ModuleNotFoundError as _e:
cprint(_e.msg, color="red", file=sys.stderr)
cprint(
@ -289,6 +292,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
)
raise _e
assert self.impls is not None
if Api.telemetry in self.impls:
setup_logger(self.impls[Api.telemetry])

View file

@ -29,6 +29,7 @@ from llama_stack.apis.telemetry import Telemetry
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDBs
from llama_stack.apis.vector_io import VectorIO
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
from llama_stack.core.client import get_client_impl
from llama_stack.core.datatypes import (
AccessRule,
@ -412,8 +413,14 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
mro = type(obj).__mro__
for name, value in inspect.getmembers(protocol):
if inspect.isfunction(value) and hasattr(value, "__webmethod__"):
if value.__webmethod__.experimental:
if inspect.isfunction(value) and hasattr(value, "__webmethods__"):
has_alpha_api = False
for webmethod in value.__webmethods__:
if webmethod.level == LLAMA_STACK_API_V1ALPHA:
has_alpha_api = True
break
# if this API has multiple webmethods, and one of them is an alpha API, this API should be skipped when checking for missing or not callable routes
if has_alpha_api:
continue
if not hasattr(obj, name):
missing_methods.append((name, "missing"))

View file

@ -16,20 +16,15 @@ from pydantic import Field, TypeAdapter
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
from llama_stack.apis.inference import (
BatchChatCompletionResponse,
BatchCompletionResponse,
ChatCompletionResponse,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
ListOpenAIChatCompletionResponse,
LogProbConfig,
@ -50,7 +45,6 @@ from llama_stack.apis.inference import (
ResponseFormat,
SamplingParams,
StopReason,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
@ -273,30 +267,6 @@ class InferenceRouter(Inference):
)
return response
async def batch_chat_completion(
self,
model_id: str,
messages_batch: list[list[Message]],
tools: list[ToolDefinition] | None = None,
tool_config: ToolConfig | None = None,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
) -> BatchChatCompletionResponse:
logger.debug(
f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
)
provider = await self.routing_table.get_provider_impl(model_id)
return await provider.batch_chat_completion(
model_id=model_id,
messages_batch=messages_batch,
tools=tools,
tool_config=tool_config,
sampling_params=sampling_params,
response_format=response_format,
logprobs=logprobs,
)
async def completion(
self,
model_id: str,
@ -338,39 +308,6 @@ class InferenceRouter(Inference):
return response
async def batch_completion(
self,
model_id: str,
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None,
) -> BatchCompletionResponse:
logger.debug(
f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
)
provider = await self.routing_table.get_provider_impl(model_id)
return await provider.batch_completion(model_id, content_batch, sampling_params, response_format, logprobs)
async def embeddings(
self,
model_id: str,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
logger.debug(f"InferenceRouter.embeddings: {model_id}")
await self._get_model(model_id, ModelType.embedding)
provider = await self.routing_table.get_provider_impl(model_id)
return await provider.embeddings(
model_id=model_id,
contents=contents,
text_truncation=text_truncation,
output_dimension=output_dimension,
task_type=task_type,
)
async def openai_completion(
self,
model: str,

View file

@ -8,9 +8,7 @@ import asyncio
import uuid
from typing import Any
from llama_stack.apis.common.content_types import (
InterleavedContent,
)
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.models import ModelType
from llama_stack.apis.vector_io import (
Chunk,
@ -19,9 +17,11 @@ from llama_stack.apis.vector_io import (
VectorIO,
VectorStoreChunkingStrategy,
VectorStoreDeleteResponse,
VectorStoreFileBatchObject,
VectorStoreFileContentsResponse,
VectorStoreFileDeleteResponse,
VectorStoreFileObject,
VectorStoreFilesListInBatchResponse,
VectorStoreFileStatus,
VectorStoreListResponse,
VectorStoreObject,
@ -193,7 +193,10 @@ class VectorIORouter(VectorIO):
all_stores = all_stores[after_index + 1 :]
if before:
before_index = next((i for i, store in enumerate(all_stores) if store.id == before), len(all_stores))
before_index = next(
(i for i, store in enumerate(all_stores) if store.id == before),
len(all_stores),
)
all_stores = all_stores[:before_index]
# Apply limit
@ -363,3 +366,61 @@ class VectorIORouter(VectorIO):
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
)
return health_statuses
async def openai_create_vector_store_file_batch(
self,
vector_store_id: str,
file_ids: list[str],
attributes: dict[str, Any] | None = None,
chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileBatchObject:
logger.debug(f"VectorIORouter.openai_create_vector_store_file_batch: {vector_store_id}, {len(file_ids)} files")
return await self.routing_table.openai_create_vector_store_file_batch(
vector_store_id=vector_store_id,
file_ids=file_ids,
attributes=attributes,
chunking_strategy=chunking_strategy,
)
async def openai_retrieve_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
) -> VectorStoreFileBatchObject:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_batch: {batch_id}, {vector_store_id}")
return await self.routing_table.openai_retrieve_vector_store_file_batch(
batch_id=batch_id,
vector_store_id=vector_store_id,
)
async def openai_list_files_in_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
after: str | None = None,
before: str | None = None,
filter: str | None = None,
limit: int | None = 20,
order: str | None = "desc",
) -> VectorStoreFilesListInBatchResponse:
logger.debug(f"VectorIORouter.openai_list_files_in_vector_store_file_batch: {batch_id}, {vector_store_id}")
return await self.routing_table.openai_list_files_in_vector_store_file_batch(
batch_id=batch_id,
vector_store_id=vector_store_id,
after=after,
before=before,
filter=filter,
limit=limit,
order=order,
)
async def openai_cancel_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
) -> VectorStoreFileBatchObject:
logger.debug(f"VectorIORouter.openai_cancel_vector_store_file_batch: {batch_id}, {vector_store_id}")
return await self.routing_table.openai_cancel_vector_store_file_batch(
batch_id=batch_id,
vector_store_id=vector_store_id,
)

View file

@ -56,3 +56,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
provider_resource_id=provider_benchmark_id,
)
await self.register_object(benchmark)
async def unregister_benchmark(self, benchmark_id: str) -> None:
existing_benchmark = await self.get_benchmark(benchmark_id)
await self.unregister_object(existing_benchmark)

View file

@ -64,6 +64,10 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
return await p.unregister_shield(obj.identifier)
elif api == Api.datasetio:
return await p.unregister_dataset(obj.identifier)
elif api == Api.eval:
return await p.unregister_benchmark(obj.identifier)
elif api == Api.scoring:
return await p.unregister_scoring_function(obj.identifier)
elif api == Api.tool_runtime:
return await p.unregister_toolgroup(obj.identifier)
else:

View file

@ -33,7 +33,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
try:
models = await provider.list_models()
except Exception as e:
logger.exception(f"Model refresh failed for provider {provider_id}: {e}")
logger.warning(f"Model refresh failed for provider {provider_id}: {e}")
continue
self.listed_providers.add(provider_id)

View file

@ -60,3 +60,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
)
scoring_fn.provider_id = provider_id
await self.register_object(scoring_fn)
async def unregister_scoring_function(self, scoring_fn_id: str) -> None:
existing_scoring_fn = await self.get_scoring_function(scoring_fn_id)
await self.unregister_object(existing_scoring_fn)

View file

@ -9,7 +9,7 @@ from typing import Any
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.errors import ToolGroupNotFoundError
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
from llama_stack.core.datatypes import ToolGroupWithOwner
from llama_stack.core.datatypes import AuthenticationRequiredError, ToolGroupWithOwner
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
@ -54,7 +54,18 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
all_tools = []
for toolgroup in toolgroups:
if toolgroup.identifier not in self.toolgroups_to_tools:
await self._index_tools(toolgroup)
try:
await self._index_tools(toolgroup)
except AuthenticationRequiredError:
# Send authentication errors back to the client so it knows
# that it needs to supply credentials for remote MCP servers.
raise
except Exception as e:
# Other errors that the client cannot fix are logged and
# those specific toolgroups are skipped.
logger.warning(f"Error listing tools for toolgroup {toolgroup.identifier}: {e}")
logger.debug(e, exc_info=True)
continue
all_tools.extend(self.toolgroups_to_tools[toolgroup.identifier])
return ListToolsResponse(data=all_tools)

View file

@ -14,7 +14,6 @@ from starlette.routing import Route
from llama_stack.apis.datatypes import Api, ExternalApiSpec
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
from llama_stack.core.resolver import api_protocol_map
from llama_stack.schema_utils import WebMethod
@ -54,22 +53,23 @@ def get_all_api_routes(
protocol_methods.append((f"{tool_group.value}.{name}", method))
for name, method in protocol_methods:
if not hasattr(method, "__webmethod__"):
# Get all webmethods for this method (supports multiple decorators)
webmethods = getattr(method, "__webmethods__", [])
if not webmethods:
continue
# The __webmethod__ attribute is dynamically added by the @webmethod decorator
# mypy doesn't know about this dynamic attribute, so we ignore the attr-defined error
webmethod = method.__webmethod__ # type: ignore[attr-defined]
path = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
if webmethod.method == hdrs.METH_GET:
http_method = hdrs.METH_GET
elif webmethod.method == hdrs.METH_DELETE:
http_method = hdrs.METH_DELETE
else:
http_method = hdrs.METH_POST
routes.append(
(Route(path=path, methods=[http_method], name=name, endpoint=None), webmethod)
) # setting endpoint to None since don't use a Router object
# Create routes for each webmethod decorator
for webmethod in webmethods:
path = f"/{webmethod.level}/{webmethod.route.lstrip('/')}"
if webmethod.method == hdrs.METH_GET:
http_method = hdrs.METH_GET
elif webmethod.method == hdrs.METH_DELETE:
http_method = hdrs.METH_DELETE
else:
http_method = hdrs.METH_POST
routes.append(
(Route(path=path, methods=[http_method], name=name, endpoint=None), webmethod)
) # setting endpoint to None since don't use a Router object
apis[api] = routes

View file

@ -6,6 +6,7 @@
import argparse
import asyncio
import concurrent.futures
import functools
import inspect
import json
@ -24,7 +25,6 @@ from typing import Annotated, Any, get_origin
import httpx
import rich.pretty
import yaml
from aiohttp import hdrs
from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi import Path as FastapiPath
from fastapi.exceptions import RequestValidationError
@ -44,23 +44,17 @@ from llama_stack.core.datatypes import (
process_cors_config,
)
from llama_stack.core.distribution import builtin_automatically_routed_apis
from llama_stack.core.external import ExternalApiSpec, load_external_apis
from llama_stack.core.external import load_external_apis
from llama_stack.core.request_headers import (
PROVIDER_DATA_VAR,
request_provider_data_context,
user_from_scope,
)
from llama_stack.core.resolver import InvalidProviderError
from llama_stack.core.server.routes import (
find_matching_route,
get_all_api_routes,
initialize_route_impls,
)
from llama_stack.core.server.routes import get_all_api_routes
from llama_stack.core.stack import (
Stack,
cast_image_name_to_string,
construct_stack,
replace_env_vars,
shutdown_stack,
validate_env_pair,
)
from llama_stack.core.utils.config import redact_sensitive_fields
@ -74,13 +68,12 @@ from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
)
from llama_stack.providers.utils.telemetry.tracing import (
CURRENT_TRACE_CONTEXT,
end_trace,
setup_logger,
start_trace,
)
from .auth import AuthenticationMiddleware
from .quota import QuotaMiddleware
from .tracing import TracingMiddleware
REPO_ROOT = Path(__file__).parent.parent.parent.parent
@ -156,21 +149,34 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
)
async def shutdown(app):
"""Initiate a graceful shutdown of the application.
Handled by the lifespan context manager. The shutdown process involves
shutting down all implementations registered in the application.
class StackApp(FastAPI):
"""
await shutdown_stack(app.__llama_stack_impls__)
A wrapper around the FastAPI application to hold a reference to the Stack instance so that we can
start background tasks (e.g. refresh model registry periodically) from the lifespan context manager.
"""
def __init__(self, config: StackRunConfig, *args, **kwargs):
super().__init__(*args, **kwargs)
self.stack: Stack = Stack(config)
# This code is called from a running event loop managed by uvicorn so we cannot simply call
# asyncio.run() to initialize the stack. We cannot await either since this is not an async
# function.
# As a workaround, we use a thread pool executor to run the initialize() method
# in a separate thread.
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(asyncio.run, self.stack.initialize())
future.result()
@asynccontextmanager
async def lifespan(app: FastAPI):
async def lifespan(app: StackApp):
logger.info("Starting up")
assert app.stack is not None
app.stack.create_registry_refresh_task()
yield
logger.info("Shutting down")
await shutdown(app)
await app.stack.shutdown()
def is_streaming_request(func_name: str, request: Request, **kwargs):
@ -287,65 +293,6 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
return route_handler
class TracingMiddleware:
def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]):
self.app = app
self.impls = impls
self.external_apis = external_apis
# FastAPI built-in paths that should bypass custom routing
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
async def __call__(self, scope, receive, send):
if scope.get("type") == "lifespan":
return await self.app(scope, receive, send)
path = scope.get("path", "")
# Check if the path is a FastAPI built-in path
if path.startswith(self.fastapi_paths):
# Pass through to FastAPI's built-in handlers
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
return await self.app(scope, receive, send)
if not hasattr(self, "route_impls"):
self.route_impls = initialize_route_impls(self.impls, self.external_apis)
try:
_, _, route_path, webmethod = find_matching_route(
scope.get("method", hdrs.METH_GET), path, self.route_impls
)
except ValueError:
# If no matching endpoint is found, pass through to FastAPI
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
return await self.app(scope, receive, send)
trace_attributes = {"__location__": "server", "raw_path": path}
# Extract W3C trace context headers and store as trace attributes
headers = dict(scope.get("headers", []))
traceparent = headers.get(b"traceparent", b"").decode()
if traceparent:
trace_attributes["traceparent"] = traceparent
tracestate = headers.get(b"tracestate", b"").decode()
if tracestate:
trace_attributes["tracestate"] = tracestate
trace_path = webmethod.descriptive_name or route_path
trace_context = await start_trace(trace_path, trace_attributes)
async def send_with_trace_id(message):
if message["type"] == "http.response.start":
headers = message.get("headers", [])
headers.append([b"x-trace-id", str(trace_context.trace_id).encode()])
message["headers"] = headers
await send(message)
try:
return await self.app(scope, receive, send_with_trace_id)
finally:
await end_trace()
class ClientVersionMiddleware:
def __init__(self, app):
self.app = app
@ -386,73 +333,61 @@ class ClientVersionMiddleware:
return await self.app(scope, receive, send)
def main(args: argparse.Namespace | None = None):
"""Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
def create_app(
config_file: str | None = None,
env_vars: list[str] | None = None,
) -> StackApp:
"""Create and configure the FastAPI application.
add_config_distro_args(parser)
parser.add_argument(
"--port",
type=int,
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
help="Port to listen on",
)
parser.add_argument(
"--env",
action="append",
help="Environment variables in KEY=value format. Can be specified multiple times.",
)
Args:
config_file: Path to config file. If None, uses LLAMA_STACK_CONFIG env var or default resolution.
env_vars: List of environment variables in KEY=value format.
disable_version_check: Whether to disable version checking. If None, uses LLAMA_STACK_DISABLE_VERSION_CHECK env var.
# Determine whether the server args are being passed by the "run" command, if this is the case
# the args will be passed as a Namespace object to the main function, otherwise they will be
# parsed from the command line
if args is None:
args = parser.parse_args()
Returns:
Configured StackApp instance.
"""
config_file = config_file or os.getenv("LLAMA_STACK_CONFIG")
if config_file is None:
raise ValueError("No config file provided and LLAMA_STACK_CONFIG env var is not set")
config_or_distro = get_config_from_args(args)
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
config_file = resolve_config_or_distro(config_file, Mode.RUN)
# Load and process configuration
logger_config = None
with open(config_file) as fp:
config_contents = yaml.safe_load(fp)
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
logger_config = LoggingConfig(**cfg)
logger = get_logger(name=__name__, category="core::server", config=logger_config)
if args.env:
for env_pair in args.env:
if env_vars:
for env_pair in env_vars:
try:
key, value = validate_env_pair(env_pair)
logger.info(f"Setting CLI environment variable {key} => {value}")
logger.info(f"Setting environment variable {key} => {value}")
os.environ[key] = value
except ValueError as e:
logger.error(f"Error: {str(e)}")
sys.exit(1)
raise ValueError(f"Invalid environment variable format: {env_pair}") from e
config = replace_env_vars(config_contents)
config = StackRunConfig(**cast_image_name_to_string(config))
_log_run_config(run_config=config)
app = FastAPI(
app = StackApp(
lifespan=lifespan,
docs_url="/docs",
redoc_url="/redoc",
openapi_url="/openapi.json",
config=config,
)
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
app.add_middleware(ClientVersionMiddleware)
try:
# Create and set the event loop that will be used for both construction and server runtime
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Construct the stack in the persistent event loop
impls = loop.run_until_complete(construct_stack(config))
except InvalidProviderError as e:
logger.error(f"Error: {str(e)}")
sys.exit(1)
impls = app.stack.impls
if config.server.auth:
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
@ -553,9 +488,54 @@ def main(args: argparse.Namespace | None = None):
app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler)
app.__llama_stack_impls__ = impls
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
return app
def main(args: argparse.Namespace | None = None):
"""Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
add_config_distro_args(parser)
parser.add_argument(
"--port",
type=int,
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
help="Port to listen on",
)
parser.add_argument(
"--env",
action="append",
help="Environment variables in KEY=value format. Can be specified multiple times.",
)
# Determine whether the server args are being passed by the "run" command, if this is the case
# the args will be passed as a Namespace object to the main function, otherwise they will be
# parsed from the command line
if args is None:
args = parser.parse_args()
config_or_distro = get_config_from_args(args)
try:
app = create_app(
config_file=config_or_distro,
env_vars=args.env,
)
except Exception as e:
logger.error(f"Error creating app: {str(e)}")
sys.exit(1)
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
with open(config_file) as fp:
config_contents = yaml.safe_load(fp)
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
logger_config = LoggingConfig(**cfg)
else:
logger_config = None
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
import uvicorn
# Configure SSL if certificates are provided
@ -593,7 +573,6 @@ def main(args: argparse.Namespace | None = None):
if ssl_config:
uvicorn_config.update(ssl_config)
# Run uvicorn in the existing event loop to preserve background tasks
# We need to catch KeyboardInterrupt because uvicorn's signal handling
# re-raises SIGINT signals using signal.raise_signal(), which Python
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
@ -604,13 +583,9 @@ def main(args: argparse.Namespace | None = None):
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
# signal handling but this is quite intrusive and not worth the effort.
try:
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
asyncio.run(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
except (KeyboardInterrupt, SystemExit):
logger.info("Received interrupt signal, shutting down gracefully...")
finally:
if not loop.is_closed():
logger.debug("Closing event loop")
loop.close()
def _log_run_config(run_config: StackRunConfig):

View file

@ -0,0 +1,80 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from aiohttp import hdrs
from llama_stack.core.external import ExternalApiSpec
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry.tracing import end_trace, start_trace
logger = get_logger(name=__name__, category="core::server")
class TracingMiddleware:
def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]):
self.app = app
self.impls = impls
self.external_apis = external_apis
# FastAPI built-in paths that should bypass custom routing
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
async def __call__(self, scope, receive, send):
if scope.get("type") == "lifespan":
return await self.app(scope, receive, send)
path = scope.get("path", "")
# Check if the path is a FastAPI built-in path
if path.startswith(self.fastapi_paths):
# Pass through to FastAPI's built-in handlers
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
return await self.app(scope, receive, send)
if not hasattr(self, "route_impls"):
self.route_impls = initialize_route_impls(self.impls, self.external_apis)
try:
_, _, route_path, webmethod = find_matching_route(
scope.get("method", hdrs.METH_GET), path, self.route_impls
)
except ValueError:
# If no matching endpoint is found, pass through to FastAPI
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
return await self.app(scope, receive, send)
# Log deprecation warning if route is deprecated
if getattr(webmethod, "deprecated", False):
logger.warning(
f"DEPRECATED ROUTE USED: {scope.get('method', 'GET')} {path} - "
f"This route is deprecated and may be removed in a future version. "
f"Please check the docs for the supported version."
)
trace_attributes = {"__location__": "server", "raw_path": path}
# Extract W3C trace context headers and store as trace attributes
headers = dict(scope.get("headers", []))
traceparent = headers.get(b"traceparent", b"").decode()
if traceparent:
trace_attributes["traceparent"] = traceparent
tracestate = headers.get(b"tracestate", b"").decode()
if tracestate:
trace_attributes["tracestate"] = tracestate
trace_path = webmethod.descriptive_name or route_path
trace_context = await start_trace(trace_path, trace_attributes)
async def send_with_trace_id(message):
if message["type"] == "http.response.start":
headers = message.get("headers", [])
headers.append([b"x-trace-id", str(trace_context.trace_id).encode()])
message["headers"] = headers
await send(message)
try:
return await self.app(scope, receive, send_with_trace_id)
finally:
await end_trace()

View file

@ -14,7 +14,6 @@ from typing import Any
import yaml
from llama_stack.apis.agents import Agents
from llama_stack.apis.batch_inference import BatchInference
from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
@ -54,7 +53,6 @@ class LlamaStack(
Providers,
VectorDBs,
Inference,
BatchInference,
Agents,
Safety,
SyntheticDataGeneration,
@ -315,78 +313,84 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
impls[Api.prompts] = prompts_impl
# Produces a stack of providers for the given run config. Not all APIs may be
# asked for in the run config.
async def construct_stack(
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None
) -> dict[Api, Any]:
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
from llama_stack.testing.inference_recorder import setup_inference_recording
class Stack:
def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
self.run_config = run_config
self.provider_registry = provider_registry
self.impls = None
# Produces a stack of providers for the given run config. Not all APIs may be
# asked for in the run config.
async def initialize(self):
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
from llama_stack.testing.inference_recorder import setup_inference_recording
global TEST_RECORDING_CONTEXT
TEST_RECORDING_CONTEXT = setup_inference_recording()
if TEST_RECORDING_CONTEXT:
TEST_RECORDING_CONTEXT.__enter__()
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
dist_registry, _ = await create_dist_registry(self.run_config.metadata_store, self.run_config.image_name)
policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else []
impls = await resolve_impls(
self.run_config, self.provider_registry or get_provider_registry(self.run_config), dist_registry, policy
)
# Add internal implementations after all other providers are resolved
add_internal_implementations(impls, self.run_config)
if Api.prompts in impls:
await impls[Api.prompts].initialize()
await register_resources(self.run_config, impls)
await refresh_registry_once(impls)
self.impls = impls
def create_registry_refresh_task(self):
assert self.impls is not None, "Must call initialize() before starting"
global REGISTRY_REFRESH_TASK
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(self.impls))
def cb(task):
import traceback
if task.cancelled():
logger.error("Model refresh task cancelled")
elif task.exception():
logger.error(f"Model refresh task failed: {task.exception()}")
traceback.print_exception(task.exception())
else:
logger.debug("Model refresh task completed")
REGISTRY_REFRESH_TASK.add_done_callback(cb)
async def shutdown(self):
for impl in self.impls.values():
impl_name = impl.__class__.__name__
logger.info(f"Shutting down {impl_name}")
try:
if hasattr(impl, "shutdown"):
await asyncio.wait_for(impl.shutdown(), timeout=5)
else:
logger.warning(f"No shutdown method for {impl_name}")
except TimeoutError:
logger.exception(f"Shutdown timeout for {impl_name}")
except (Exception, asyncio.CancelledError) as e:
logger.exception(f"Failed to shutdown {impl_name}: {e}")
global TEST_RECORDING_CONTEXT
TEST_RECORDING_CONTEXT = setup_inference_recording()
if TEST_RECORDING_CONTEXT:
TEST_RECORDING_CONTEXT.__enter__()
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
try:
TEST_RECORDING_CONTEXT.__exit__(None, None, None)
except Exception as e:
logger.error(f"Error during inference recording cleanup: {e}")
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
policy = run_config.server.auth.access_policy if run_config.server.auth else []
impls = await resolve_impls(
run_config, provider_registry or get_provider_registry(run_config), dist_registry, policy
)
# Add internal implementations after all other providers are resolved
add_internal_implementations(impls, run_config)
if Api.prompts in impls:
await impls[Api.prompts].initialize()
await register_resources(run_config, impls)
await refresh_registry_once(impls)
global REGISTRY_REFRESH_TASK
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(impls))
def cb(task):
import traceback
if task.cancelled():
logger.error("Model refresh task cancelled")
elif task.exception():
logger.error(f"Model refresh task failed: {task.exception()}")
traceback.print_exception(task.exception())
else:
logger.debug("Model refresh task completed")
REGISTRY_REFRESH_TASK.add_done_callback(cb)
return impls
async def shutdown_stack(impls: dict[Api, Any]):
for impl in impls.values():
impl_name = impl.__class__.__name__
logger.info(f"Shutting down {impl_name}")
try:
if hasattr(impl, "shutdown"):
await asyncio.wait_for(impl.shutdown(), timeout=5)
else:
logger.warning(f"No shutdown method for {impl_name}")
except TimeoutError:
logger.exception(f"Shutdown timeout for {impl_name}")
except (Exception, asyncio.CancelledError) as e:
logger.exception(f"Failed to shutdown {impl_name}: {e}")
global TEST_RECORDING_CONTEXT
if TEST_RECORDING_CONTEXT:
try:
TEST_RECORDING_CONTEXT.__exit__(None, None, None)
except Exception as e:
logger.error(f"Error during inference recording cleanup: {e}")
global REGISTRY_REFRESH_TASK
if REGISTRY_REFRESH_TASK:
REGISTRY_REFRESH_TASK.cancel()
global REGISTRY_REFRESH_TASK
if REGISTRY_REFRESH_TASK:
REGISTRY_REFRESH_TASK.cancel()
async def refresh_registry_once(impls: dict[Api, Any]):

View file

@ -123,6 +123,6 @@ if [[ "$env_type" == "venv" ]]; then
$other_args
elif [[ "$env_type" == "container" ]]; then
echo -e "${RED}Warning: Llama Stack no longer supports running Containers via the 'llama stack run' command.${NC}"
echo -e "Please refer to the documentation for more information: https://llama-stack.readthedocs.io/en/latest/distributions/building_distro.html#llama-stack-build"
echo -e "Please refer to the documentation for more information: https://llamastack.github.io/latest/distributions/building_distro.html#llama-stack-build"
exit 1
fi

View file

@ -6,7 +6,7 @@
## Developer Setup
1. Start up Llama Stack API server. More details [here](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html).
1. Start up Llama Stack API server. More details [here](https://llamastack.github.io/latest/getting_started/index.htmll).
```
llama stack build --distro together --image-type venv

View file

@ -17,6 +17,7 @@ distribution_spec:
- provider_type: remote::vertexai
- provider_type: remote::groq
- provider_type: remote::sambanova
- provider_type: remote::azure
- provider_type: inline::sentence-transformers
vector_io:
- provider_type: inline::faiss

View file

@ -81,6 +81,13 @@ providers:
config:
url: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY:=}
- provider_id: ${env.AZURE_API_KEY:+azure}
provider_type: remote::azure
config:
api_key: ${env.AZURE_API_KEY:=}
api_base: ${env.AZURE_API_BASE:=}
api_version: ${env.AZURE_API_VERSION:=}
api_type: ${env.AZURE_API_TYPE:=}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
vector_io:

View file

@ -23,6 +23,8 @@ distribution_spec:
- provider_type: inline::basic
tool_runtime:
- provider_type: inline::rag-runtime
files:
- provider_type: inline::localfs
image_type: venv
additional_pip_packages:
- aiosqlite

View file

@ -49,22 +49,22 @@ The deployed platform includes the NIM Proxy microservice, which is the service
### Datasetio API: NeMo Data Store
The NeMo Data Store microservice serves as the default file storage solution for the NeMo microservices platform. It exposts APIs compatible with the Hugging Face Hub client (`HfApi`), so you can use the client to interact with Data Store. The `NVIDIA_DATASETS_URL` environment variable should point to your NeMo Data Store endpoint.
See the {repopath}`NVIDIA Datasetio docs::llama_stack/providers/remote/datasetio/nvidia/README.md` for supported features and example usage.
See the [NVIDIA Datasetio docs](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/datasetio/nvidia/README.md) for supported features and example usage.
### Eval API: NeMo Evaluator
The NeMo Evaluator microservice supports evaluation of LLMs. Launching an Evaluation job with NeMo Evaluator requires an Evaluation Config (an object that contains metadata needed by the job). A Llama Stack Benchmark maps to an Evaluation Config, so registering a Benchmark creates an Evaluation Config in NeMo Evaluator. The `NVIDIA_EVALUATOR_URL` environment variable should point to your NeMo Microservices endpoint.
See the {repopath}`NVIDIA Eval docs::llama_stack/providers/remote/eval/nvidia/README.md` for supported features and example usage.
See the [NVIDIA Eval docs](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/eval/nvidia/README.md) for supported features and example usage.
### Post-Training API: NeMo Customizer
The NeMo Customizer microservice supports fine-tuning models. You can reference {repopath}`this list of supported models::llama_stack/providers/remote/post_training/nvidia/models.py` that can be fine-tuned using Llama Stack. The `NVIDIA_CUSTOMIZER_URL` environment variable should point to your NeMo Microservices endpoint.
The NeMo Customizer microservice supports fine-tuning models. You can reference [this list of supported models](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/post_training/nvidia/models.py) that can be fine-tuned using Llama Stack. The `NVIDIA_CUSTOMIZER_URL` environment variable should point to your NeMo Microservices endpoint.
See the {repopath}`NVIDIA Post-Training docs::llama_stack/providers/remote/post_training/nvidia/README.md` for supported features and example usage.
See the [NVIDIA Post-Training docs](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/post_training/nvidia/README.md) for supported features and example usage.
### Safety API: NeMo Guardrails
The NeMo Guardrails microservice sits between your application and the LLM, and adds checks and content moderation to a model. The `GUARDRAILS_SERVICE_URL` environment variable should point to your NeMo Microservices endpoint.
See the {repopath}`NVIDIA Safety docs::llama_stack/providers/remote/safety/nvidia/README.md` for supported features and example usage.
See the [NVIDIA Safety docs](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/safety/nvidia/README.md) for supported features and example usage.
## Deploying models
In order to use a registered model with the Llama Stack APIs, ensure the corresponding NIM is deployed to your environment. For example, you can use the NIM Proxy microservice to deploy `meta/llama-3.2-1b-instruct`.
@ -138,4 +138,4 @@ llama stack run ./run.yaml \
```
## Example Notebooks
For examples of how to use the NVIDIA Distribution to run inference, fine-tune, evaluate, and run safety checks on your LLMs, you can reference the example notebooks in {repopath}`docs/notebooks/nvidia`.
For examples of how to use the NVIDIA Distribution to run inference, fine-tune, evaluate, and run safety checks on your LLMs, you can reference the example notebooks in [docs/notebooks/nvidia](https://github.com/meta-llama/llama-stack/tree/main/docs/notebooks/nvidia).

View file

@ -7,15 +7,15 @@
from pathlib import Path
from llama_stack.core.datatypes import BuildProvider, ModelInput, Provider, ShieldInput, ToolGroupInput
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings, get_model_registry
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
from llama_stack.providers.remote.datasetio.nvidia import NvidiaDatasetIOConfig
from llama_stack.providers.remote.eval.nvidia import NVIDIAEvalConfig
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
def get_distribution_template() -> DistributionTemplate:
def get_distribution_template(name: str = "nvidia") -> DistributionTemplate:
providers = {
"inference": [BuildProvider(provider_type="remote::nvidia")],
"vector_io": [BuildProvider(provider_type="inline::faiss")],
@ -30,6 +30,7 @@ def get_distribution_template() -> DistributionTemplate:
],
"scoring": [BuildProvider(provider_type="inline::basic")],
"tool_runtime": [BuildProvider(provider_type="inline::rag-runtime")],
"files": [BuildProvider(provider_type="inline::localfs")],
}
inference_provider = Provider(
@ -52,6 +53,11 @@ def get_distribution_template() -> DistributionTemplate:
provider_type="remote::nvidia",
config=NVIDIAEvalConfig.sample_run_config(),
)
files_provider = Provider(
provider_id="meta-reference-files",
provider_type="inline::localfs",
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
inference_model = ModelInput(
model_id="${env.INFERENCE_MODEL}",
provider_id="nvidia",
@ -61,9 +67,6 @@ def get_distribution_template() -> DistributionTemplate:
provider_id="nvidia",
)
available_models = {
"nvidia": MODEL_ENTRIES,
}
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::rag",
@ -71,23 +74,21 @@ def get_distribution_template() -> DistributionTemplate:
),
]
default_models, _ = get_model_registry(available_models)
return DistributionTemplate(
name="nvidia",
name=name,
distro_type="self_hosted",
description="Use NVIDIA NIM for running LLM inference, evaluation and safety",
container_image=None,
template_path=Path(__file__).parent / "doc_template.md",
providers=providers,
available_models_by_provider=available_models,
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider],
"datasetio": [datasetio_provider],
"eval": [eval_provider],
"files": [files_provider],
},
default_models=default_models,
default_tool_groups=default_tool_groups,
),
"run-with-safety.yaml": RunConfigSettings(
@ -97,6 +98,7 @@ def get_distribution_template() -> DistributionTemplate:
safety_provider,
],
"eval": [eval_provider],
"files": [files_provider],
},
default_models=[inference_model, safety_model],
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}", provider_id="nvidia")],

View file

@ -4,6 +4,7 @@ apis:
- agents
- datasetio
- eval
- files
- inference
- post_training
- safety
@ -88,6 +89,14 @@ providers:
tool_runtime:
- provider_id: rag-runtime
provider_type: inline::rag-runtime
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/nvidia/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/files_metadata.db
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/registry.db

View file

@ -4,6 +4,7 @@ apis:
- agents
- datasetio
- eval
- files
- inference
- post_training
- safety
@ -77,96 +78,21 @@ providers:
tool_runtime:
- provider_id: rag-runtime
provider_type: inline::rag-runtime
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/nvidia/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/files_metadata.db
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/registry.db
inference_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/inference_store.db
models:
- metadata: {}
model_id: meta/llama3-8b-instruct
provider_id: nvidia
provider_model_id: meta/llama3-8b-instruct
model_type: llm
- metadata: {}
model_id: meta/llama3-70b-instruct
provider_id: nvidia
provider_model_id: meta/llama3-70b-instruct
model_type: llm
- metadata: {}
model_id: meta/llama-3.1-8b-instruct
provider_id: nvidia
provider_model_id: meta/llama-3.1-8b-instruct
model_type: llm
- metadata: {}
model_id: meta/llama-3.1-70b-instruct
provider_id: nvidia
provider_model_id: meta/llama-3.1-70b-instruct
model_type: llm
- metadata: {}
model_id: meta/llama-3.1-405b-instruct
provider_id: nvidia
provider_model_id: meta/llama-3.1-405b-instruct
model_type: llm
- metadata: {}
model_id: meta/llama-3.2-1b-instruct
provider_id: nvidia
provider_model_id: meta/llama-3.2-1b-instruct
model_type: llm
- metadata: {}
model_id: meta/llama-3.2-3b-instruct
provider_id: nvidia
provider_model_id: meta/llama-3.2-3b-instruct
model_type: llm
- metadata: {}
model_id: meta/llama-3.2-11b-vision-instruct
provider_id: nvidia
provider_model_id: meta/llama-3.2-11b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta/llama-3.2-90b-vision-instruct
provider_id: nvidia
provider_model_id: meta/llama-3.2-90b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta/llama-3.3-70b-instruct
provider_id: nvidia
provider_model_id: meta/llama-3.3-70b-instruct
model_type: llm
- metadata: {}
model_id: nvidia/vila
provider_id: nvidia
provider_model_id: nvidia/vila
model_type: llm
- metadata:
embedding_dimension: 2048
context_length: 8192
model_id: nvidia/llama-3.2-nv-embedqa-1b-v2
provider_id: nvidia
provider_model_id: nvidia/llama-3.2-nv-embedqa-1b-v2
model_type: embedding
- metadata:
embedding_dimension: 1024
context_length: 512
model_id: nvidia/nv-embedqa-e5-v5
provider_id: nvidia
provider_model_id: nvidia/nv-embedqa-e5-v5
model_type: embedding
- metadata:
embedding_dimension: 4096
context_length: 512
model_id: nvidia/nv-embedqa-mistral-7b-v2
provider_id: nvidia
provider_model_id: nvidia/nv-embedqa-mistral-7b-v2
model_type: embedding
- metadata:
embedding_dimension: 1024
context_length: 512
model_id: snowflake/arctic-embed-l
provider_id: nvidia
provider_model_id: snowflake/arctic-embed-l
model_type: embedding
models: []
shields: []
vector_dbs: []
datasets: []

View file

@ -18,6 +18,7 @@ distribution_spec:
- provider_type: remote::vertexai
- provider_type: remote::groq
- provider_type: remote::sambanova
- provider_type: remote::azure
- provider_type: inline::sentence-transformers
vector_io:
- provider_type: inline::faiss

View file

@ -81,6 +81,13 @@ providers:
config:
url: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY:=}
- provider_id: ${env.AZURE_API_KEY:+azure}
provider_type: remote::azure
config:
api_key: ${env.AZURE_API_KEY:=}
api_base: ${env.AZURE_API_BASE:=}
api_version: ${env.AZURE_API_VERSION:=}
api_type: ${env.AZURE_API_TYPE:=}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
vector_io:

View file

@ -18,6 +18,7 @@ distribution_spec:
- provider_type: remote::vertexai
- provider_type: remote::groq
- provider_type: remote::sambanova
- provider_type: remote::azure
- provider_type: inline::sentence-transformers
vector_io:
- provider_type: inline::faiss

View file

@ -81,6 +81,13 @@ providers:
config:
url: https://api.sambanova.ai/v1
api_key: ${env.SAMBANOVA_API_KEY:=}
- provider_id: ${env.AZURE_API_KEY:+azure}
provider_type: remote::azure
config:
api_key: ${env.AZURE_API_KEY:=}
api_base: ${env.AZURE_API_BASE:=}
api_version: ${env.AZURE_API_VERSION:=}
api_type: ${env.AZURE_API_TYPE:=}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
vector_io:

View file

@ -59,6 +59,7 @@ ENABLED_INFERENCE_PROVIDERS = [
"cerebras",
"nvidia",
"bedrock",
"azure",
]
INFERENCE_PROVIDER_IDS = {
@ -68,6 +69,7 @@ INFERENCE_PROVIDER_IDS = {
"cerebras": "${env.CEREBRAS_API_KEY:+cerebras}",
"nvidia": "${env.NVIDIA_API_KEY:+nvidia}",
"vertexai": "${env.VERTEX_AI_PROJECT:+vertexai}",
"azure": "${env.AZURE_API_KEY:+azure}",
}
@ -76,12 +78,12 @@ def get_remote_inference_providers() -> list[Provider]:
remote_providers = [
provider
for provider in available_providers()
if isinstance(provider, RemoteProviderSpec) and provider.adapter.adapter_type in ENABLED_INFERENCE_PROVIDERS
if isinstance(provider, RemoteProviderSpec) and provider.adapter_type in ENABLED_INFERENCE_PROVIDERS
]
inference_providers = []
for provider_spec in remote_providers:
provider_type = provider_spec.adapter.adapter_type
provider_type = provider_spec.adapter_type
if provider_type in INFERENCE_PROVIDER_IDS:
provider_id = INFERENCE_PROVIDER_IDS[provider_type]
@ -277,5 +279,21 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
"http://localhost:11434",
"Ollama URL",
),
"AZURE_API_KEY": (
"",
"Azure API Key",
),
"AZURE_API_BASE": (
"",
"Azure API Base",
),
"AZURE_API_VERSION": (
"",
"Azure API Version",
),
"AZURE_API_TYPE": (
"azure",
"Azure API Type",
),
},
)

View file

@ -10,6 +10,7 @@ apis:
- telemetry
- tool_runtime
- vector_io
- files
providers:
inference:
- provider_id: watsonx
@ -94,6 +95,14 @@ providers:
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/watsonx/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/files_metadata.db
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/registry.db

View file

@ -9,6 +9,7 @@ from pathlib import Path
from llama_stack.apis.models import ModelType
from llama_stack.core.datatypes import BuildProvider, ModelInput, Provider, ToolGroupInput
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings, get_model_registry
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
@ -16,7 +17,7 @@ from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
from llama_stack.providers.remote.inference.watsonx.models import MODEL_ENTRIES
def get_distribution_template() -> DistributionTemplate:
def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
providers = {
"inference": [
BuildProvider(provider_type="remote::watsonx"),
@ -42,6 +43,7 @@ def get_distribution_template() -> DistributionTemplate:
BuildProvider(provider_type="inline::rag-runtime"),
BuildProvider(provider_type="remote::model-context-protocol"),
],
"files": [BuildProvider(provider_type="inline::localfs")],
}
inference_provider = Provider(
@ -79,9 +81,14 @@ def get_distribution_template() -> DistributionTemplate:
},
)
files_provider = Provider(
provider_id="meta-reference-files",
provider_type="inline::localfs",
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
default_models, _ = get_model_registry(available_models)
return DistributionTemplate(
name="watsonx",
name=name,
distro_type="remote_hosted",
description="Use watsonx for running LLM inference",
container_image=None,
@ -92,6 +99,7 @@ def get_distribution_template() -> DistributionTemplate:
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider, embedding_provider],
"files": [files_provider],
},
default_models=default_models + [embedding_model],
default_tool_groups=default_tool_groups,

View file

@ -92,6 +92,8 @@ class ToolParamDefinition(BaseModel):
param_type: str
description: str | None = None
required: bool | None = True
items: Any | None = None
title: str | None = None
default: Any | None = None

View file

@ -131,6 +131,15 @@ class ProviderSpec(BaseModel):
""",
)
pip_packages: list[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
provider_data_validator: str | None = Field(
default=None,
)
is_external: bool = Field(default=False, description="Notes whether this provider is an external provider.")
# used internally by the resolver; this is a hack for now
@ -145,45 +154,8 @@ class RoutingTable(Protocol):
async def get_provider_impl(self, routing_key: str) -> Any: ...
# TODO: this can now be inlined into RemoteProviderSpec
@json_schema_type
class AdapterSpec(BaseModel):
adapter_type: str = Field(
...,
description="Unique identifier for this adapter",
)
module: str = Field(
default_factory=str,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the adapter implementation
""",
)
pip_packages: list[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
config_class: str = Field(
description="Fully-qualified classname of the config for this provider",
)
provider_data_validator: str | None = Field(
default=None,
)
description: str | None = Field(
default=None,
description="""
A description of the provider. This is used to display in the documentation.
""",
)
@json_schema_type
class InlineProviderSpec(ProviderSpec):
pip_packages: list[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
container_image: str | None = Field(
default=None,
description="""
@ -191,10 +163,6 @@ The container image to use for this implementation. If one is provided, pip_pack
If a provider depends on other providers, the dependencies MUST NOT specify a container image.
""",
)
# module field is inherited from ProviderSpec
provider_data_validator: str | None = Field(
default=None,
)
description: str | None = Field(
default=None,
description="""
@ -223,10 +191,15 @@ class RemoteProviderConfig(BaseModel):
@json_schema_type
class RemoteProviderSpec(ProviderSpec):
adapter: AdapterSpec = Field(
adapter_type: str = Field(
...,
description="Unique identifier for this adapter",
)
description: str | None = Field(
default=None,
description="""
If some code is needed to convert the remote responses into Llama Stack compatible
API responses, specify the adapter here.
A description of the provider. This is used to display in the documentation.
""",
)
@ -234,33 +207,6 @@ API responses, specify the adapter here.
def container_image(self) -> str | None:
return None
# module field is inherited from ProviderSpec
@property
def pip_packages(self) -> list[str]:
return self.adapter.pip_packages
@property
def provider_data_validator(self) -> str | None:
return self.adapter.provider_data_validator
def remote_provider_spec(
api: Api,
adapter: AdapterSpec,
api_dependencies: list[Api] | None = None,
optional_api_dependencies: list[Api] | None = None,
) -> RemoteProviderSpec:
return RemoteProviderSpec(
api=api,
provider_type=f"remote::{adapter.adapter_type}",
config_class=adapter.config_class,
module=adapter.module,
adapter=adapter,
api_dependencies=api_dependencies or [],
optional_api_dependencies=optional_api_dependencies or [],
)
class HealthStatus(StrEnum):
OK = "OK"

View file

@ -830,6 +830,8 @@ class ChatAgent(ShieldRunnerMixin):
param_type=param.parameter_type,
description=param.description,
required=param.required,
items=param.items,
title=param.title,
default=param.default,
)
for param in tool_def.parameters
@ -873,6 +875,8 @@ class ChatAgent(ShieldRunnerMixin):
param_type=param.parameter_type,
description=param.description,
required=param.required,
items=param.items,
title=param.title,
default=param.default,
)
for param in tool_def.parameters
@ -952,7 +956,7 @@ async def get_raw_document_text(document: Document) -> str:
DeprecationWarning,
stacklevel=2,
)
elif not (document.mime_type.startswith("text/") or document.mime_type == "application/yaml"):
elif not (document.mime_type.startswith("text/") or document.mime_type in ("application/yaml", "application/json")):
raise ValueError(f"Unexpected document mime type: {document.mime_type}")
if isinstance(document.content, URL):

View file

@ -237,6 +237,7 @@ class OpenAIResponsesImpl:
response_tools=tools,
temperature=temperature,
response_format=response_format,
inputs=input,
)
# Create orchestrator and delegate streaming logic

View file

@ -10,10 +10,12 @@ from typing import Any
from llama_stack.apis.agents.openai_responses import (
AllowedToolsFilter,
ApprovalFilter,
MCPListToolsTool,
OpenAIResponseContentPartOutputText,
OpenAIResponseInputTool,
OpenAIResponseInputToolMCP,
OpenAIResponseMCPApprovalRequest,
OpenAIResponseObject,
OpenAIResponseObjectStream,
OpenAIResponseObjectStreamResponseCompleted,
@ -50,6 +52,36 @@ from .utils import convert_chat_choice_to_response_message, is_function_tool_cal
logger = get_logger(name=__name__, category="agents::meta_reference")
def convert_tooldef_to_chat_tool(tool_def):
"""Convert a ToolDef to OpenAI ChatCompletionToolParam format.
Args:
tool_def: ToolDef from the tools API
Returns:
ChatCompletionToolParam suitable for OpenAI chat completion
"""
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
internal_tool_def = ToolDefinition(
tool_name=tool_def.name,
description=tool_def.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
items=param.items,
)
for param in tool_def.parameters
},
)
return convert_tooldef_to_openai_tool(internal_tool_def)
class StreamingResponseOrchestrator:
def __init__(
self,
@ -117,10 +149,17 @@ class StreamingResponseOrchestrator:
raise ValueError("Streaming chunk processor failed to return completion data")
current_response = self._build_chat_completion(completion_result_data)
function_tool_calls, non_function_tool_calls, next_turn_messages = self._separate_tool_calls(
function_tool_calls, non_function_tool_calls, approvals, next_turn_messages = self._separate_tool_calls(
current_response, messages
)
# add any approval requests required
for tool_call in approvals:
async for evt in self._add_mcp_approval_request(
tool_call.function.name, tool_call.function.arguments, output_messages
):
yield evt
# Handle choices with no tool calls
for choice in current_response.choices:
if not (choice.message.tool_calls and self.ctx.response_tools):
@ -164,10 +203,11 @@ class StreamingResponseOrchestrator:
# Emit response.completed
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
def _separate_tool_calls(self, current_response, messages) -> tuple[list, list, list]:
def _separate_tool_calls(self, current_response, messages) -> tuple[list, list, list, list]:
"""Separate tool calls into function and non-function categories."""
function_tool_calls = []
non_function_tool_calls = []
approvals = []
next_turn_messages = messages.copy()
for choice in current_response.choices:
@ -178,9 +218,23 @@ class StreamingResponseOrchestrator:
if is_function_tool_call(tool_call, self.ctx.response_tools):
function_tool_calls.append(tool_call)
else:
non_function_tool_calls.append(tool_call)
if self._approval_required(tool_call.function.name):
approval_response = self.ctx.approval_response(
tool_call.function.name, tool_call.function.arguments
)
if approval_response:
if approval_response.approve:
logger.info(f"Approval granted for {tool_call.id} on {tool_call.function.name}")
non_function_tool_calls.append(tool_call)
else:
logger.info(f"Approval denied for {tool_call.id} on {tool_call.function.name}")
else:
logger.info(f"Requesting approval for {tool_call.id} on {tool_call.function.name}")
approvals.append(tool_call)
else:
non_function_tool_calls.append(tool_call)
return function_tool_calls, non_function_tool_calls, next_turn_messages
return function_tool_calls, non_function_tool_calls, approvals, next_turn_messages
async def _process_streaming_chunks(
self, completion_result, output_messages: list[OpenAIResponseOutput]
@ -556,23 +610,7 @@ class StreamingResponseOrchestrator:
continue
if not always_allowed or t.name in always_allowed:
# Add to chat tools for inference
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
tool_def = ToolDefinition(
tool_name=t.name,
description=t.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in t.parameters
},
)
openai_tool = convert_tooldef_to_openai_tool(tool_def)
openai_tool = convert_tooldef_to_chat_tool(t)
if self.ctx.chat_tools is None:
self.ctx.chat_tools = []
self.ctx.chat_tools.append(openai_tool)
@ -632,3 +670,46 @@ class StreamingResponseOrchestrator:
# TODO: Emit mcp_list_tools.failed event if needed
logger.exception(f"Failed to list MCP tools from {mcp_tool.server_url}: {e}")
raise
def _approval_required(self, tool_name: str) -> bool:
if tool_name not in self.mcp_tool_to_server:
return False
mcp_server = self.mcp_tool_to_server[tool_name]
if mcp_server.require_approval == "always":
return True
if mcp_server.require_approval == "never":
return False
if isinstance(mcp_server, ApprovalFilter):
if tool_name in mcp_server.always:
return True
if tool_name in mcp_server.never:
return False
return True
async def _add_mcp_approval_request(
self, tool_name: str, arguments: str, output_messages: list[OpenAIResponseOutput]
) -> AsyncIterator[OpenAIResponseObjectStream]:
mcp_server = self.mcp_tool_to_server[tool_name]
mcp_approval_request = OpenAIResponseMCPApprovalRequest(
arguments=arguments,
id=f"approval_{uuid.uuid4()}",
name=tool_name,
server_label=mcp_server.server_label,
)
output_messages.append(mcp_approval_request)
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
response_id=self.response_id,
item=mcp_approval_request,
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemDone(
response_id=self.response_id,
item=mcp_approval_request,
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)

View file

@ -10,7 +10,10 @@ from openai.types.chat import ChatCompletionToolParam
from pydantic import BaseModel
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInput,
OpenAIResponseInputTool,
OpenAIResponseMCPApprovalRequest,
OpenAIResponseMCPApprovalResponse,
OpenAIResponseObjectStream,
OpenAIResponseOutput,
)
@ -58,3 +61,37 @@ class ChatCompletionContext(BaseModel):
chat_tools: list[ChatCompletionToolParam] | None = None
temperature: float | None
response_format: OpenAIResponseFormatParam
approval_requests: list[OpenAIResponseMCPApprovalRequest] = []
approval_responses: dict[str, OpenAIResponseMCPApprovalResponse] = {}
def __init__(
self,
model: str,
messages: list[OpenAIMessageParam],
response_tools: list[OpenAIResponseInputTool] | None,
temperature: float | None,
response_format: OpenAIResponseFormatParam,
inputs: list[OpenAIResponseInput] | str,
):
super().__init__(
model=model,
messages=messages,
response_tools=response_tools,
temperature=temperature,
response_format=response_format,
)
if not isinstance(inputs, str):
self.approval_requests = [input for input in inputs if input.type == "mcp_approval_request"]
self.approval_responses = {
input.approval_request_id: input for input in inputs if input.type == "mcp_approval_response"
}
def approval_response(self, tool_name: str, arguments: str) -> OpenAIResponseMCPApprovalResponse | None:
request = self._approval_request(tool_name, arguments)
return self.approval_responses.get(request.id, None) if request else None
def _approval_request(self, tool_name: str, arguments: str) -> OpenAIResponseMCPApprovalRequest | None:
for request in self.approval_requests:
if request.name == tool_name and request.arguments == arguments:
return request
return None

View file

@ -13,6 +13,8 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool,
OpenAIResponseMCPApprovalRequest,
OpenAIResponseMCPApprovalResponse,
OpenAIResponseMessage,
OpenAIResponseOutputMessageContent,
OpenAIResponseOutputMessageContentOutputText,
@ -149,6 +151,11 @@ async def convert_response_input_to_chat_messages(
elif isinstance(input_item, OpenAIResponseOutputMessageMCPListTools):
# the tool list will be handled separately
pass
elif isinstance(input_item, OpenAIResponseMCPApprovalRequest) or isinstance(
input_item, OpenAIResponseMCPApprovalResponse
):
# these are handled by the responses impl itself and not pass through to chat completions
pass
else:
content = await convert_response_content_to_chat_content(input_item.content)
message_type = await get_message_type_by_role(input_item.role)

View file

@ -12,7 +12,7 @@ from llama_stack.apis.agents import Agents, StepType
from llama_stack.apis.benchmarks import Benchmark
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.inference import Inference, SystemMessage, UserMessage
from llama_stack.apis.inference import Inference, OpenAISystemMessageParam, OpenAIUserMessageParam, UserMessage
from llama_stack.apis.scoring import Scoring
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
@ -75,6 +75,13 @@ class MetaReferenceEvalImpl(
)
self.benchmarks[task_def.identifier] = task_def
async def unregister_benchmark(self, benchmark_id: str) -> None:
if benchmark_id in self.benchmarks:
del self.benchmarks[benchmark_id]
key = f"{EVAL_TASKS_PREFIX}{benchmark_id}"
await self.kvstore.delete(key)
async def run_eval(
self,
benchmark_id: str,
@ -152,31 +159,40 @@ class MetaReferenceEvalImpl(
) -> list[dict[str, Any]]:
candidate = benchmark_config.eval_candidate
assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided"
sampling_params = {"max_tokens": candidate.sampling_params.max_tokens}
generations = []
for x in tqdm(input_rows):
if ColumnName.completion_input.value in x:
if candidate.sampling_params.stop:
sampling_params["stop"] = candidate.sampling_params.stop
input_content = json.loads(x[ColumnName.completion_input.value])
response = await self.inference_api.completion(
response = await self.inference_api.openai_completion(
model=candidate.model,
content=input_content,
sampling_params=candidate.sampling_params,
prompt=input_content,
**sampling_params,
)
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
generations.append({ColumnName.generated_answer.value: response.choices[0].text})
elif ColumnName.chat_completion_input.value in x:
chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
input_messages = [UserMessage(**x) for x in chat_completion_input_json if x["role"] == "user"]
input_messages = [
OpenAIUserMessageParam(**x) for x in chat_completion_input_json if x["role"] == "user"
]
messages = []
if candidate.system_message:
messages.append(candidate.system_message)
messages += [SystemMessage(**x) for x in chat_completion_input_json if x["role"] == "system"]
messages += [OpenAISystemMessageParam(**x) for x in chat_completion_input_json if x["role"] == "system"]
messages += input_messages
response = await self.inference_api.chat_completion(
model_id=candidate.model,
response = await self.inference_api.openai_chat_completion(
model=candidate.model,
messages=messages,
sampling_params=candidate.sampling_params,
**sampling_params,
)
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
generations.append({ColumnName.generated_answer.value: response.choices[0].message.content})
else:
raise ValueError("Invalid input row")

View file

@ -9,11 +9,12 @@ import uuid
from pathlib import Path
from typing import Annotated
from fastapi import File, Form, Response, UploadFile
from fastapi import Depends, File, Form, Response, UploadFile
from llama_stack.apis.common.errors import ResourceNotFoundError
from llama_stack.apis.common.responses import Order
from llama_stack.apis.files import (
ExpiresAfter,
Files,
ListOpenAIFileResponse,
OpenAIFileDeleteResponse,
@ -22,6 +23,7 @@ from llama_stack.apis.files import (
)
from llama_stack.core.datatypes import AccessRule
from llama_stack.log import get_logger
from llama_stack.providers.utils.files.form_data import parse_expires_after
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
@ -44,7 +46,7 @@ class LocalfsFilesImpl(Files):
storage_path.mkdir(parents=True, exist_ok=True)
# Initialize SQL store for metadata
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store))
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store), self.policy)
await self.sql_store.create_table(
"openai_files",
{
@ -74,7 +76,7 @@ class LocalfsFilesImpl(Files):
if not self.sql_store:
raise RuntimeError("Files provider not initialized")
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
if not row:
raise ResourceNotFoundError(file_id, "File", "client.files.list()")
@ -86,14 +88,13 @@ class LocalfsFilesImpl(Files):
self,
file: Annotated[UploadFile, File()],
purpose: Annotated[OpenAIFilePurpose, Form()],
expires_after_anchor: Annotated[str | None, Form(alias="expires_after[anchor]")] = None,
expires_after_seconds: Annotated[int | None, Form(alias="expires_after[seconds]")] = None,
expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None,
) -> OpenAIFileObject:
"""Upload a file that can be used across various endpoints."""
if not self.sql_store:
raise RuntimeError("Files provider not initialized")
if expires_after_anchor is not None or expires_after_seconds is not None:
if expires_after is not None:
raise NotImplementedError("File expiration is not supported by this provider")
file_id = self._generate_file_id()
@ -150,7 +151,6 @@ class LocalfsFilesImpl(Files):
paginated_result = await self.sql_store.fetch_all(
table="openai_files",
policy=self.policy,
where=where_conditions if where_conditions else None,
order_by=[("created_at", order.value)],
cursor=("id", after) if after else None,

View file

@ -18,8 +18,6 @@ from llama_stack.apis.common.content_types import (
ToolCallParseStatus,
)
from llama_stack.apis.inference import (
BatchChatCompletionResponse,
BatchCompletionResponse,
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
@ -219,41 +217,6 @@ class MetaReferenceInferenceImpl(
results = await self._nonstream_completion([request])
return results[0]
async def batch_completion(
self,
model_id: str,
content_batch: list[InterleavedContent],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> BatchCompletionResponse:
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
content_batch = [
augment_content_with_response_format_prompt(response_format, content) for content in content_batch
]
request_batch = []
for content in content_batch:
request = CompletionRequest(
model=model_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
self.check_model(request)
request = await convert_request_to_raw(request)
request_batch.append(request)
results = await self._nonstream_completion(request_batch)
return BatchCompletionResponse(batch=results)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
tokenizer = self.generator.formatter.tokenizer
@ -399,49 +362,6 @@ class MetaReferenceInferenceImpl(
results = await self._nonstream_chat_completion([request])
return results[0]
async def batch_chat_completion(
self,
model_id: str,
messages_batch: list[list[Message]],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> BatchChatCompletionResponse:
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs:
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request_batch = []
for messages in messages_batch:
request = ChatCompletionRequest(
model=model_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
response_format=response_format,
logprobs=logprobs,
tool_config=tool_config or ToolConfig(),
)
self.check_model(request)
# augment and rewrite messages depending on the model
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
# download media and convert to raw content so we can send it to the model
request = await convert_request_to_raw(request)
request_batch.append(request)
if self.config.create_distributed_process_group:
if SEMAPHORE.locked():
raise RuntimeError("Only one concurrent request is supported")
results = await self._nonstream_chat_completion(request_batch)
return BatchChatCompletionResponse(batch=results)
async def _nonstream_chat_completion(
self, request_batch: list[ChatCompletionRequest]
) -> list[ChatCompletionResponse]:

View file

@ -290,13 +290,13 @@ class LlamaGuardShield:
else:
shield_input_message = self.build_text_shield_input(messages)
# TODO: llama-stack inference protocol has issues with non-streaming inference code
response = await self.inference_api.chat_completion(
model_id=self.model,
response = await self.inference_api.openai_chat_completion(
model=self.model,
messages=[shield_input_message],
stream=False,
temperature=0.0, # default is 1, which is too high for safety
)
content = response.completion_message.content
content = response.choices[0].message.content
content = content.strip()
return self.get_shield_response(content)

View file

@ -63,6 +63,9 @@ class LlmAsJudgeScoringImpl(
async def register_scoring_function(self, function_def: ScoringFn) -> None:
self.llm_as_judge_fn.register_scoring_fn_def(function_def)
async def unregister_scoring_function(self, scoring_fn_id: str) -> None:
self.llm_as_judge_fn.unregister_scoring_fn_def(scoring_fn_id)
async def score_batch(
self,
dataset_id: str,

View file

@ -224,10 +224,6 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
return _GLOBAL_STORAGE["gauges"][name]
def _log_metric(self, event: MetricEvent) -> None:
# Always log to console if console sink is enabled (debug)
if TelemetrySink.CONSOLE in self.config.sinks:
logger.debug(f"METRIC: {event.metric}={event.value} {event.unit} {event.attributes}")
# Add metric as an event to the current span
try:
with self._lock:

View file

@ -8,7 +8,7 @@
from jinja2 import Template
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import UserMessage
from llama_stack.apis.inference import OpenAIUserMessageParam
from llama_stack.apis.tools.rag_tool import (
DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig,
@ -61,16 +61,16 @@ async def llm_rag_query_generator(
messages = [interleaved_content_as_str(content)]
template = Template(config.template)
content = template.render({"messages": messages})
rendered_content: str = template.render({"messages": messages})
model = config.model
message = UserMessage(content=content)
response = await inference_api.chat_completion(
model_id=model,
message = OpenAIUserMessageParam(content=rendered_content)
response = await inference_api.openai_chat_completion(
model=model,
messages=[message],
stream=False,
)
query = response.completion_message.content
query = response.choices[0].message.content
return query

View file

@ -45,10 +45,7 @@ from llama_stack.apis.vector_io import (
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack.providers.utils.memory.vector_store import (
content_from_doc,
parse_data_url,
)
from llama_stack.providers.utils.memory.vector_store import parse_data_url
from .config import RagToolRuntimeConfig
from .context_retriever import generate_rag_query
@ -60,6 +57,47 @@ def make_random_string(length: int = 8):
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]:
"""Get raw binary data and mime type from a RAGDocument for file upload."""
if isinstance(doc.content, URL):
if doc.content.uri.startswith("data:"):
parts = parse_data_url(doc.content.uri)
mime_type = parts["mimetype"]
data = parts["data"]
if parts["is_base64"]:
file_data = base64.b64decode(data)
else:
file_data = data.encode("utf-8")
return file_data, mime_type
else:
async with httpx.AsyncClient() as client:
r = await client.get(doc.content.uri)
r.raise_for_status()
mime_type = r.headers.get("content-type", "application/octet-stream")
return r.content, mime_type
else:
if isinstance(doc.content, str):
content_str = doc.content
else:
content_str = interleaved_content_as_str(doc.content)
if content_str.startswith("data:"):
parts = parse_data_url(content_str)
mime_type = parts["mimetype"]
data = parts["data"]
if parts["is_base64"]:
file_data = base64.b64decode(data)
else:
file_data = data.encode("utf-8")
return file_data, mime_type
else:
return content_str.encode("utf-8"), "text/plain"
class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime):
def __init__(
self,
@ -95,46 +133,52 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
return
for doc in documents:
if isinstance(doc.content, URL):
if doc.content.uri.startswith("data:"):
parts = parse_data_url(doc.content.uri)
file_data = base64.b64decode(parts["data"]) if parts["is_base64"] else parts["data"].encode()
mime_type = parts["mimetype"]
else:
async with httpx.AsyncClient() as client:
response = await client.get(doc.content.uri)
file_data = response.content
mime_type = doc.mime_type or response.headers.get("content-type", "application/octet-stream")
else:
content_str = await content_from_doc(doc)
file_data = content_str.encode("utf-8")
mime_type = doc.mime_type or "text/plain"
try:
try:
file_data, mime_type = await raw_data_from_doc(doc)
except Exception as e:
log.error(f"Failed to extract content from document {doc.document_id}: {e}")
continue
file_extension = mimetypes.guess_extension(mime_type) or ".txt"
filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}")
file_extension = mimetypes.guess_extension(mime_type) or ".txt"
filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}")
file_obj = io.BytesIO(file_data)
file_obj.name = filename
file_obj = io.BytesIO(file_data)
file_obj.name = filename
upload_file = UploadFile(file=file_obj, filename=filename)
upload_file = UploadFile(file=file_obj, filename=filename)
created_file = await self.files_api.openai_upload_file(
file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS
)
try:
created_file = await self.files_api.openai_upload_file(
file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS
)
except Exception as e:
log.error(f"Failed to upload file for document {doc.document_id}: {e}")
continue
chunking_strategy = VectorStoreChunkingStrategyStatic(
static=VectorStoreChunkingStrategyStaticConfig(
max_chunk_size_tokens=chunk_size_in_tokens,
chunk_overlap_tokens=chunk_size_in_tokens // 4,
chunking_strategy = VectorStoreChunkingStrategyStatic(
static=VectorStoreChunkingStrategyStaticConfig(
max_chunk_size_tokens=chunk_size_in_tokens,
chunk_overlap_tokens=chunk_size_in_tokens // 4,
)
)
)
await self.vector_io_api.openai_attach_file_to_vector_store(
vector_store_id=vector_db_id,
file_id=created_file.id,
attributes=doc.metadata,
chunking_strategy=chunking_strategy,
)
try:
await self.vector_io_api.openai_attach_file_to_vector_store(
vector_store_id=vector_db_id,
file_id=created_file.id,
attributes=doc.metadata,
chunking_strategy=chunking_strategy,
)
except Exception as e:
log.error(
f"Failed to attach file {created_file.id} to vector store {vector_db_id} for document {doc.document_id}: {e}"
)
continue
except Exception as e:
log.error(f"Unexpected error processing document {doc.document_id}: {e}")
continue
async def query(
self,
@ -274,7 +318,6 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
if query_config:
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
else:
# handle someone passing an empty dict
query_config = RAGQueryConfig()
query = kwargs["query"]
@ -285,6 +328,6 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
)
return ToolInvocationResult(
content=result.content,
content=result.content or [],
metadata=result.metadata,
)

View file

@ -13,7 +13,7 @@ def available_providers() -> list[ProviderSpec]:
InlineProviderSpec(
api=Api.batches,
provider_type="inline::reference",
pip_packages=["openai"],
pip_packages=[],
module="llama_stack.providers.inline.batches.reference",
config_class="llama_stack.providers.inline.batches.reference.config.ReferenceBatchesImplConfig",
api_dependencies=[

View file

@ -6,11 +6,10 @@
from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
RemoteProviderSpec,
)
@ -25,28 +24,26 @@ def available_providers() -> list[ProviderSpec]:
api_dependencies=[],
description="Local filesystem-based dataset I/O provider for reading and writing datasets to local storage.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.datasetio,
adapter=AdapterSpec(
adapter_type="huggingface",
pip_packages=[
"datasets>=4.0.0",
],
module="llama_stack.providers.remote.datasetio.huggingface",
config_class="llama_stack.providers.remote.datasetio.huggingface.HuggingfaceDatasetIOConfig",
description="HuggingFace datasets provider for accessing and managing datasets from the HuggingFace Hub.",
),
adapter_type="huggingface",
provider_type="remote::huggingface",
pip_packages=[
"datasets>=4.0.0",
],
module="llama_stack.providers.remote.datasetio.huggingface",
config_class="llama_stack.providers.remote.datasetio.huggingface.HuggingfaceDatasetIOConfig",
description="HuggingFace datasets provider for accessing and managing datasets from the HuggingFace Hub.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.datasetio,
adapter=AdapterSpec(
adapter_type="nvidia",
pip_packages=[
"datasets>=4.0.0",
],
module="llama_stack.providers.remote.datasetio.nvidia",
config_class="llama_stack.providers.remote.datasetio.nvidia.NvidiaDatasetIOConfig",
description="NVIDIA's dataset I/O provider for accessing datasets from NVIDIA's data platform.",
),
adapter_type="nvidia",
provider_type="remote::nvidia",
module="llama_stack.providers.remote.datasetio.nvidia",
config_class="llama_stack.providers.remote.datasetio.nvidia.NvidiaDatasetIOConfig",
pip_packages=[
"datasets>=4.0.0",
],
description="NVIDIA's dataset I/O provider for accessing datasets from NVIDIA's data platform.",
),
]

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
def available_providers() -> list[ProviderSpec]:
@ -25,17 +25,16 @@ def available_providers() -> list[ProviderSpec]:
],
description="Meta's reference implementation of evaluation tasks with support for multiple languages and evaluation metrics.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.eval,
adapter=AdapterSpec(
adapter_type="nvidia",
pip_packages=[
"requests",
],
module="llama_stack.providers.remote.eval.nvidia",
config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig",
description="NVIDIA's evaluation provider for running evaluation tasks on NVIDIA's platform.",
),
adapter_type="nvidia",
pip_packages=[
"requests",
],
provider_type="remote::nvidia",
module="llama_stack.providers.remote.eval.nvidia",
config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig",
description="NVIDIA's evaluation provider for running evaluation tasks on NVIDIA's platform.",
api_dependencies=[
Api.datasetio,
Api.datasets,

View file

@ -4,13 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
)
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages
@ -25,14 +19,13 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig",
description="Local filesystem-based file storage provider for managing files and documents locally.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.files,
adapter=AdapterSpec(
adapter_type="s3",
pip_packages=["boto3"] + sql_store_pip_packages,
module="llama_stack.providers.remote.files.s3",
config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig",
description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.",
),
provider_type="remote::s3",
adapter_type="s3",
pip_packages=["boto3"] + sql_store_pip_packages,
module="llama_stack.providers.remote.files.s3",
config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig",
description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.",
),
]

View file

@ -6,11 +6,10 @@
from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
RemoteProviderSpec,
)
META_REFERENCE_DEPS = [
@ -49,180 +48,167 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig",
description="Sentence Transformers inference provider for text embeddings and similarity search.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="cerebras",
pip_packages=[
"cerebras_cloud_sdk",
],
module="llama_stack.providers.remote.inference.cerebras",
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
description="Cerebras inference provider for running models on Cerebras Cloud platform.",
),
adapter_type="cerebras",
provider_type="remote::cerebras",
pip_packages=[
"cerebras_cloud_sdk",
],
module="llama_stack.providers.remote.inference.cerebras",
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
description="Cerebras inference provider for running models on Cerebras Cloud platform.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="ollama",
pip_packages=["ollama", "aiohttp", "h11>=0.16.0"],
config_class="llama_stack.providers.remote.inference.ollama.OllamaImplConfig",
module="llama_stack.providers.remote.inference.ollama",
description="Ollama inference provider for running local models through the Ollama runtime.",
),
adapter_type="ollama",
provider_type="remote::ollama",
pip_packages=["ollama", "aiohttp", "h11>=0.16.0"],
config_class="llama_stack.providers.remote.inference.ollama.OllamaImplConfig",
module="llama_stack.providers.remote.inference.ollama",
description="Ollama inference provider for running local models through the Ollama runtime.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="vllm",
pip_packages=["openai"],
module="llama_stack.providers.remote.inference.vllm",
config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig",
description="Remote vLLM inference provider for connecting to vLLM servers.",
),
adapter_type="vllm",
provider_type="remote::vllm",
pip_packages=[],
module="llama_stack.providers.remote.inference.vllm",
config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig",
provider_data_validator="llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator",
description="Remote vLLM inference provider for connecting to vLLM servers.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="tgi",
pip_packages=["huggingface_hub", "aiohttp"],
module="llama_stack.providers.remote.inference.tgi",
config_class="llama_stack.providers.remote.inference.tgi.TGIImplConfig",
description="Text Generation Inference (TGI) provider for HuggingFace model serving.",
),
adapter_type="tgi",
provider_type="remote::tgi",
pip_packages=["huggingface_hub", "aiohttp"],
module="llama_stack.providers.remote.inference.tgi",
config_class="llama_stack.providers.remote.inference.tgi.TGIImplConfig",
description="Text Generation Inference (TGI) provider for HuggingFace model serving.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="hf::serverless",
pip_packages=["huggingface_hub", "aiohttp"],
module="llama_stack.providers.remote.inference.tgi",
config_class="llama_stack.providers.remote.inference.tgi.InferenceAPIImplConfig",
description="HuggingFace Inference API serverless provider for on-demand model inference.",
),
adapter_type="hf::serverless",
provider_type="remote::hf::serverless",
pip_packages=["huggingface_hub", "aiohttp"],
module="llama_stack.providers.remote.inference.tgi",
config_class="llama_stack.providers.remote.inference.tgi.InferenceAPIImplConfig",
description="HuggingFace Inference API serverless provider for on-demand model inference.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="hf::endpoint",
pip_packages=["huggingface_hub", "aiohttp"],
module="llama_stack.providers.remote.inference.tgi",
config_class="llama_stack.providers.remote.inference.tgi.InferenceEndpointImplConfig",
description="HuggingFace Inference Endpoints provider for dedicated model serving.",
),
provider_type="remote::hf::endpoint",
adapter_type="hf::endpoint",
pip_packages=["huggingface_hub", "aiohttp"],
module="llama_stack.providers.remote.inference.tgi",
config_class="llama_stack.providers.remote.inference.tgi.InferenceEndpointImplConfig",
description="HuggingFace Inference Endpoints provider for dedicated model serving.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="fireworks",
pip_packages=[
"fireworks-ai<=0.17.16",
],
module="llama_stack.providers.remote.inference.fireworks",
config_class="llama_stack.providers.remote.inference.fireworks.FireworksImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.fireworks.FireworksProviderDataValidator",
description="Fireworks AI inference provider for Llama models and other AI models on the Fireworks platform.",
),
adapter_type="fireworks",
provider_type="remote::fireworks",
pip_packages=[
"fireworks-ai<=0.17.16",
],
module="llama_stack.providers.remote.inference.fireworks",
config_class="llama_stack.providers.remote.inference.fireworks.FireworksImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.fireworks.FireworksProviderDataValidator",
description="Fireworks AI inference provider for Llama models and other AI models on the Fireworks platform.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="together",
pip_packages=[
"together",
],
module="llama_stack.providers.remote.inference.together",
config_class="llama_stack.providers.remote.inference.together.TogetherImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
description="Together AI inference provider for open-source models and collaborative AI development.",
),
adapter_type="together",
provider_type="remote::together",
pip_packages=[
"together",
],
module="llama_stack.providers.remote.inference.together",
config_class="llama_stack.providers.remote.inference.together.TogetherImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
description="Together AI inference provider for open-source models and collaborative AI development.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="bedrock",
pip_packages=["boto3"],
module="llama_stack.providers.remote.inference.bedrock",
config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig",
description="AWS Bedrock inference provider for accessing various AI models through AWS's managed service.",
),
adapter_type="bedrock",
provider_type="remote::bedrock",
pip_packages=["boto3"],
module="llama_stack.providers.remote.inference.bedrock",
config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig",
description="AWS Bedrock inference provider for accessing various AI models through AWS's managed service.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="databricks",
pip_packages=[
"openai",
],
module="llama_stack.providers.remote.inference.databricks",
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
description="Databricks inference provider for running models on Databricks' unified analytics platform.",
),
adapter_type="databricks",
provider_type="remote::databricks",
pip_packages=["databricks-sdk"],
module="llama_stack.providers.remote.inference.databricks",
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
description="Databricks inference provider for running models on Databricks' unified analytics platform.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="nvidia",
pip_packages=[
"openai",
],
module="llama_stack.providers.remote.inference.nvidia",
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.",
),
adapter_type="nvidia",
provider_type="remote::nvidia",
pip_packages=[],
module="llama_stack.providers.remote.inference.nvidia",
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="runpod",
pip_packages=["openai"],
module="llama_stack.providers.remote.inference.runpod",
config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig",
description="RunPod inference provider for running models on RunPod's cloud GPU platform.",
),
adapter_type="runpod",
provider_type="remote::runpod",
pip_packages=[],
module="llama_stack.providers.remote.inference.runpod",
config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig",
description="RunPod inference provider for running models on RunPod's cloud GPU platform.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="openai",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.openai",
config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig",
provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
description="OpenAI inference provider for accessing GPT models and other OpenAI services.",
),
adapter_type="openai",
provider_type="remote::openai",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.openai",
config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig",
provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
description="OpenAI inference provider for accessing GPT models and other OpenAI services.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="anthropic",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.anthropic",
config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig",
provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",
description="Anthropic inference provider for accessing Claude models and Anthropic's AI services.",
),
adapter_type="anthropic",
provider_type="remote::anthropic",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.anthropic",
config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig",
provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",
description="Anthropic inference provider for accessing Claude models and Anthropic's AI services.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="gemini",
pip_packages=["litellm", "openai"],
module="llama_stack.providers.remote.inference.gemini",
config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig",
provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
description="Google Gemini inference provider for accessing Gemini models and Google's AI services.",
),
adapter_type="gemini",
provider_type="remote::gemini",
pip_packages=[
"litellm",
],
module="llama_stack.providers.remote.inference.gemini",
config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig",
provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
description="Google Gemini inference provider for accessing Gemini models and Google's AI services.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="vertexai",
pip_packages=["litellm", "google-cloud-aiplatform", "openai"],
module="llama_stack.providers.remote.inference.vertexai",
config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig",
provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator",
description="""Google Vertex AI inference provider enables you to use Google's Gemini models through Google Cloud's Vertex AI platform, providing several advantages:
adapter_type="vertexai",
provider_type="remote::vertexai",
pip_packages=[
"litellm",
"google-cloud-aiplatform",
],
module="llama_stack.providers.remote.inference.vertexai",
config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig",
provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator",
description="""Google Vertex AI inference provider enables you to use Google's Gemini models through Google Cloud's Vertex AI platform, providing several advantages:
Enterprise-grade security: Uses Google Cloud's security controls and IAM
Better integration: Seamless integration with other Google Cloud services
@ -242,61 +228,73 @@ Available Models:
- vertex_ai/gemini-2.0-flash
- vertex_ai/gemini-2.5-flash
- vertex_ai/gemini-2.5-pro""",
),
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="groq",
pip_packages=["litellm", "openai"],
module="llama_stack.providers.remote.inference.groq",
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
description="Groq inference provider for ultra-fast inference using Groq's LPU technology.",
),
adapter_type="groq",
provider_type="remote::groq",
pip_packages=[
"litellm",
],
module="llama_stack.providers.remote.inference.groq",
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
description="Groq inference provider for ultra-fast inference using Groq's LPU technology.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="llama-openai-compat",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.llama_openai_compat",
config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig",
provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
description="Llama OpenAI-compatible provider for using Llama models with OpenAI API format.",
),
adapter_type="llama-openai-compat",
provider_type="remote::llama-openai-compat",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.llama_openai_compat",
config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig",
provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
description="Llama OpenAI-compatible provider for using Llama models with OpenAI API format.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="sambanova",
pip_packages=["litellm", "openai"],
module="llama_stack.providers.remote.inference.sambanova",
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
description="SambaNova inference provider for running models on SambaNova's dataflow architecture.",
),
adapter_type="sambanova",
provider_type="remote::sambanova",
pip_packages=[
"litellm",
],
module="llama_stack.providers.remote.inference.sambanova",
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
description="SambaNova inference provider for running models on SambaNova's dataflow architecture.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="passthrough",
pip_packages=[],
module="llama_stack.providers.remote.inference.passthrough",
config_class="llama_stack.providers.remote.inference.passthrough.PassthroughImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator",
description="Passthrough inference provider for connecting to any external inference service not directly supported.",
),
adapter_type="passthrough",
provider_type="remote::passthrough",
pip_packages=[],
module="llama_stack.providers.remote.inference.passthrough",
config_class="llama_stack.providers.remote.inference.passthrough.PassthroughImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator",
description="Passthrough inference provider for connecting to any external inference service not directly supported.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="watsonx",
pip_packages=["ibm_watsonx_ai"],
module="llama_stack.providers.remote.inference.watsonx",
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
description="IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform.",
),
adapter_type="watsonx",
provider_type="remote::watsonx",
pip_packages=["ibm_watsonx_ai"],
module="llama_stack.providers.remote.inference.watsonx",
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
description="IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform.",
),
RemoteProviderSpec(
api=Api.inference,
provider_type="remote::azure",
adapter_type="azure",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.azure",
config_class="llama_stack.providers.remote.inference.azure.AzureConfig",
provider_data_validator="llama_stack.providers.remote.inference.azure.config.AzureProviderDataValidator",
description="""
Azure OpenAI inference provider for accessing GPT models and other Azure services.
Provider documentation
https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview
""",
),
]

View file

@ -7,7 +7,7 @@
from typing import cast
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
# We provide two versions of these providers so that distributions can package the appropriate version of torch.
# The CPU version is used for distributions that don't have GPU support -- they result in smaller container images.
@ -57,14 +57,13 @@ def available_providers() -> list[ProviderSpec]:
],
description="HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.post_training,
adapter=AdapterSpec(
adapter_type="nvidia",
pip_packages=["requests", "aiohttp"],
module="llama_stack.providers.remote.post_training.nvidia",
config_class="llama_stack.providers.remote.post_training.nvidia.NvidiaPostTrainingConfig",
description="NVIDIA's post-training provider for fine-tuning models on NVIDIA's platform.",
),
adapter_type="nvidia",
provider_type="remote::nvidia",
pip_packages=["requests", "aiohttp"],
module="llama_stack.providers.remote.post_training.nvidia",
config_class="llama_stack.providers.remote.post_training.nvidia.NvidiaPostTrainingConfig",
description="NVIDIA's post-training provider for fine-tuning models on NVIDIA's platform.",
),
]

View file

@ -6,11 +6,10 @@
from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
RemoteProviderSpec,
)
@ -48,35 +47,32 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.inline.safety.code_scanner.CodeScannerConfig",
description="Code Scanner safety provider for detecting security vulnerabilities and unsafe code patterns.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.safety,
adapter=AdapterSpec(
adapter_type="bedrock",
pip_packages=["boto3"],
module="llama_stack.providers.remote.safety.bedrock",
config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig",
description="AWS Bedrock safety provider for content moderation using AWS's safety services.",
),
adapter_type="bedrock",
provider_type="remote::bedrock",
pip_packages=["boto3"],
module="llama_stack.providers.remote.safety.bedrock",
config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig",
description="AWS Bedrock safety provider for content moderation using AWS's safety services.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.safety,
adapter=AdapterSpec(
adapter_type="nvidia",
pip_packages=["requests"],
module="llama_stack.providers.remote.safety.nvidia",
config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig",
description="NVIDIA's safety provider for content moderation and safety filtering.",
),
adapter_type="nvidia",
provider_type="remote::nvidia",
pip_packages=["requests"],
module="llama_stack.providers.remote.safety.nvidia",
config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig",
description="NVIDIA's safety provider for content moderation and safety filtering.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.safety,
adapter=AdapterSpec(
adapter_type="sambanova",
pip_packages=["litellm", "requests"],
module="llama_stack.providers.remote.safety.sambanova",
config_class="llama_stack.providers.remote.safety.sambanova.SambaNovaSafetyConfig",
provider_data_validator="llama_stack.providers.remote.safety.sambanova.config.SambaNovaProviderDataValidator",
description="SambaNova's safety provider for content moderation and safety filtering.",
),
adapter_type="sambanova",
provider_type="remote::sambanova",
pip_packages=["litellm", "requests"],
module="llama_stack.providers.remote.safety.sambanova",
config_class="llama_stack.providers.remote.safety.sambanova.SambaNovaSafetyConfig",
provider_data_validator="llama_stack.providers.remote.safety.sambanova.config.SambaNovaProviderDataValidator",
description="SambaNova's safety provider for content moderation and safety filtering.",
),
]

View file

@ -38,7 +38,7 @@ def available_providers() -> list[ProviderSpec]:
InlineProviderSpec(
api=Api.scoring,
provider_type="inline::braintrust",
pip_packages=["autoevals", "openai"],
pip_packages=["autoevals"],
module="llama_stack.providers.inline.scoring.braintrust",
config_class="llama_stack.providers.inline.scoring.braintrust.BraintrustScoringConfig",
api_dependencies=[

View file

@ -6,11 +6,10 @@
from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
RemoteProviderSpec,
)
@ -35,59 +34,54 @@ def available_providers() -> list[ProviderSpec]:
api_dependencies=[Api.vector_io, Api.inference, Api.files],
description="RAG (Retrieval-Augmented Generation) tool runtime for document ingestion, chunking, and semantic search.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.tool_runtime,
adapter=AdapterSpec(
adapter_type="brave-search",
module="llama_stack.providers.remote.tool_runtime.brave_search",
config_class="llama_stack.providers.remote.tool_runtime.brave_search.config.BraveSearchToolConfig",
pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
description="Brave Search tool for web search capabilities with privacy-focused results.",
),
adapter_type="brave-search",
provider_type="remote::brave-search",
module="llama_stack.providers.remote.tool_runtime.brave_search",
config_class="llama_stack.providers.remote.tool_runtime.brave_search.config.BraveSearchToolConfig",
pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
description="Brave Search tool for web search capabilities with privacy-focused results.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.tool_runtime,
adapter=AdapterSpec(
adapter_type="bing-search",
module="llama_stack.providers.remote.tool_runtime.bing_search",
config_class="llama_stack.providers.remote.tool_runtime.bing_search.config.BingSearchToolConfig",
pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.bing_search.BingSearchToolProviderDataValidator",
description="Bing Search tool for web search capabilities using Microsoft's search engine.",
),
adapter_type="bing-search",
provider_type="remote::bing-search",
module="llama_stack.providers.remote.tool_runtime.bing_search",
config_class="llama_stack.providers.remote.tool_runtime.bing_search.config.BingSearchToolConfig",
pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.bing_search.BingSearchToolProviderDataValidator",
description="Bing Search tool for web search capabilities using Microsoft's search engine.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.tool_runtime,
adapter=AdapterSpec(
adapter_type="tavily-search",
module="llama_stack.providers.remote.tool_runtime.tavily_search",
config_class="llama_stack.providers.remote.tool_runtime.tavily_search.config.TavilySearchToolConfig",
pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator",
description="Tavily Search tool for AI-optimized web search with structured results.",
),
adapter_type="tavily-search",
provider_type="remote::tavily-search",
module="llama_stack.providers.remote.tool_runtime.tavily_search",
config_class="llama_stack.providers.remote.tool_runtime.tavily_search.config.TavilySearchToolConfig",
pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator",
description="Tavily Search tool for AI-optimized web search with structured results.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.tool_runtime,
adapter=AdapterSpec(
adapter_type="wolfram-alpha",
module="llama_stack.providers.remote.tool_runtime.wolfram_alpha",
config_class="llama_stack.providers.remote.tool_runtime.wolfram_alpha.config.WolframAlphaToolConfig",
pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.wolfram_alpha.WolframAlphaToolProviderDataValidator",
description="Wolfram Alpha tool for computational knowledge and mathematical calculations.",
),
adapter_type="wolfram-alpha",
provider_type="remote::wolfram-alpha",
module="llama_stack.providers.remote.tool_runtime.wolfram_alpha",
config_class="llama_stack.providers.remote.tool_runtime.wolfram_alpha.config.WolframAlphaToolConfig",
pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.wolfram_alpha.WolframAlphaToolProviderDataValidator",
description="Wolfram Alpha tool for computational knowledge and mathematical calculations.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.tool_runtime,
adapter=AdapterSpec(
adapter_type="model-context-protocol",
module="llama_stack.providers.remote.tool_runtime.model_context_protocol",
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderConfig",
pip_packages=["mcp>=1.8.1"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderDataValidator",
description="Model Context Protocol (MCP) tool for standardized tool calling and context management.",
),
adapter_type="model-context-protocol",
provider_type="remote::model-context-protocol",
module="llama_stack.providers.remote.tool_runtime.model_context_protocol",
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderConfig",
pip_packages=["mcp>=1.8.1"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderDataValidator",
description="Model Context Protocol (MCP) tool for standardized tool calling and context management.",
),
]

View file

@ -6,11 +6,10 @@
from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
RemoteProviderSpec,
)
@ -300,14 +299,16 @@ See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) f
Please refer to the sqlite-vec provider documentation.
""",
),
remote_provider_spec(
Api.vector_io,
AdapterSpec(
adapter_type="chromadb",
pip_packages=["chromadb-client"],
module="llama_stack.providers.remote.vector_io.chroma",
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig",
description="""
RemoteProviderSpec(
api=Api.vector_io,
adapter_type="chromadb",
provider_type="remote::chromadb",
pip_packages=["chromadb-client"],
module="llama_stack.providers.remote.vector_io.chroma",
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
description="""
[Chroma](https://www.trychroma.com/) is an inline and remote vector
database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database.
That means you're not limited to storing vectors in memory or in a separate service.
@ -340,9 +341,6 @@ pip install chromadb
## Documentation
See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introduction) for more details about Chroma in general.
""",
),
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
),
InlineProviderSpec(
api=Api.vector_io,
@ -387,14 +385,16 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
""",
),
remote_provider_spec(
Api.vector_io,
AdapterSpec(
adapter_type="pgvector",
pip_packages=["psycopg2-binary"],
module="llama_stack.providers.remote.vector_io.pgvector",
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig",
description="""
RemoteProviderSpec(
api=Api.vector_io,
adapter_type="pgvector",
provider_type="remote::pgvector",
pip_packages=["psycopg2-binary"],
module="llama_stack.providers.remote.vector_io.pgvector",
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
description="""
[PGVector](https://github.com/pgvector/pgvector) is a remote vector database provider for Llama Stack. It
allows you to store and query vectors directly in memory.
That means you'll get fast and efficient vector retrieval.
@ -410,7 +410,7 @@ There are three implementations of search for PGVectoIndex available:
- How it works:
- Uses PostgreSQL's vector extension (pgvector) to perform similarity search
- Compares query embeddings against stored embeddings using Cosine distance or other distance metrics
- Eg. SQL query: SELECT document, embedding <=> %s::vector AS distance FROM table ORDER BY distance
- Eg. SQL query: SELECT document, embedding &lt;=&gt; %s::vector AS distance FROM table ORDER BY distance
-Characteristics:
- Semantic understanding - finds documents similar in meaning even if they don't share keywords
@ -495,19 +495,18 @@ docker pull pgvector/pgvector:pg17
## Documentation
See [PGVector's documentation](https://github.com/pgvector/pgvector) for more details about PGVector in general.
""",
),
),
RemoteProviderSpec(
api=Api.vector_io,
adapter_type="weaviate",
provider_type="remote::weaviate",
pip_packages=["weaviate-client"],
module="llama_stack.providers.remote.vector_io.weaviate",
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
),
remote_provider_spec(
Api.vector_io,
AdapterSpec(
adapter_type="weaviate",
pip_packages=["weaviate-client"],
module="llama_stack.providers.remote.vector_io.weaviate",
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
description="""
description="""
[Weaviate](https://weaviate.io/) is a vector database provider for Llama Stack.
It allows you to store and query vectors directly within a Weaviate database.
That means you're not limited to storing vectors in memory or in a separate service.
@ -538,9 +537,6 @@ To install Weaviate see the [Weaviate quickstart documentation](https://weaviate
## Documentation
See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more details about Weaviate in general.
""",
),
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
),
InlineProviderSpec(
api=Api.vector_io,
@ -594,28 +590,29 @@ docker pull qdrant/qdrant
See the [Qdrant documentation](https://qdrant.tech/documentation/) for more details about Qdrant in general.
""",
),
remote_provider_spec(
Api.vector_io,
AdapterSpec(
adapter_type="qdrant",
pip_packages=["qdrant-client"],
module="llama_stack.providers.remote.vector_io.qdrant",
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig",
description="""
Please refer to the inline provider documentation.
""",
),
RemoteProviderSpec(
api=Api.vector_io,
adapter_type="qdrant",
provider_type="remote::qdrant",
pip_packages=["qdrant-client"],
module="llama_stack.providers.remote.vector_io.qdrant",
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
description="""
Please refer to the inline provider documentation.
""",
),
remote_provider_spec(
Api.vector_io,
AdapterSpec(
adapter_type="milvus",
pip_packages=["pymilvus>=2.4.10"],
module="llama_stack.providers.remote.vector_io.milvus",
config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig",
description="""
RemoteProviderSpec(
api=Api.vector_io,
adapter_type="milvus",
provider_type="remote::milvus",
pip_packages=["pymilvus>=2.4.10"],
module="llama_stack.providers.remote.vector_io.milvus",
config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
description="""
[Milvus](https://milvus.io/) is an inline and remote vector database provider for Llama Stack. It
allows you to store and query vectors directly within a Milvus database.
That means you're not limited to storing vectors in memory or in a separate service.
@ -636,7 +633,13 @@ To use Milvus in your Llama Stack project, follow these steps:
## Installation
You can install Milvus using pymilvus:
If you want to use inline Milvus, you can install:
```bash
pip install pymilvus[milvus-lite]
```
If you want to use remote Milvus, you can install:
```bash
pip install pymilvus
@ -806,14 +809,11 @@ See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for m
For more details on TLS configuration, refer to the [TLS setup guide](https://milvus.io/docs/tls.md).
""",
),
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
),
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::milvus",
pip_packages=["pymilvus>=2.4.10"],
pip_packages=["pymilvus[milvus-lite]>=2.4.10"],
module="llama_stack.providers.inline.vector_io.milvus",
config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig",
api_dependencies=[Api.inference],

View file

@ -14,7 +14,6 @@ from llama_stack.apis.datasets import Datasets
from llama_stack.apis.inference import Inference
from llama_stack.apis.scoring import Scoring, ScoringResult
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from .....apis.common.job_types import Job, JobStatus
@ -45,24 +44,29 @@ class NVIDIAEvalImpl(
self.inference_api = inference_api
self.agents_api = agents_api
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
ModelRegistryHelper.__init__(self)
async def initialize(self) -> None: ...
async def shutdown(self) -> None: ...
async def _evaluator_get(self, path):
async def _evaluator_get(self, path: str):
"""Helper for making GET requests to the evaluator service."""
response = requests.get(url=f"{self.config.evaluator_url}{path}")
response.raise_for_status()
return response.json()
async def _evaluator_post(self, path, data):
async def _evaluator_post(self, path: str, data: dict[str, Any]):
"""Helper for making POST requests to the evaluator service."""
response = requests.post(url=f"{self.config.evaluator_url}{path}", json=data)
response.raise_for_status()
return response.json()
async def _evaluator_delete(self, path: str) -> None:
"""Helper for making DELETE requests to the evaluator service."""
response = requests.delete(url=f"{self.config.evaluator_url}{path}")
response.raise_for_status()
async def register_benchmark(self, task_def: Benchmark) -> None:
"""Register a benchmark as an evaluation configuration."""
await self._evaluator_post(
@ -75,6 +79,10 @@ class NVIDIAEvalImpl(
},
)
async def unregister_benchmark(self, benchmark_id: str) -> None:
"""Unregister a benchmark evaluation configuration from NeMo Evaluator."""
await self._evaluator_delete(f"/v1/evaluation/configs/{DEFAULT_NAMESPACE}/{benchmark_id}")
async def run_eval(
self,
benchmark_id: str,

View file

@ -10,7 +10,7 @@ from typing import Annotated, Any
import boto3
from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError
from fastapi import File, Form, Response, UploadFile
from fastapi import Depends, File, Form, Response, UploadFile
from llama_stack.apis.common.errors import ResourceNotFoundError
from llama_stack.apis.common.responses import Order
@ -23,6 +23,7 @@ from llama_stack.apis.files import (
OpenAIFilePurpose,
)
from llama_stack.core.datatypes import AccessRule
from llama_stack.providers.utils.files.form_data import parse_expires_after
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
@ -137,7 +138,7 @@ class S3FilesImpl(Files):
where: dict[str, str | dict] = {"id": file_id}
if not return_expired:
where["expires_at"] = {">": self._now()}
if not (row := await self.sql_store.fetch_one("openai_files", policy=self.policy, where=where)):
if not (row := await self.sql_store.fetch_one("openai_files", where=where)):
raise ResourceNotFoundError(file_id, "File", "files.list()")
return row
@ -164,7 +165,7 @@ class S3FilesImpl(Files):
self._client = _create_s3_client(self._config)
await _create_bucket_if_not_exists(self._client, self._config)
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store))
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store), self.policy)
await self._sql_store.create_table(
"openai_files",
{
@ -195,8 +196,7 @@ class S3FilesImpl(Files):
self,
file: Annotated[UploadFile, File()],
purpose: Annotated[OpenAIFilePurpose, Form()],
expires_after_anchor: Annotated[str | None, Form(alias="expires_after[anchor]")] = None,
expires_after_seconds: Annotated[int | None, Form(alias="expires_after[seconds]")] = None,
expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None,
) -> OpenAIFileObject:
file_id = f"file-{uuid.uuid4().hex}"
@ -204,14 +204,6 @@ class S3FilesImpl(Files):
created_at = self._now()
expires_after = None
if expires_after_anchor is not None or expires_after_seconds is not None:
# we use ExpiresAfter to validate input
expires_after = ExpiresAfter(
anchor=expires_after_anchor, # type: ignore[arg-type]
seconds=expires_after_seconds, # type: ignore[arg-type]
)
# the default is no expiration.
# to implement no expiration we set an expiration beyond the max.
# we'll hide this fact from users when returning the file object.
@ -268,7 +260,6 @@ class S3FilesImpl(Files):
paginated_result = await self.sql_store.fetch_all(
table="openai_files",
policy=self.policy,
where=where_conditions,
order_by=[("created_at", order.value)],
cursor=("id", after) if after else None,

View file

@ -4,15 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from .config import AnthropicConfig
class AnthropicProviderDataValidator(BaseModel):
anthropic_api_key: str | None = None
async def get_adapter_impl(config: AnthropicConfig, _deps):
from .anthropic import AnthropicInferenceAdapter

View file

@ -8,14 +8,24 @@ from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOp
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import AnthropicConfig
from .models import MODEL_ENTRIES
class AnthropicInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
# source: https://docs.claude.com/en/docs/build-with-claude/embeddings
# TODO: add support for voyageai, which is where these models are hosted
# embedding_model_metadata = {
# "voyage-3-large": {"embedding_dimension": 1024, "context_length": 32000}, # supports dimensions 256, 512, 1024, 2048
# "voyage-3.5": {"embedding_dimension": 1024, "context_length": 32000}, # supports dimensions 256, 512, 1024, 2048
# "voyage-3.5-lite": {"embedding_dimension": 1024, "context_length": 32000}, # supports dimensions 256, 512, 1024, 2048
# "voyage-code-3": {"embedding_dimension": 1024, "context_length": 32000}, # supports dimensions 256, 512, 1024, 2048
# "voyage-finance-2": {"embedding_dimension": 1024, "context_length": 32000},
# "voyage-law-2": {"embedding_dimension": 1024, "context_length": 16000},
# "voyage-multimodal-3": {"embedding_dimension": 1024, "context_length": 32000},
# }
def __init__(self, config: AnthropicConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
MODEL_ENTRIES,
litellm_provider_name="anthropic",
api_key_from_config=config.api_key,
provider_data_api_key_field="anthropic_api_key",

View file

@ -1,40 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models import ModelType
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
)
LLM_MODEL_IDS = [
"claude-3-5-sonnet-latest",
"claude-3-7-sonnet-latest",
"claude-3-5-haiku-latest",
]
SAFETY_MODELS_ENTRIES = []
MODEL_ENTRIES = (
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
+ [
ProviderModelEntry(
provider_model_id="voyage-3",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 1024, "context_length": 32000},
),
ProviderModelEntry(
provider_model_id="voyage-3-lite",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 512, "context_length": 32000},
),
ProviderModelEntry(
provider_model_id="voyage-code-3",
model_type=ModelType.embedding,
metadata={"embedding_dimension": 1024, "context_length": 32000},
),
]
+ SAFETY_MODELS_ENTRIES
)

View file

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import AzureConfig
async def get_adapter_impl(config: AzureConfig, _deps):
from .azure import AzureInferenceAdapter
impl = AzureInferenceAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,62 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from urllib.parse import urljoin
from llama_stack.apis.inference import ChatCompletionRequest
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
LiteLLMOpenAIMixin,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import AzureConfig
class AzureInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
def __init__(self, config: AzureConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
litellm_provider_name="azure",
api_key_from_config=config.api_key.get_secret_value(),
provider_data_api_key_field="azure_api_key",
openai_compat_api_base=str(config.api_base),
)
self.config = config
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_base_url(self) -> str:
"""
Get the Azure API base URL.
Returns the Azure API base URL from the configuration.
"""
return urljoin(str(self.config.api_base), "/openai/v1")
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
# Get base parameters from parent
params = await super()._get_params(request)
# Add Azure specific parameters
provider_data = self.get_request_provider_data()
if provider_data:
if getattr(provider_data, "azure_api_key", None):
params["api_key"] = provider_data.azure_api_key
if getattr(provider_data, "azure_api_base", None):
params["api_base"] = provider_data.azure_api_base
if getattr(provider_data, "azure_api_version", None):
params["api_version"] = provider_data.azure_api_version
if getattr(provider_data, "azure_api_type", None):
params["api_type"] = provider_data.azure_api_type
else:
params["api_key"] = self.config.api_key.get_secret_value()
params["api_base"] = str(self.config.api_base)
params["api_version"] = self.config.api_version
params["api_type"] = self.config.api_type
return params

View file

@ -0,0 +1,63 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Any
from pydantic import BaseModel, Field, HttpUrl, SecretStr
from llama_stack.schema_utils import json_schema_type
class AzureProviderDataValidator(BaseModel):
azure_api_key: SecretStr = Field(
description="Azure API key for Azure",
)
azure_api_base: HttpUrl = Field(
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com)",
)
azure_api_version: str | None = Field(
default=None,
description="Azure API version for Azure (e.g., 2024-06-01)",
)
azure_api_type: str | None = Field(
default="azure",
description="Azure API type for Azure (e.g., azure)",
)
@json_schema_type
class AzureConfig(BaseModel):
api_key: SecretStr = Field(
description="Azure API key for Azure",
)
api_base: HttpUrl = Field(
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com)",
)
api_version: str | None = Field(
default_factory=lambda: os.getenv("AZURE_API_VERSION"),
description="Azure API version for Azure (e.g., 2024-12-01-preview)",
)
api_type: str | None = Field(
default_factory=lambda: os.getenv("AZURE_API_TYPE", "azure"),
description="Azure API type for Azure (e.g., azure)",
)
@classmethod
def sample_run_config(
cls,
api_key: str = "${env.AZURE_API_KEY:=}",
api_base: str = "${env.AZURE_API_BASE:=}",
api_version: str = "${env.AZURE_API_VERSION:=}",
api_type: str = "${env.AZURE_API_TYPE:=}",
**kwargs,
) -> dict[str, Any]:
return {
"api_key": api_key,
"api_base": api_base,
"api_version": api_version,
"api_type": api_type,
}

View file

@ -11,21 +11,17 @@ from botocore.client import BaseClient
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
OpenAIEmbeddingsResponse,
ResponseFormat,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
@ -47,12 +43,47 @@ from llama_stack.providers.utils.inference.openai_compat import (
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
content_has_media,
interleaved_content_as_str,
)
from .models import MODEL_ENTRIES
REGION_PREFIX_MAP = {
"us": "us.",
"eu": "eu.",
"ap": "ap.",
}
def _get_region_prefix(region: str | None) -> str:
# AWS requires region prefixes for inference profiles
if region is None:
return "us." # default to US when we don't know
# Handle case insensitive region matching
region_lower = region.lower()
for prefix in REGION_PREFIX_MAP:
if region_lower.startswith(f"{prefix}-"):
return REGION_PREFIX_MAP[prefix]
# Fallback to US for anything we don't recognize
return "us."
def _to_inference_profile_id(model_id: str, region: str = None) -> str:
# Return ARNs unchanged
if model_id.startswith("arn:"):
return model_id
# Return inference profile IDs that already have regional prefixes
if any(model_id.startswith(p) for p in REGION_PREFIX_MAP.values()):
return model_id
# Default to US East when no region is provided
if region is None:
region = "us-east-1"
return _get_region_prefix(region) + model_id
class BedrockInferenceAdapter(
ModelRegistryHelper,
@ -61,7 +92,7 @@ class BedrockInferenceAdapter(
OpenAICompletionToLlamaStackMixin,
):
def __init__(self, config: BedrockConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
self._config = config
self._client = None
@ -166,8 +197,13 @@ class BedrockInferenceAdapter(
options["repetition_penalty"] = sampling_params.repetition_penalty
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
# Convert foundation model ID to inference profile ID
region_name = self.client.meta.region_name
inference_profile_id = _to_inference_profile_id(bedrock_model, region_name)
return {
"modelId": bedrock_model,
"modelId": inference_profile_id,
"body": json.dumps(
{
"prompt": prompt,
@ -176,31 +212,6 @@ class BedrockInferenceAdapter(
),
}
async def embeddings(
self,
model_id: str,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
embeddings = []
for content in contents:
assert not content_has_media(content), "Bedrock does not support media for embeddings"
input_text = interleaved_content_as_str(content)
input_body = {"inputText": input_text}
body = json.dumps(input_body)
response = self.client.invoke_model(
body=body,
modelId=model.provider_resource_id,
accept="application/json",
contentType="application/json",
)
response_body = json.loads(response.get("body").read())
embeddings.append(response_body.get("embedding"))
return EmbeddingsResponse(embeddings=embeddings)
async def openai_embeddings(
self,
model: str,

View file

@ -5,26 +5,23 @@
# the root directory of this source tree.
from collections.abc import AsyncGenerator
from urllib.parse import urljoin
from cerebras.cloud.sdk import AsyncCerebras
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionRequest,
CompletionResponse,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
OpenAIEmbeddingsResponse,
ResponseFormat,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
@ -35,42 +32,41 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
)
from .config import CerebrasImplConfig
from .models import MODEL_ENTRIES
class CerebrasInferenceAdapter(
OpenAIMixin,
ModelRegistryHelper,
Inference,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
):
def __init__(self, config: CerebrasImplConfig) -> None:
ModelRegistryHelper.__init__(
self,
model_entries=MODEL_ENTRIES,
)
self.config = config
# TODO: make this use provider data, etc. like other providers
self.client = AsyncCerebras(
self._cerebras_client = AsyncCerebras(
base_url=self.config.base_url,
api_key=self.config.api_key.get_secret_value(),
)
def get_api_key(self) -> str:
return self.config.api_key.get_secret_value()
def get_base_url(self) -> str:
return urljoin(self.config.base_url, "v1")
async def initialize(self) -> None:
return
@ -107,14 +103,14 @@ class CerebrasInferenceAdapter(
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request)
r = await self.client.completions.create(**params)
r = await self._cerebras_client.completions.create(**params)
return process_completion_response(r)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
stream = await self.client.completions.create(**params)
stream = await self._cerebras_client.completions.create(**params)
async for chunk in process_completion_stream_response(stream):
yield chunk
@ -156,14 +152,14 @@ class CerebrasInferenceAdapter(
async def _nonstream_chat_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request)
r = await self.client.completions.create(**params)
r = await self._cerebras_client.completions.create(**params)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
stream = await self.client.completions.create(**params)
stream = await self._cerebras_client.completions.create(**params)
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
@ -187,16 +183,6 @@ class CerebrasInferenceAdapter(
**get_sampling_options(request.sampling_params),
}
async def embeddings(
self,
model_id: str,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
raise NotImplementedError()
async def openai_embeddings(
self,
model: str,

View file

@ -20,8 +20,8 @@ class CerebrasImplConfig(BaseModel):
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
description="Base URL for the Cerebras API",
)
api_key: SecretStr | None = Field(
default=os.environ.get("CEREBRAS_API_KEY"),
api_key: SecretStr = Field(
default=SecretStr(os.environ.get("CEREBRAS_API_KEY")),
description="Cerebras API Key",
)

View file

@ -1,28 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry,
)
SAFETY_MODELS_ENTRIES = []
# https://inference-docs.cerebras.ai/models
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"llama3.1-8b",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"llama-3.3-70b",
CoreModelId.llama3_3_70b_instruct.value,
),
build_hf_repo_model_entry(
"llama-4-scout-17b-16e-instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
] + SAFETY_MODELS_ENTRIES

View file

@ -5,10 +5,11 @@
# the root directory of this source tree.
from .config import DatabricksImplConfig
from .databricks import DatabricksInferenceAdapter
async def get_adapter_impl(config: DatabricksImplConfig, _deps):
from .databricks import DatabricksInferenceAdapter
assert isinstance(config, DatabricksImplConfig), f"Unexpected config type: {type(config)}"
impl = DatabricksInferenceAdapter(config)
await impl.initialize()

View file

@ -6,7 +6,7 @@
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, SecretStr
from llama_stack.schema_utils import json_schema_type
@ -17,16 +17,16 @@ class DatabricksImplConfig(BaseModel):
default=None,
description="The URL for the Databricks model serving endpoint",
)
api_token: str = Field(
default=None,
api_token: SecretStr = Field(
default=SecretStr(None),
description="The Databricks API token",
)
@classmethod
def sample_run_config(
cls,
url: str = "${env.DATABRICKS_URL:=}",
api_token: str = "${env.DATABRICKS_API_TOKEN:=}",
url: str = "${env.DATABRICKS_HOST:=}",
api_token: str = "${env.DATABRICKS_TOKEN:=}",
**kwargs: Any,
) -> dict[str, Any]:
return {

View file

@ -4,74 +4,59 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator
from collections.abc import AsyncIterator
from typing import Any
from openai import OpenAI
from databricks.sdk import WorkspaceClient
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
EmbeddingsResponse,
EmbeddingTaskType,
ChatCompletionResponseStreamChunk,
CompletionResponse,
CompletionResponseStreamChunk,
Inference,
LogProbConfig,
Message,
OpenAIEmbeddingsResponse,
Model,
OpenAICompletion,
ResponseFormat,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
)
from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import DatabricksImplConfig
SAFETY_MODELS_ENTRIES = []
# https://docs.databricks.com/aws/en/machine-learning/model-serving/foundation-model-overview
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"databricks-meta-llama-3-1-70b-instruct",
CoreModelId.llama3_1_70b_instruct.value,
),
build_hf_repo_model_entry(
"databricks-meta-llama-3-1-405b-instruct",
CoreModelId.llama3_1_405b_instruct.value,
),
] + SAFETY_MODELS_ENTRIES
logger = get_logger(name=__name__, category="inference::databricks")
class DatabricksInferenceAdapter(
ModelRegistryHelper,
OpenAIMixin,
Inference,
OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin,
):
# source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models
embedding_model_metadata = {
"databricks-gte-large-en": {"embedding_dimension": 1024, "context_length": 8192},
"databricks-bge-large-en": {"embedding_dimension": 1024, "context_length": 512},
}
def __init__(self, config: DatabricksImplConfig) -> None:
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
self.config = config
def get_api_key(self) -> str:
return self.config.api_token.get_secret_value()
def get_base_url(self) -> str:
return f"{self.config.url}/serving-endpoints"
async def initialize(self) -> None:
return
@ -80,89 +65,80 @@ class DatabricksInferenceAdapter(
async def completion(
self,
model: str,
model_id: str,
content: InterleavedContent,
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
) -> AsyncGenerator:
) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]:
raise NotImplementedError()
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
raise NotImplementedError()
async def chat_completion(
self,
model: str,
model_id: str,
messages: list[Message],
sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None,
tools: list[ToolDefinition] | None = None,
tool_choice: ToolChoice | None = ToolChoice.auto,
tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
request = ChatCompletionRequest(
model=model,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
if stream:
return self._stream_chat_completion(request, client)
else:
return await self._nonstream_chat_completion(request, client)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
) -> ChatCompletionResponse:
params = self._get_params(request)
r = client.completions.create(**params)
return process_chat_completion_response(r, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
params = self._get_params(request)
async def _to_async_generator():
s = client.completions.create(**params)
for chunk in s:
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
def _get_params(self, request: ChatCompletionRequest) -> dict:
return {
"model": request.model,
"prompt": chat_completion_request_to_prompt(request, self.get_llama_model(request.model)),
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}
async def embeddings(
self,
model_id: str,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
raise NotImplementedError()
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
async def list_models(self) -> list[Model] | None:
self._model_cache = {} # from OpenAIMixin
ws_client = WorkspaceClient(host=self.config.url, token=self.get_api_key()) # TODO: this is not async
endpoints = ws_client.serving_endpoints.list()
for endpoint in endpoints:
model = Model(
provider_id=self.__provider_id__,
provider_resource_id=endpoint.name,
identifier=endpoint.name,
)
if endpoint.task == "llm/v1/chat":
model.model_type = ModelType.llm # this is redundant, but informative
elif endpoint.task == "llm/v1/embeddings":
if endpoint.name not in self.embedding_model_metadata:
logger.warning(f"No metadata information available for embedding model {endpoint.name}, skipping.")
continue
model.model_type = ModelType.embedding
model.metadata = self.embedding_model_metadata[endpoint.name]
else:
logger.warning(f"Unknown model type, skipping: {endpoint}")
continue
self._model_cache[endpoint.name] = model
return list(self._model_cache.values())
async def should_refresh_models(self) -> bool:
return False

Some files were not shown because too many files have changed in this diff Show more