mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
Merge branch 'main' into agents-openai-migration
This commit is contained in:
commit
724322eeb2
673 changed files with 164269 additions and 14378 deletions
|
@ -27,6 +27,7 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
from llama_stack.apis.safety import SafetyViolation
|
||||
from llama_stack.apis.tools import ToolDef
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
from .openai_responses import (
|
||||
|
@ -481,7 +482,10 @@ class Agents(Protocol):
|
|||
- Agents can also use Memory to retrieve information from knowledge bases. See the RAG Tool and Vector IO APIs for more details.
|
||||
"""
|
||||
|
||||
@webmethod(route="/agents", method="POST", descriptive_name="create_agent")
|
||||
@webmethod(
|
||||
route="/agents", method="POST", descriptive_name="create_agent", deprecated=True, level=LLAMA_STACK_API_V1
|
||||
)
|
||||
@webmethod(route="/agents", method="POST", descriptive_name="create_agent", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def create_agent(
|
||||
self,
|
||||
agent_config: AgentConfig,
|
||||
|
@ -494,7 +498,17 @@ class Agents(Protocol):
|
|||
...
|
||||
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn", method="POST", descriptive_name="create_agent_turn"
|
||||
route="/agents/{agent_id}/session/{session_id}/turn",
|
||||
method="POST",
|
||||
descriptive_name="create_agent_turn",
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn",
|
||||
method="POST",
|
||||
descriptive_name="create_agent_turn",
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def create_agent_turn(
|
||||
self,
|
||||
|
@ -524,6 +538,14 @@ class Agents(Protocol):
|
|||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
|
||||
method="POST",
|
||||
descriptive_name="resume_agent_turn",
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
|
||||
method="POST",
|
||||
descriptive_name="resume_agent_turn",
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def resume_agent_turn(
|
||||
self,
|
||||
|
@ -549,6 +571,13 @@ class Agents(Protocol):
|
|||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
|
||||
method="GET",
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def get_agents_turn(
|
||||
self,
|
||||
|
@ -568,6 +597,13 @@ class Agents(Protocol):
|
|||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}",
|
||||
method="GET",
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def get_agents_step(
|
||||
self,
|
||||
|
@ -586,7 +622,19 @@ class Agents(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}/session", method="POST", descriptive_name="create_agent_session")
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session",
|
||||
method="POST",
|
||||
descriptive_name="create_agent_session",
|
||||
deprecated=True,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session",
|
||||
method="POST",
|
||||
descriptive_name="create_agent_session",
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def create_agent_session(
|
||||
self,
|
||||
agent_id: str,
|
||||
|
@ -600,7 +648,8 @@ class Agents(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="GET")
|
||||
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def get_agents_session(
|
||||
self,
|
||||
session_id: str,
|
||||
|
@ -616,7 +665,10 @@ class Agents(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="DELETE")
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}", method="DELETE", deprecated=True, level=LLAMA_STACK_API_V1
|
||||
)
|
||||
@webmethod(route="/agents/{agent_id}/session/{session_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def delete_agents_session(
|
||||
self,
|
||||
session_id: str,
|
||||
|
@ -629,7 +681,8 @@ class Agents(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}", method="DELETE")
|
||||
@webmethod(route="/agents/{agent_id}", method="DELETE", deprecated=True, level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/agents/{agent_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def delete_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
|
@ -640,7 +693,8 @@ class Agents(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents", method="GET")
|
||||
@webmethod(route="/agents", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/agents", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
|
||||
"""List all agents.
|
||||
|
||||
|
@ -650,7 +704,8 @@ class Agents(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}", method="GET")
|
||||
@webmethod(route="/agents/{agent_id}", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/agents/{agent_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def get_agent(self, agent_id: str) -> Agent:
|
||||
"""Describe an agent by its ID.
|
||||
|
||||
|
@ -659,7 +714,8 @@ class Agents(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}/sessions", method="GET")
|
||||
@webmethod(route="/agents/{agent_id}/sessions", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
|
||||
@webmethod(route="/agents/{agent_id}/sessions", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def list_agent_sessions(
|
||||
self,
|
||||
agent_id: str,
|
||||
|
@ -682,7 +738,7 @@ class Agents(Protocol):
|
|||
#
|
||||
# Both of these APIs are inherently stateful.
|
||||
|
||||
@webmethod(route="/openai/v1/responses/{response_id}", method="GET")
|
||||
@webmethod(route="/responses/{response_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_openai_response(
|
||||
self,
|
||||
response_id: str,
|
||||
|
@ -694,7 +750,7 @@ class Agents(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/responses", method="POST")
|
||||
@webmethod(route="/responses", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def create_openai_response(
|
||||
self,
|
||||
input: str | list[OpenAIResponseInput],
|
||||
|
@ -719,7 +775,7 @@ class Agents(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/responses", method="GET")
|
||||
@webmethod(route="/responses", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_openai_responses(
|
||||
self,
|
||||
after: str | None = None,
|
||||
|
@ -737,7 +793,7 @@ class Agents(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/responses/{response_id}/input_items", method="GET")
|
||||
@webmethod(route="/responses/{response_id}/input_items", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_openai_response_input_items(
|
||||
self,
|
||||
response_id: str,
|
||||
|
@ -759,7 +815,7 @@ class Agents(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/responses/{response_id}", method="DELETE")
|
||||
@webmethod(route="/responses/{response_id}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||
"""Delete an OpenAI response by its ID.
|
||||
|
||||
|
|
|
@ -276,13 +276,40 @@ class OpenAIResponseOutputMessageMCPListTools(BaseModel):
|
|||
tools: list[MCPListToolsTool]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseMCPApprovalRequest(BaseModel):
|
||||
"""
|
||||
A request for human approval of a tool invocation.
|
||||
"""
|
||||
|
||||
arguments: str
|
||||
id: str
|
||||
name: str
|
||||
server_label: str
|
||||
type: Literal["mcp_approval_request"] = "mcp_approval_request"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OpenAIResponseMCPApprovalResponse(BaseModel):
|
||||
"""
|
||||
A response to an MCP approval request.
|
||||
"""
|
||||
|
||||
approval_request_id: str
|
||||
approve: bool
|
||||
type: Literal["mcp_approval_response"] = "mcp_approval_response"
|
||||
id: str | None = None
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
OpenAIResponseOutput = Annotated[
|
||||
OpenAIResponseMessage
|
||||
| OpenAIResponseOutputMessageWebSearchToolCall
|
||||
| OpenAIResponseOutputMessageFileSearchToolCall
|
||||
| OpenAIResponseOutputMessageFunctionToolCall
|
||||
| OpenAIResponseOutputMessageMCPCall
|
||||
| OpenAIResponseOutputMessageMCPListTools,
|
||||
| OpenAIResponseOutputMessageMCPListTools
|
||||
| OpenAIResponseMCPApprovalRequest,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
|
||||
|
@ -336,7 +363,6 @@ class OpenAIResponseObject(BaseModel):
|
|||
:param text: Text formatting configuration for the response
|
||||
:param top_p: (Optional) Nucleus sampling parameter used for generation
|
||||
:param truncation: (Optional) Truncation strategy applied to the response
|
||||
:param user: (Optional) User identifier associated with the request
|
||||
"""
|
||||
|
||||
created_at: int
|
||||
|
@ -354,7 +380,6 @@ class OpenAIResponseObject(BaseModel):
|
|||
text: OpenAIResponseText = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text"))
|
||||
top_p: float | None = None
|
||||
truncation: str | None = None
|
||||
user: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -725,6 +750,8 @@ OpenAIResponseInput = Annotated[
|
|||
| OpenAIResponseOutputMessageFileSearchToolCall
|
||||
| OpenAIResponseOutputMessageFunctionToolCall
|
||||
| OpenAIResponseInputFunctionToolCallOutput
|
||||
| OpenAIResponseMCPApprovalRequest
|
||||
| OpenAIResponseMCPApprovalResponse
|
||||
|
|
||||
# Fallback to the generic message type as a last resort
|
||||
OpenAIResponseMessage,
|
||||
|
|
|
@ -1,78 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from llama_stack.apis.common.job_types import Job
|
||||
from llama_stack.apis.inference import (
|
||||
InterleavedContent,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.schema_utils import webmethod
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class BatchInference(Protocol):
|
||||
"""Batch inference API for generating completions and chat completions.
|
||||
|
||||
This is an asynchronous API. If the request is successful, the response will be a job which can be polled for completion.
|
||||
|
||||
NOTE: This API is not yet implemented and is subject to change in concert with other asynchronous APIs
|
||||
including (post-training, evals, etc).
|
||||
"""
|
||||
|
||||
@webmethod(route="/batch-inference/completion", method="POST")
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content_batch: list[InterleavedContent],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> Job:
|
||||
"""Generate completions for a batch of content.
|
||||
|
||||
:param model: The model to use for the completion.
|
||||
:param content_batch: The content to complete.
|
||||
:param sampling_params: The sampling parameters to use for the completion.
|
||||
:param response_format: The response format to use for the completion.
|
||||
:param logprobs: The logprobs to use for the completion.
|
||||
:returns: A job for the completion.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/batch-inference/chat-completion", method="POST")
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages_batch: list[list[Message]],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
# zero-shot tool definitions as input to the model
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> Job:
|
||||
"""Generate chat completions for a batch of messages.
|
||||
|
||||
:param model: The model to use for the chat completion.
|
||||
:param messages_batch: The messages to complete.
|
||||
:param sampling_params: The sampling parameters to use for the completion.
|
||||
:param tools: The tools to use for the chat completion.
|
||||
:param tool_choice: The tool choice to use for the chat completion.
|
||||
:param tool_prompt_format: The tool prompt format to use for the chat completion.
|
||||
:param response_format: The response format to use for the chat completion.
|
||||
:param logprobs: The logprobs to use for the chat completion.
|
||||
:returns: A job for the chat completion.
|
||||
"""
|
||||
...
|
|
@ -8,6 +8,7 @@ from typing import Literal, Protocol, runtime_checkable
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
try:
|
||||
|
@ -42,7 +43,7 @@ class Batches(Protocol):
|
|||
Note: This API is currently under active development and may undergo changes.
|
||||
"""
|
||||
|
||||
@webmethod(route="/openai/v1/batches", method="POST")
|
||||
@webmethod(route="/batches", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def create_batch(
|
||||
self,
|
||||
input_file_id: str,
|
||||
|
@ -62,7 +63,7 @@ class Batches(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/batches/{batch_id}", method="GET")
|
||||
@webmethod(route="/batches/{batch_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def retrieve_batch(self, batch_id: str) -> BatchObject:
|
||||
"""Retrieve information about a specific batch.
|
||||
|
||||
|
@ -71,7 +72,7 @@ class Batches(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/batches/{batch_id}/cancel", method="POST")
|
||||
@webmethod(route="/batches/{batch_id}/cancel", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def cancel_batch(self, batch_id: str) -> BatchObject:
|
||||
"""Cancel a batch that is in progress.
|
||||
|
||||
|
@ -80,7 +81,7 @@ class Batches(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/batches", method="GET")
|
||||
@webmethod(route="/batches", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_batches(
|
||||
self,
|
||||
after: str | None = None,
|
||||
|
|
|
@ -8,6 +8,7 @@ from typing import Any, Literal, Protocol, runtime_checkable
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
|
@ -53,7 +54,8 @@ class ListBenchmarksResponse(BaseModel):
|
|||
|
||||
@runtime_checkable
|
||||
class Benchmarks(Protocol):
|
||||
@webmethod(route="/eval/benchmarks", method="GET")
|
||||
@webmethod(route="/eval/benchmarks", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/eval/benchmarks", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def list_benchmarks(self) -> ListBenchmarksResponse:
|
||||
"""List all benchmarks.
|
||||
|
||||
|
@ -61,7 +63,8 @@ class Benchmarks(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET")
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def get_benchmark(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
|
@ -73,7 +76,8 @@ class Benchmarks(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks", method="POST")
|
||||
@webmethod(route="/eval/benchmarks", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/eval/benchmarks", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def register_benchmark(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
|
@ -93,3 +97,12 @@ class Benchmarks(Protocol):
|
|||
:param metadata: The metadata to use for the benchmark.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def unregister_benchmark(self, benchmark_id: str) -> None:
|
||||
"""Unregister a benchmark.
|
||||
|
||||
:param benchmark_id: The ID of the benchmark to unregister.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -8,6 +8,7 @@ from typing import Any, Protocol, runtime_checkable
|
|||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.apis.datasets import Dataset
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.schema_utils import webmethod
|
||||
|
||||
|
||||
|
@ -20,7 +21,7 @@ class DatasetIO(Protocol):
|
|||
# keeping for aligning with inference/safety, but this is not used
|
||||
dataset_store: DatasetStore
|
||||
|
||||
@webmethod(route="/datasetio/iterrows/{dataset_id:path}", method="GET")
|
||||
@webmethod(route="/datasetio/iterrows/{dataset_id:path}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def iterrows(
|
||||
self,
|
||||
dataset_id: str,
|
||||
|
@ -44,7 +45,7 @@ class DatasetIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST")
|
||||
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
|
||||
"""Append rows to a dataset.
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ from typing import Annotated, Any, Literal, Protocol
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
||||
|
@ -145,7 +146,7 @@ class ListDatasetsResponse(BaseModel):
|
|||
|
||||
|
||||
class Datasets(Protocol):
|
||||
@webmethod(route="/datasets", method="POST")
|
||||
@webmethod(route="/datasets", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def register_dataset(
|
||||
self,
|
||||
purpose: DatasetPurpose,
|
||||
|
@ -214,7 +215,7 @@ class Datasets(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/datasets/{dataset_id:path}", method="GET")
|
||||
@webmethod(route="/datasets/{dataset_id:path}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
|
@ -226,7 +227,7 @@ class Datasets(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/datasets", method="GET")
|
||||
@webmethod(route="/datasets", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_datasets(self) -> ListDatasetsResponse:
|
||||
"""List all datasets.
|
||||
|
||||
|
@ -234,7 +235,7 @@ class Datasets(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE")
|
||||
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def unregister_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
|
|
|
@ -13,6 +13,7 @@ from llama_stack.apis.common.job_types import Job
|
|||
from llama_stack.apis.inference import SamplingParams, SystemMessage
|
||||
from llama_stack.apis.scoring import ScoringResult
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
||||
|
@ -83,7 +84,8 @@ class EvaluateResponse(BaseModel):
|
|||
class Eval(Protocol):
|
||||
"""Llama Stack Evaluation API for running evaluations on model and agent candidates."""
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST")
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def run_eval(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
|
@ -97,7 +99,10 @@ class Eval(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST")
|
||||
@webmethod(
|
||||
route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST", level=LLAMA_STACK_API_V1, deprecated=True
|
||||
)
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def evaluate_rows(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
|
@ -115,7 +120,10 @@ class Eval(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
|
||||
@webmethod(
|
||||
route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True
|
||||
)
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
|
||||
"""Get the status of a job.
|
||||
|
||||
|
@ -125,7 +133,13 @@ class Eval(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="DELETE")
|
||||
@webmethod(
|
||||
route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}",
|
||||
method="DELETE",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
|
||||
"""Cancel a job.
|
||||
|
||||
|
@ -134,7 +148,15 @@ class Eval(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET")
|
||||
@webmethod(
|
||||
route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
deprecated=True,
|
||||
)
|
||||
@webmethod(
|
||||
route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET", level=LLAMA_STACK_API_V1ALPHA
|
||||
)
|
||||
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
|
||||
"""Get the result of a job.
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ from fastapi import File, Form, Response, UploadFile
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.common.responses import Order
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
@ -104,14 +105,12 @@ class OpenAIFileDeleteResponse(BaseModel):
|
|||
@trace_protocol
|
||||
class Files(Protocol):
|
||||
# OpenAI Files API Endpoints
|
||||
@webmethod(route="/openai/v1/files", method="POST")
|
||||
@webmethod(route="/files", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def openai_upload_file(
|
||||
self,
|
||||
file: Annotated[UploadFile, File()],
|
||||
purpose: Annotated[OpenAIFilePurpose, Form()],
|
||||
expires_after_anchor: Annotated[str | None, Form(alias="expires_after[anchor]")] = None,
|
||||
expires_after_seconds: Annotated[int | None, Form(alias="expires_after[seconds]")] = None,
|
||||
# TODO: expires_after is producing strange openapi spec, params are showing up as a required w/ oneOf being null
|
||||
expires_after: Annotated[ExpiresAfter | None, Form()] = None,
|
||||
) -> OpenAIFileObject:
|
||||
"""
|
||||
Upload a file that can be used across various endpoints.
|
||||
|
@ -119,15 +118,16 @@ class Files(Protocol):
|
|||
The file upload should be a multipart form request with:
|
||||
- file: The File object (not file name) to be uploaded.
|
||||
- purpose: The intended purpose of the uploaded file.
|
||||
- expires_after: Optional form values describing expiration for the file. Expected expires_after[anchor] = "created_at", expires_after[seconds] = <int>. Seconds must be between 3600 and 2592000 (1 hour to 30 days).
|
||||
- expires_after: Optional form values describing expiration for the file.
|
||||
|
||||
:param file: The uploaded file object containing content and metadata (filename, content_type, etc.).
|
||||
:param purpose: The intended purpose of the uploaded file (e.g., "assistants", "fine-tune").
|
||||
:param expires_after: Optional form values describing expiration for the file.
|
||||
:returns: An OpenAIFileObject representing the uploaded file.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/files", method="GET")
|
||||
@webmethod(route="/files", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def openai_list_files(
|
||||
self,
|
||||
after: str | None = None,
|
||||
|
@ -146,7 +146,7 @@ class Files(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/files/{file_id}", method="GET")
|
||||
@webmethod(route="/files/{file_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def openai_retrieve_file(
|
||||
self,
|
||||
file_id: str,
|
||||
|
@ -159,7 +159,7 @@ class Files(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/files/{file_id}", method="DELETE")
|
||||
@webmethod(route="/files/{file_id}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def openai_delete_file(
|
||||
self,
|
||||
file_id: str,
|
||||
|
@ -172,7 +172,7 @@ class Files(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/files/{file_id}/content", method="GET")
|
||||
@webmethod(route="/files/{file_id}/content", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def openai_retrieve_file_content(
|
||||
self,
|
||||
file_id: str,
|
||||
|
|
|
@ -17,10 +17,11 @@ from typing import (
|
|||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
|
||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
|
||||
from llama_stack.apis.common.responses import Order
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.telemetry import MetricResponseMixin
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
StopReason,
|
||||
|
@ -913,6 +914,7 @@ class OpenAIEmbeddingData(BaseModel):
|
|||
"""
|
||||
|
||||
object: Literal["embedding"] = "embedding"
|
||||
# TODO: consider dropping str and using openai.types.embeddings.Embedding instead of OpenAIEmbeddingData
|
||||
embedding: list[float] | str
|
||||
index: int
|
||||
|
||||
|
@ -973,26 +975,6 @@ class EmbeddingTaskType(Enum):
|
|||
document = "document"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BatchCompletionResponse(BaseModel):
|
||||
"""Response from a batch completion request.
|
||||
|
||||
:param batch: List of completion responses, one for each input in the batch
|
||||
"""
|
||||
|
||||
batch: list[CompletionResponse]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BatchChatCompletionResponse(BaseModel):
|
||||
"""Response from a batch chat completion request.
|
||||
|
||||
:param batch: List of chat completion responses, one for each conversation in the batch
|
||||
"""
|
||||
|
||||
batch: list[ChatCompletionResponse]
|
||||
|
||||
|
||||
class OpenAICompletionWithInputMessages(OpenAIChatCompletion):
|
||||
input_messages: list[OpenAIMessageParam]
|
||||
|
||||
|
@ -1026,7 +1008,6 @@ class InferenceProvider(Protocol):
|
|||
|
||||
model_store: ModelStore | None = None
|
||||
|
||||
@webmethod(route="/inference/completion", method="POST")
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
@ -1049,28 +1030,6 @@ class InferenceProvider(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/inference/batch-completion", method="POST", experimental=True)
|
||||
async def batch_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content_batch: list[InterleavedContent],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> BatchCompletionResponse:
|
||||
"""Generate completions for a batch of content using the specified model.
|
||||
|
||||
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param content_batch: The content to generate completions for.
|
||||
:param sampling_params: (Optional) Parameters to control the sampling strategy.
|
||||
:param response_format: (Optional) Grammar specification for guided (structured) decoding.
|
||||
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
|
||||
:returns: A BatchCompletionResponse with the full completions.
|
||||
"""
|
||||
raise NotImplementedError("Batch completion is not implemented")
|
||||
return # this is so mypy's safe-super rule will consider the method concrete
|
||||
|
||||
@webmethod(route="/inference/chat-completion", method="POST")
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
@ -1110,52 +1069,7 @@ class InferenceProvider(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/inference/batch-chat-completion", method="POST", experimental=True)
|
||||
async def batch_chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages_batch: list[list[Message]],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> BatchChatCompletionResponse:
|
||||
"""Generate chat completions for a batch of messages using the specified model.
|
||||
|
||||
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||
:param messages_batch: The messages to generate completions for.
|
||||
:param sampling_params: (Optional) Parameters to control the sampling strategy.
|
||||
:param tools: (Optional) List of tool definitions available to the model.
|
||||
:param tool_config: (Optional) Configuration for tool use.
|
||||
:param response_format: (Optional) Grammar specification for guided (structured) decoding.
|
||||
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
|
||||
:returns: A BatchChatCompletionResponse with the full completions.
|
||||
"""
|
||||
raise NotImplementedError("Batch chat completion is not implemented")
|
||||
return # this is so mypy's safe-super rule will consider the method concrete
|
||||
|
||||
@webmethod(route="/inference/embeddings", method="POST")
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
"""Generate embeddings for content pieces using the specified model.
|
||||
|
||||
:param model_id: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
|
||||
:param contents: List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text.
|
||||
:param output_dimension: (Optional) Output dimensionality for the embeddings. Only supported by Matryoshka models.
|
||||
:param text_truncation: (Optional) Config for how to truncate text for embedding when text is longer than the model's max sequence length.
|
||||
:param task_type: (Optional) How is the embedding being used? This is only supported by asymmetric embedding models.
|
||||
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/inference/rerank", method="POST", experimental=True)
|
||||
@webmethod(route="/inference/rerank", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def rerank(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -1174,7 +1088,7 @@ class InferenceProvider(Protocol):
|
|||
raise NotImplementedError("Reranking is not implemented")
|
||||
return # this is so mypy's safe-super rule will consider the method concrete
|
||||
|
||||
@webmethod(route="/openai/v1/completions", method="POST")
|
||||
@webmethod(route="/completions", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def openai_completion(
|
||||
self,
|
||||
# Standard OpenAI completion parameters
|
||||
|
@ -1225,7 +1139,7 @@ class InferenceProvider(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/chat/completions", method="POST")
|
||||
@webmethod(route="/chat/completions", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -1281,7 +1195,7 @@ class InferenceProvider(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/embeddings", method="POST")
|
||||
@webmethod(route="/embeddings", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -1310,7 +1224,7 @@ class Inference(InferenceProvider):
|
|||
- Embedding models: these models generate embeddings to be used for semantic search.
|
||||
"""
|
||||
|
||||
@webmethod(route="/openai/v1/chat/completions", method="GET")
|
||||
@webmethod(route="/chat/completions", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_chat_completions(
|
||||
self,
|
||||
after: str | None = None,
|
||||
|
@ -1328,7 +1242,7 @@ class Inference(InferenceProvider):
|
|||
"""
|
||||
raise NotImplementedError("List chat completions is not implemented")
|
||||
|
||||
@webmethod(route="/openai/v1/chat/completions/{completion_id}", method="GET")
|
||||
@webmethod(route="/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
|
||||
"""Describe a chat completion by its ID.
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ from typing import Protocol, runtime_checkable
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.providers.datatypes import HealthStatus
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
@ -57,7 +58,7 @@ class ListRoutesResponse(BaseModel):
|
|||
|
||||
@runtime_checkable
|
||||
class Inspect(Protocol):
|
||||
@webmethod(route="/inspect/routes", method="GET")
|
||||
@webmethod(route="/inspect/routes", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_routes(self) -> ListRoutesResponse:
|
||||
"""List all available API routes with their methods and implementing providers.
|
||||
|
||||
|
@ -65,7 +66,7 @@ class Inspect(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/health", method="GET")
|
||||
@webmethod(route="/health", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def health(self) -> HealthInfo:
|
||||
"""Get the current health status of the service.
|
||||
|
||||
|
@ -73,7 +74,7 @@ class Inspect(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/version", method="GET")
|
||||
@webmethod(route="/version", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def version(self) -> VersionInfo:
|
||||
"""Get the version of the service.
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ from typing import Any, Literal, Protocol, runtime_checkable
|
|||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
@ -102,7 +103,7 @@ class OpenAIListModelsResponse(BaseModel):
|
|||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Models(Protocol):
|
||||
@webmethod(route="/models", method="GET")
|
||||
@webmethod(route="/models", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_models(self) -> ListModelsResponse:
|
||||
"""List all models.
|
||||
|
||||
|
@ -110,15 +111,7 @@ class Models(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/models", method="GET")
|
||||
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
||||
"""List models using the OpenAI API.
|
||||
|
||||
:returns: A OpenAIListModelsResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/models/{model_id:path}", method="GET")
|
||||
@webmethod(route="/models/{model_id:path}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_model(
|
||||
self,
|
||||
model_id: str,
|
||||
|
@ -130,7 +123,7 @@ class Models(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/models", method="POST")
|
||||
@webmethod(route="/models", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def register_model(
|
||||
self,
|
||||
model_id: str,
|
||||
|
@ -150,7 +143,7 @@ class Models(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/models/{model_id:path}", method="DELETE")
|
||||
@webmethod(route="/models/{model_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def unregister_model(
|
||||
self,
|
||||
model_id: str,
|
||||
|
|
|
@ -13,6 +13,7 @@ from pydantic import BaseModel, Field
|
|||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.common.job_types import JobStatus
|
||||
from llama_stack.apis.common.training_types import Checkpoint
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
||||
|
@ -283,7 +284,8 @@ class PostTrainingJobArtifactsResponse(BaseModel):
|
|||
|
||||
|
||||
class PostTraining(Protocol):
|
||||
@webmethod(route="/post-training/supervised-fine-tune", method="POST")
|
||||
@webmethod(route="/post-training/supervised-fine-tune", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/post-training/supervised-fine-tune", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def supervised_fine_tune(
|
||||
self,
|
||||
job_uuid: str,
|
||||
|
@ -310,7 +312,8 @@ class PostTraining(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/post-training/preference-optimize", method="POST")
|
||||
@webmethod(route="/post-training/preference-optimize", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/post-training/preference-optimize", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def preference_optimize(
|
||||
self,
|
||||
job_uuid: str,
|
||||
|
@ -332,7 +335,8 @@ class PostTraining(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/post-training/jobs", method="GET")
|
||||
@webmethod(route="/post-training/jobs", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/post-training/jobs", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
|
||||
"""Get all training jobs.
|
||||
|
||||
|
@ -340,7 +344,8 @@ class PostTraining(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/post-training/job/status", method="GET")
|
||||
@webmethod(route="/post-training/job/status", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/post-training/job/status", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse:
|
||||
"""Get the status of a training job.
|
||||
|
||||
|
@ -349,7 +354,8 @@ class PostTraining(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/post-training/job/cancel", method="POST")
|
||||
@webmethod(route="/post-training/job/cancel", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/post-training/job/cancel", method="POST", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def cancel_training_job(self, job_uuid: str) -> None:
|
||||
"""Cancel a training job.
|
||||
|
||||
|
@ -357,7 +363,8 @@ class PostTraining(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/post-training/job/artifacts", method="GET")
|
||||
@webmethod(route="/post-training/job/artifacts", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/post-training/job/artifacts", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
|
||||
"""Get the artifacts of a training job.
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ from typing import Protocol, runtime_checkable
|
|||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
@ -95,7 +96,7 @@ class ListPromptsResponse(BaseModel):
|
|||
class Prompts(Protocol):
|
||||
"""Protocol for prompt management operations."""
|
||||
|
||||
@webmethod(route="/prompts", method="GET")
|
||||
@webmethod(route="/prompts", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_prompts(self) -> ListPromptsResponse:
|
||||
"""List all prompts.
|
||||
|
||||
|
@ -103,7 +104,7 @@ class Prompts(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/prompts/{prompt_id}/versions", method="GET")
|
||||
@webmethod(route="/prompts/{prompt_id}/versions", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_prompt_versions(
|
||||
self,
|
||||
prompt_id: str,
|
||||
|
@ -115,7 +116,7 @@ class Prompts(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/prompts/{prompt_id}", method="GET")
|
||||
@webmethod(route="/prompts/{prompt_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_prompt(
|
||||
self,
|
||||
prompt_id: str,
|
||||
|
@ -129,7 +130,7 @@ class Prompts(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/prompts", method="POST")
|
||||
@webmethod(route="/prompts", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def create_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
|
@ -143,7 +144,7 @@ class Prompts(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/prompts/{prompt_id}", method="PUT")
|
||||
@webmethod(route="/prompts/{prompt_id}", method="PUT", level=LLAMA_STACK_API_V1)
|
||||
async def update_prompt(
|
||||
self,
|
||||
prompt_id: str,
|
||||
|
@ -163,7 +164,7 @@ class Prompts(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/prompts/{prompt_id}", method="DELETE")
|
||||
@webmethod(route="/prompts/{prompt_id}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def delete_prompt(
|
||||
self,
|
||||
prompt_id: str,
|
||||
|
@ -174,7 +175,7 @@ class Prompts(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/prompts/{prompt_id}/set-default-version", method="PUT")
|
||||
@webmethod(route="/prompts/{prompt_id}/set-default-version", method="PUT", level=LLAMA_STACK_API_V1)
|
||||
async def set_default_version(
|
||||
self,
|
||||
prompt_id: str,
|
||||
|
|
|
@ -8,6 +8,7 @@ from typing import Any, Protocol, runtime_checkable
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.providers.datatypes import HealthResponse
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
@ -45,7 +46,7 @@ class Providers(Protocol):
|
|||
Providers API for inspecting, listing, and modifying providers and their configurations.
|
||||
"""
|
||||
|
||||
@webmethod(route="/providers", method="GET")
|
||||
@webmethod(route="/providers", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_providers(self) -> ListProvidersResponse:
|
||||
"""List all available providers.
|
||||
|
||||
|
@ -53,7 +54,7 @@ class Providers(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/providers/{provider_id}", method="GET")
|
||||
@webmethod(route="/providers/{provider_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def inspect_provider(self, provider_id: str) -> ProviderInfo:
|
||||
"""Get detailed information about a specific provider.
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ from pydantic import BaseModel, Field
|
|||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
@ -97,7 +98,7 @@ class ShieldStore(Protocol):
|
|||
class Safety(Protocol):
|
||||
shield_store: ShieldStore
|
||||
|
||||
@webmethod(route="/safety/run-shield", method="POST")
|
||||
@webmethod(route="/safety/run-shield", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
|
@ -113,7 +114,7 @@ class Safety(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/moderations", method="POST")
|
||||
@webmethod(route="/moderations", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
||||
"""Classifies if text and/or image inputs are potentially harmful.
|
||||
:param input: Input (or inputs) to classify.
|
||||
|
|
|
@ -9,6 +9,7 @@ from typing import Any, Protocol, runtime_checkable
|
|||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
# mapping of metric to value
|
||||
|
@ -61,7 +62,7 @@ class ScoringFunctionStore(Protocol):
|
|||
class Scoring(Protocol):
|
||||
scoring_function_store: ScoringFunctionStore
|
||||
|
||||
@webmethod(route="/scoring/score-batch", method="POST")
|
||||
@webmethod(route="/scoring/score-batch", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def score_batch(
|
||||
self,
|
||||
dataset_id: str,
|
||||
|
@ -77,7 +78,7 @@ class Scoring(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/scoring/score", method="POST")
|
||||
@webmethod(route="/scoring/score", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def score(
|
||||
self,
|
||||
input_rows: list[dict[str, Any]],
|
||||
|
|
|
@ -18,6 +18,7 @@ from pydantic import BaseModel, Field
|
|||
|
||||
from llama_stack.apis.common.type_system import ParamType
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
||||
|
@ -160,7 +161,7 @@ class ListScoringFunctionsResponse(BaseModel):
|
|||
|
||||
@runtime_checkable
|
||||
class ScoringFunctions(Protocol):
|
||||
@webmethod(route="/scoring-functions", method="GET")
|
||||
@webmethod(route="/scoring-functions", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
|
||||
"""List all scoring functions.
|
||||
|
||||
|
@ -168,7 +169,7 @@ class ScoringFunctions(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
|
||||
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn:
|
||||
"""Get a scoring function by its ID.
|
||||
|
||||
|
@ -177,7 +178,7 @@ class ScoringFunctions(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/scoring-functions", method="POST")
|
||||
@webmethod(route="/scoring-functions", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def register_scoring_function(
|
||||
self,
|
||||
scoring_fn_id: str,
|
||||
|
@ -197,3 +198,11 @@ class ScoringFunctions(Protocol):
|
|||
:param params: The parameters for the scoring function for benchmark eval, these can be overridden for app eval.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def unregister_scoring_function(self, scoring_fn_id: str) -> None:
|
||||
"""Unregister a scoring function.
|
||||
|
||||
:param scoring_fn_id: The ID of the scoring function to unregister.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -9,6 +9,7 @@ from typing import Any, Literal, Protocol, runtime_checkable
|
|||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
@ -49,7 +50,7 @@ class ListShieldsResponse(BaseModel):
|
|||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Shields(Protocol):
|
||||
@webmethod(route="/shields", method="GET")
|
||||
@webmethod(route="/shields", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_shields(self) -> ListShieldsResponse:
|
||||
"""List all shields.
|
||||
|
||||
|
@ -57,7 +58,7 @@ class Shields(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/shields/{identifier:path}", method="GET")
|
||||
@webmethod(route="/shields/{identifier:path}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_shield(self, identifier: str) -> Shield:
|
||||
"""Get a shield by its identifier.
|
||||
|
||||
|
@ -66,7 +67,7 @@ class Shields(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/shields", method="POST")
|
||||
@webmethod(route="/shields", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def register_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
|
@ -84,7 +85,7 @@ class Shields(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/shields/{identifier:path}", method="DELETE")
|
||||
@webmethod(route="/shields/{identifier:path}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def unregister_shield(self, identifier: str) -> None:
|
||||
"""Unregister a shield.
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ from typing import Any, Protocol
|
|||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
|
@ -59,7 +60,7 @@ class SyntheticDataGenerationResponse(BaseModel):
|
|||
|
||||
|
||||
class SyntheticDataGeneration(Protocol):
|
||||
@webmethod(route="/synthetic-data-generation/generate")
|
||||
@webmethod(route="/synthetic-data-generation/generate", level=LLAMA_STACK_API_V1)
|
||||
def synthetic_data_generate(
|
||||
self,
|
||||
dialogs: list[Message],
|
||||
|
|
|
@ -16,6 +16,7 @@ from typing import (
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.models.llama.datatypes import Primitive
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
@ -412,7 +413,7 @@ class QueryMetricsResponse(BaseModel):
|
|||
|
||||
@runtime_checkable
|
||||
class Telemetry(Protocol):
|
||||
@webmethod(route="/telemetry/events", method="POST")
|
||||
@webmethod(route="/telemetry/events", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def log_event(
|
||||
self,
|
||||
event: Event,
|
||||
|
@ -425,7 +426,7 @@ class Telemetry(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/traces", method="POST", required_scope=REQUIRED_SCOPE)
|
||||
@webmethod(route="/telemetry/traces", method="POST", required_scope=REQUIRED_SCOPE, level=LLAMA_STACK_API_V1)
|
||||
async def query_traces(
|
||||
self,
|
||||
attribute_filters: list[QueryCondition] | None = None,
|
||||
|
@ -443,7 +444,9 @@ class Telemetry(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET", required_scope=REQUIRED_SCOPE)
|
||||
@webmethod(
|
||||
route="/telemetry/traces/{trace_id:path}", method="GET", required_scope=REQUIRED_SCOPE, level=LLAMA_STACK_API_V1
|
||||
)
|
||||
async def get_trace(self, trace_id: str) -> Trace:
|
||||
"""Get a trace by its ID.
|
||||
|
||||
|
@ -453,7 +456,10 @@ class Telemetry(Protocol):
|
|||
...
|
||||
|
||||
@webmethod(
|
||||
route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET", required_scope=REQUIRED_SCOPE
|
||||
route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}",
|
||||
method="GET",
|
||||
required_scope=REQUIRED_SCOPE,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
async def get_span(self, trace_id: str, span_id: str) -> Span:
|
||||
"""Get a span by its ID.
|
||||
|
@ -464,7 +470,12 @@ class Telemetry(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/spans/{span_id:path}/tree", method="POST", required_scope=REQUIRED_SCOPE)
|
||||
@webmethod(
|
||||
route="/telemetry/spans/{span_id:path}/tree",
|
||||
method="POST",
|
||||
required_scope=REQUIRED_SCOPE,
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
async def get_span_tree(
|
||||
self,
|
||||
span_id: str,
|
||||
|
@ -480,7 +491,7 @@ class Telemetry(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/spans", method="POST", required_scope=REQUIRED_SCOPE)
|
||||
@webmethod(route="/telemetry/spans", method="POST", required_scope=REQUIRED_SCOPE, level=LLAMA_STACK_API_V1)
|
||||
async def query_spans(
|
||||
self,
|
||||
attribute_filters: list[QueryCondition],
|
||||
|
@ -496,7 +507,7 @@ class Telemetry(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/spans/export", method="POST")
|
||||
@webmethod(route="/telemetry/spans/export", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def save_spans_to_dataset(
|
||||
self,
|
||||
attribute_filters: list[QueryCondition],
|
||||
|
@ -513,7 +524,9 @@ class Telemetry(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/metrics/{metric_name}", method="POST", required_scope=REQUIRED_SCOPE)
|
||||
@webmethod(
|
||||
route="/telemetry/metrics/{metric_name}", method="POST", required_scope=REQUIRED_SCOPE, level=LLAMA_STACK_API_V1
|
||||
)
|
||||
async def query_metrics(
|
||||
self,
|
||||
metric_name: str,
|
||||
|
|
|
@ -11,6 +11,7 @@ from pydantic import BaseModel, Field, field_validator
|
|||
from typing_extensions import runtime_checkable
|
||||
|
||||
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
@ -185,7 +186,7 @@ class RAGQueryConfig(BaseModel):
|
|||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class RAGToolRuntime(Protocol):
|
||||
@webmethod(route="/tool-runtime/rag-tool/insert", method="POST")
|
||||
@webmethod(route="/tool-runtime/rag-tool/insert", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def insert(
|
||||
self,
|
||||
documents: list[RAGDocument],
|
||||
|
@ -200,7 +201,7 @@ class RAGToolRuntime(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/tool-runtime/rag-tool/query", method="POST")
|
||||
@webmethod(route="/tool-runtime/rag-tool/query", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def query(
|
||||
self,
|
||||
content: InterleavedContent,
|
||||
|
|
|
@ -12,6 +12,7 @@ from typing_extensions import runtime_checkable
|
|||
|
||||
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
@ -26,6 +27,8 @@ class ToolParameter(BaseModel):
|
|||
:param parameter_type: Type of the parameter (e.g., string, integer)
|
||||
:param description: Human-readable description of what the parameter does
|
||||
:param required: Whether this parameter is required for tool invocation
|
||||
:param items: Type of the elements when parameter_type is array
|
||||
:param title: (Optional) Title of the parameter
|
||||
:param default: (Optional) Default value for the parameter if not provided
|
||||
"""
|
||||
|
||||
|
@ -33,6 +36,8 @@ class ToolParameter(BaseModel):
|
|||
parameter_type: str
|
||||
description: str
|
||||
required: bool = Field(default=True)
|
||||
items: dict | None = None
|
||||
title: str | None = None
|
||||
default: Any | None = None
|
||||
|
||||
|
||||
|
@ -151,7 +156,7 @@ class ListToolDefsResponse(BaseModel):
|
|||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class ToolGroups(Protocol):
|
||||
@webmethod(route="/toolgroups", method="POST")
|
||||
@webmethod(route="/toolgroups", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def register_tool_group(
|
||||
self,
|
||||
toolgroup_id: str,
|
||||
|
@ -168,7 +173,7 @@ class ToolGroups(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="GET")
|
||||
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_tool_group(
|
||||
self,
|
||||
toolgroup_id: str,
|
||||
|
@ -180,7 +185,7 @@ class ToolGroups(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/toolgroups", method="GET")
|
||||
@webmethod(route="/toolgroups", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
||||
"""List tool groups with optional provider.
|
||||
|
||||
|
@ -188,7 +193,7 @@ class ToolGroups(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/tools", method="GET")
|
||||
@webmethod(route="/tools", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
|
||||
"""List tools with optional tool group.
|
||||
|
||||
|
@ -197,7 +202,7 @@ class ToolGroups(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/tools/{tool_name:path}", method="GET")
|
||||
@webmethod(route="/tools/{tool_name:path}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_tool(
|
||||
self,
|
||||
tool_name: str,
|
||||
|
@ -209,7 +214,7 @@ class ToolGroups(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE")
|
||||
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def unregister_toolgroup(
|
||||
self,
|
||||
toolgroup_id: str,
|
||||
|
@ -238,7 +243,7 @@ class ToolRuntime(Protocol):
|
|||
rag_tool: RAGToolRuntime | None = None
|
||||
|
||||
# TODO: This needs to be renamed once OPEN API generator name conflict issue is fixed.
|
||||
@webmethod(route="/tool-runtime/list-tools", method="GET")
|
||||
@webmethod(route="/tool-runtime/list-tools", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||
) -> ListToolDefsResponse:
|
||||
|
@ -250,7 +255,7 @@ class ToolRuntime(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/tool-runtime/invoke", method="POST")
|
||||
@webmethod(route="/tool-runtime/invoke", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
|
||||
"""Run a tool with the given arguments.
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ from typing import Literal, Protocol, runtime_checkable
|
|||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
@ -65,7 +66,7 @@ class ListVectorDBsResponse(BaseModel):
|
|||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class VectorDBs(Protocol):
|
||||
@webmethod(route="/vector-dbs", method="GET")
|
||||
@webmethod(route="/vector-dbs", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_vector_dbs(self) -> ListVectorDBsResponse:
|
||||
"""List all vector databases.
|
||||
|
||||
|
@ -73,7 +74,7 @@ class VectorDBs(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="GET")
|
||||
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def get_vector_db(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
|
@ -85,7 +86,7 @@ class VectorDBs(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/vector-dbs", method="POST")
|
||||
@webmethod(route="/vector-dbs", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
|
@ -107,7 +108,7 @@ class VectorDBs(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE")
|
||||
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE", level=LLAMA_STACK_API_V1)
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
"""Unregister a vector database.
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@ from pydantic import BaseModel, Field
|
|||
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
@ -317,7 +318,8 @@ class VectorStoreChunkingStrategyStatic(BaseModel):
|
|||
|
||||
|
||||
VectorStoreChunkingStrategy = Annotated[
|
||||
VectorStoreChunkingStrategyAuto | VectorStoreChunkingStrategyStatic, Field(discriminator="type")
|
||||
VectorStoreChunkingStrategyAuto | VectorStoreChunkingStrategyStatic,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(VectorStoreChunkingStrategy, name="VectorStoreChunkingStrategy")
|
||||
|
||||
|
@ -426,6 +428,44 @@ class VectorStoreFileDeleteResponse(BaseModel):
|
|||
deleted: bool = True
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreFileBatchObject(BaseModel):
|
||||
"""OpenAI Vector Store File Batch object.
|
||||
|
||||
:param id: Unique identifier for the file batch
|
||||
:param object: Object type identifier, always "vector_store.file_batch"
|
||||
:param created_at: Timestamp when the file batch was created
|
||||
:param vector_store_id: ID of the vector store containing the file batch
|
||||
:param status: Current processing status of the file batch
|
||||
:param file_counts: File processing status counts for the batch
|
||||
"""
|
||||
|
||||
id: str
|
||||
object: str = "vector_store.file_batch"
|
||||
created_at: int
|
||||
vector_store_id: str
|
||||
status: VectorStoreFileStatus
|
||||
file_counts: VectorStoreFileCounts
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VectorStoreFilesListInBatchResponse(BaseModel):
|
||||
"""Response from listing files in a vector store file batch.
|
||||
|
||||
:param object: Object type identifier, always "list"
|
||||
:param data: List of vector store file objects in the batch
|
||||
:param first_id: (Optional) ID of the first file in the list for pagination
|
||||
:param last_id: (Optional) ID of the last file in the list for pagination
|
||||
:param has_more: Whether there are more files available beyond this page
|
||||
"""
|
||||
|
||||
object: str = "list"
|
||||
data: list[VectorStoreFileObject]
|
||||
first_id: str | None = None
|
||||
last_id: str | None = None
|
||||
has_more: bool = False
|
||||
|
||||
|
||||
class VectorDBStore(Protocol):
|
||||
def get_vector_db(self, vector_db_id: str) -> VectorDB | None: ...
|
||||
|
||||
|
@ -437,7 +477,7 @@ class VectorIO(Protocol):
|
|||
|
||||
# this will just block now until chunks are inserted, but it should
|
||||
# probably return a Job instance which can be polled for completion
|
||||
@webmethod(route="/vector-io/insert", method="POST")
|
||||
@webmethod(route="/vector-io/insert", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def insert_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
|
@ -455,7 +495,7 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/vector-io/query", method="POST")
|
||||
@webmethod(route="/vector-io/query", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def query_chunks(
|
||||
self,
|
||||
vector_db_id: str,
|
||||
|
@ -472,7 +512,7 @@ class VectorIO(Protocol):
|
|||
...
|
||||
|
||||
# OpenAI Vector Stores API endpoints
|
||||
@webmethod(route="/openai/v1/vector_stores", method="POST")
|
||||
@webmethod(route="/vector_stores", method="POST", level=LLAMA_STACK_API_V1)
|
||||
async def openai_create_vector_store(
|
||||
self,
|
||||
name: str | None = None,
|
||||
|
@ -498,7 +538,7 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/vector_stores", method="GET")
|
||||
@webmethod(route="/vector_stores", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def openai_list_vector_stores(
|
||||
self,
|
||||
limit: int | None = 20,
|
||||
|
@ -516,7 +556,7 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}", method="GET")
|
||||
@webmethod(route="/vector_stores/{vector_store_id}", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def openai_retrieve_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
|
@ -528,7 +568,11 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}", method="POST")
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
async def openai_update_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
|
@ -546,7 +590,11 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}", method="DELETE")
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}",
|
||||
method="DELETE",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
async def openai_delete_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
|
@ -558,7 +606,11 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}/search", method="POST")
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/search",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
async def openai_search_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
|
@ -567,7 +619,9 @@ class VectorIO(Protocol):
|
|||
max_num_results: int | None = 10,
|
||||
ranking_options: SearchRankingOptions | None = None,
|
||||
rewrite_query: bool | None = False,
|
||||
search_mode: str | None = "vector", # Using str instead of Literal due to OpenAPI schema generator limitations
|
||||
search_mode: (
|
||||
str | None
|
||||
) = "vector", # Using str instead of Literal due to OpenAPI schema generator limitations
|
||||
) -> VectorStoreSearchResponsePage:
|
||||
"""Search for chunks in a vector store.
|
||||
|
||||
|
@ -584,7 +638,11 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}/files", method="POST")
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/files",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
async def openai_attach_file_to_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
|
@ -602,7 +660,11 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}/files", method="GET")
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/files",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
async def openai_list_files_in_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
|
@ -624,7 +686,11 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}", method="GET")
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/files/{file_id}",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
async def openai_retrieve_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
|
@ -638,7 +704,11 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}/content", method="GET")
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/files/{file_id}/content",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
async def openai_retrieve_vector_store_file_contents(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
|
@ -652,7 +722,11 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}", method="POST")
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/files/{file_id}",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
async def openai_update_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
|
@ -668,7 +742,11 @@ class VectorIO(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}", method="DELETE")
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/files/{file_id}",
|
||||
method="DELETE",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
async def openai_delete_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
|
@ -681,3 +759,89 @@ class VectorIO(Protocol):
|
|||
:returns: A VectorStoreFileDeleteResponse indicating the deletion status.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/file_batches",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
async def openai_create_vector_store_file_batch(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_ids: list[str],
|
||||
attributes: dict[str, Any] | None = None,
|
||||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
) -> VectorStoreFileBatchObject:
|
||||
"""Create a vector store file batch.
|
||||
|
||||
:param vector_store_id: The ID of the vector store to create the file batch for.
|
||||
:param file_ids: A list of File IDs that the vector store should use.
|
||||
:param attributes: (Optional) Key-value attributes to store with the files.
|
||||
:param chunking_strategy: (Optional) The chunking strategy used to chunk the file(s). Defaults to auto.
|
||||
:returns: A VectorStoreFileBatchObject representing the created file batch.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/file_batches/{batch_id}",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
async def openai_retrieve_vector_store_file_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreFileBatchObject:
|
||||
"""Retrieve a vector store file batch.
|
||||
|
||||
:param batch_id: The ID of the file batch to retrieve.
|
||||
:param vector_store_id: The ID of the vector store containing the file batch.
|
||||
:returns: A VectorStoreFileBatchObject representing the file batch.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/file_batches/{batch_id}/files",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
async def openai_list_files_in_vector_store_file_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
vector_store_id: str,
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
filter: str | None = None,
|
||||
limit: int | None = 20,
|
||||
order: str | None = "desc",
|
||||
) -> VectorStoreFilesListInBatchResponse:
|
||||
"""Returns a list of vector store files in a batch.
|
||||
|
||||
:param batch_id: The ID of the file batch to list files from.
|
||||
:param vector_store_id: The ID of the vector store containing the file batch.
|
||||
:param after: A cursor for use in pagination. `after` is an object ID that defines your place in the list.
|
||||
:param before: A cursor for use in pagination. `before` is an object ID that defines your place in the list.
|
||||
:param filter: Filter by file status. One of in_progress, completed, failed, cancelled.
|
||||
:param limit: A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.
|
||||
:param order: Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and `desc` for descending order.
|
||||
:returns: A VectorStoreFilesListInBatchResponse containing the list of files in the batch.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/vector_stores/{vector_store_id}/file_batches/{batch_id}/cancel",
|
||||
method="POST",
|
||||
level=LLAMA_STACK_API_V1,
|
||||
)
|
||||
async def openai_cancel_vector_store_file_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreFileBatchObject:
|
||||
"""Cancels a vector store file batch.
|
||||
|
||||
:param batch_id: The ID of the file batch to cancel.
|
||||
:param vector_store_id: The ID of the vector store containing the file batch.
|
||||
:returns: A VectorStoreFileBatchObject representing the cancelled file batch.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -4,4 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
LLAMA_STACK_API_VERSION = "v1"
|
||||
LLAMA_STACK_API_V1 = "v1"
|
||||
LLAMA_STACK_API_V1BETA = "v1beta"
|
||||
LLAMA_STACK_API_V1ALPHA = "v1alpha"
|
||||
|
|
|
@ -48,15 +48,12 @@ def setup_verify_download_parser(parser: argparse.ArgumentParser) -> None:
|
|||
parser.set_defaults(func=partial(run_verify_cmd, parser=parser))
|
||||
|
||||
|
||||
def calculate_md5(filepath: Path, chunk_size: int = 8192) -> str:
|
||||
# NOTE: MD5 is used here only for download integrity verification,
|
||||
# not for security purposes
|
||||
# TODO: switch to SHA256
|
||||
md5_hash = hashlib.md5(usedforsecurity=False)
|
||||
def calculate_sha256(filepath: Path, chunk_size: int = 8192) -> str:
|
||||
sha256_hash = hashlib.sha256()
|
||||
with open(filepath, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(chunk_size), b""):
|
||||
md5_hash.update(chunk)
|
||||
return md5_hash.hexdigest()
|
||||
sha256_hash.update(chunk)
|
||||
return sha256_hash.hexdigest()
|
||||
|
||||
|
||||
def load_checksums(checklist_path: Path) -> dict[str, str]:
|
||||
|
@ -64,10 +61,10 @@ def load_checksums(checklist_path: Path) -> dict[str, str]:
|
|||
with open(checklist_path) as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
md5sum, filepath = line.strip().split(" ", 1)
|
||||
sha256sum, filepath = line.strip().split(" ", 1)
|
||||
# Remove leading './' if present
|
||||
filepath = filepath.lstrip("./")
|
||||
checksums[filepath] = md5sum
|
||||
checksums[filepath] = sha256sum
|
||||
return checksums
|
||||
|
||||
|
||||
|
@ -88,7 +85,7 @@ def verify_files(model_dir: Path, checksums: dict[str, str], console: Console) -
|
|||
matches = False
|
||||
|
||||
if exists:
|
||||
actual_hash = calculate_md5(full_path)
|
||||
actual_hash = calculate_sha256(full_path)
|
||||
matches = actual_hash == expected_hash
|
||||
|
||||
results.append(
|
||||
|
|
|
@ -147,7 +147,7 @@ WORKDIR /app
|
|||
|
||||
RUN dnf -y update && dnf install -y iputils git net-tools wget \
|
||||
vim-minimal python3.12 python3.12-pip python3.12-wheel \
|
||||
python3.12-setuptools python3.12-devel gcc make && \
|
||||
python3.12-setuptools python3.12-devel gcc gcc-c++ make && \
|
||||
ln -s /bin/pip3.12 /bin/pip && ln -s /bin/python3.12 /bin/python && dnf clean all
|
||||
|
||||
ENV UV_SYSTEM_PYTHON=1
|
||||
|
@ -164,7 +164,7 @@ RUN apt-get update && apt-get install -y \
|
|||
procps psmisc lsof \
|
||||
traceroute \
|
||||
bubblewrap \
|
||||
gcc \
|
||||
gcc g++ \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
ENV UV_SYSTEM_PYTHON=1
|
||||
|
|
|
@ -15,7 +15,6 @@ import httpx
|
|||
from pydantic import BaseModel, parse_obj_as
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
||||
from llama_stack.providers.datatypes import RemoteProviderConfig
|
||||
|
||||
_CLIENT_CLASSES = {}
|
||||
|
@ -114,7 +113,24 @@ def create_api_client_class(protocol) -> type:
|
|||
break
|
||||
kwargs[param.name] = args[i]
|
||||
|
||||
url = f"{self.base_url}/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
|
||||
# Get all webmethods for this method (supports multiple decorators)
|
||||
webmethods = getattr(method, "__webmethods__", [])
|
||||
|
||||
if not webmethods:
|
||||
raise RuntimeError(f"Method {method} has no webmethod decorators")
|
||||
|
||||
# Choose the preferred webmethod (non-deprecated if available)
|
||||
preferred_webmethod = None
|
||||
for wm in webmethods:
|
||||
if not getattr(wm, "deprecated", False):
|
||||
preferred_webmethod = wm
|
||||
break
|
||||
|
||||
# If no non-deprecated found, use the first one
|
||||
if preferred_webmethod is None:
|
||||
preferred_webmethod = webmethods[0]
|
||||
|
||||
url = f"{self.base_url}/{preferred_webmethod.level}/{preferred_webmethod.route.lstrip('/')}"
|
||||
|
||||
def convert(value):
|
||||
if isinstance(value, list):
|
||||
|
|
|
@ -121,10 +121,6 @@ class AutoRoutedProviderSpec(ProviderSpec):
|
|||
default=None,
|
||||
)
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> list[str]:
|
||||
raise AssertionError("Should not be called on AutoRoutedProviderSpec")
|
||||
|
||||
|
||||
# Example: /models, /shields
|
||||
class RoutingTableProviderSpec(ProviderSpec):
|
||||
|
@ -437,6 +433,12 @@ class InferenceStoreConfig(BaseModel):
|
|||
num_writers: int = Field(default=4, description="Number of concurrent background writers")
|
||||
|
||||
|
||||
class ResponsesStoreConfig(BaseModel):
|
||||
sql_store_config: SqlStoreConfig
|
||||
max_write_queue_size: int = Field(default=10000, description="Max queued writes for responses store")
|
||||
num_writers: int = Field(default=4, description="Number of concurrent background writers")
|
||||
|
||||
|
||||
class StackRunConfig(BaseModel):
|
||||
version: int = LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
|
||||
|
|
|
@ -16,16 +16,18 @@ from llama_stack.core.datatypes import BuildConfig, DistributionSpec
|
|||
from llama_stack.core.external import load_external_apis
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
RemoteProviderSpec,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
INTERNAL_APIS = {Api.inspect, Api.providers, Api.prompts}
|
||||
|
||||
|
||||
def stack_apis() -> list[Api]:
|
||||
return list(Api)
|
||||
|
||||
|
@ -70,31 +72,16 @@ def builtin_automatically_routed_apis() -> list[AutoRoutedApiInfo]:
|
|||
|
||||
def providable_apis() -> list[Api]:
|
||||
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
||||
return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers]
|
||||
return [api for api in Api if api not in routing_table_apis and api not in INTERNAL_APIS]
|
||||
|
||||
|
||||
def _load_remote_provider_spec(spec_data: dict[str, Any], api: Api) -> ProviderSpec:
|
||||
adapter = AdapterSpec(**spec_data["adapter"])
|
||||
spec = remote_provider_spec(
|
||||
api=api,
|
||||
adapter=adapter,
|
||||
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
|
||||
)
|
||||
spec = RemoteProviderSpec(api=api, provider_type=f"remote::{spec_data['adapter_type']}", **spec_data)
|
||||
return spec
|
||||
|
||||
|
||||
def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_name: str) -> ProviderSpec:
|
||||
spec = InlineProviderSpec(
|
||||
api=api,
|
||||
provider_type=f"inline::{provider_name}",
|
||||
pip_packages=spec_data.get("pip_packages", []),
|
||||
module=spec_data["module"],
|
||||
config_class=spec_data["config_class"],
|
||||
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
|
||||
optional_api_dependencies=[Api(dep) for dep in spec_data.get("optional_api_dependencies", [])],
|
||||
provider_data_validator=spec_data.get("provider_data_validator"),
|
||||
container_image=spec_data.get("container_image"),
|
||||
)
|
||||
spec = InlineProviderSpec(api=api, provider_type=f"inline::{provider_name}", **spec_data)
|
||||
return spec
|
||||
|
||||
|
||||
|
|
|
@ -40,7 +40,7 @@ from llama_stack.core.request_headers import (
|
|||
from llama_stack.core.resolver import ProviderRegistry
|
||||
from llama_stack.core.server.routes import RouteImpls, find_matching_route, initialize_route_impls
|
||||
from llama_stack.core.stack import (
|
||||
construct_stack,
|
||||
Stack,
|
||||
get_stack_run_config_from_distro,
|
||||
replace_env_vars,
|
||||
)
|
||||
|
@ -252,7 +252,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
|
||||
try:
|
||||
self.route_impls = None
|
||||
self.impls = await construct_stack(self.config, self.custom_provider_registry)
|
||||
|
||||
stack = Stack(self.config, self.custom_provider_registry)
|
||||
await stack.initialize()
|
||||
self.impls = stack.impls
|
||||
except ModuleNotFoundError as _e:
|
||||
cprint(_e.msg, color="red", file=sys.stderr)
|
||||
cprint(
|
||||
|
@ -289,6 +292,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
)
|
||||
raise _e
|
||||
|
||||
assert self.impls is not None
|
||||
if Api.telemetry in self.impls:
|
||||
setup_logger(self.impls[Api.telemetry])
|
||||
|
||||
|
|
|
@ -29,6 +29,7 @@ from llama_stack.apis.telemetry import Telemetry
|
|||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_dbs import VectorDBs
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
|
||||
from llama_stack.core.client import get_client_impl
|
||||
from llama_stack.core.datatypes import (
|
||||
AccessRule,
|
||||
|
@ -412,8 +413,14 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
|||
|
||||
mro = type(obj).__mro__
|
||||
for name, value in inspect.getmembers(protocol):
|
||||
if inspect.isfunction(value) and hasattr(value, "__webmethod__"):
|
||||
if value.__webmethod__.experimental:
|
||||
if inspect.isfunction(value) and hasattr(value, "__webmethods__"):
|
||||
has_alpha_api = False
|
||||
for webmethod in value.__webmethods__:
|
||||
if webmethod.level == LLAMA_STACK_API_V1ALPHA:
|
||||
has_alpha_api = True
|
||||
break
|
||||
# if this API has multiple webmethods, and one of them is an alpha API, this API should be skipped when checking for missing or not callable routes
|
||||
if has_alpha_api:
|
||||
continue
|
||||
if not hasattr(obj, name):
|
||||
missing_methods.append((name, "missing"))
|
||||
|
|
|
@ -16,20 +16,15 @@ from pydantic import Field, TypeAdapter
|
|||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
|
||||
from llama_stack.apis.inference import (
|
||||
BatchChatCompletionResponse,
|
||||
BatchCompletionResponse,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
ListOpenAIChatCompletionResponse,
|
||||
LogProbConfig,
|
||||
|
@ -50,7 +45,6 @@ from llama_stack.apis.inference import (
|
|||
ResponseFormat,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -273,30 +267,6 @@ class InferenceRouter(Inference):
|
|||
)
|
||||
return response
|
||||
|
||||
async def batch_chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages_batch: list[list[Message]],
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> BatchChatCompletionResponse:
|
||||
logger.debug(
|
||||
f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl(model_id)
|
||||
return await provider.batch_chat_completion(
|
||||
model_id=model_id,
|
||||
messages_batch=messages_batch,
|
||||
tools=tools,
|
||||
tool_config=tool_config,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
@ -338,39 +308,6 @@ class InferenceRouter(Inference):
|
|||
|
||||
return response
|
||||
|
||||
async def batch_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content_batch: list[InterleavedContent],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> BatchCompletionResponse:
|
||||
logger.debug(
|
||||
f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl(model_id)
|
||||
return await provider.batch_completion(model_id, content_batch, sampling_params, response_format, logprobs)
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
logger.debug(f"InferenceRouter.embeddings: {model_id}")
|
||||
await self._get_model(model_id, ModelType.embedding)
|
||||
provider = await self.routing_table.get_provider_impl(model_id)
|
||||
return await provider.embeddings(
|
||||
model_id=model_id,
|
||||
contents=contents,
|
||||
text_truncation=text_truncation,
|
||||
output_dimension=output_dimension,
|
||||
task_type=task_type,
|
||||
)
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -8,9 +8,7 @@ import asyncio
|
|||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
)
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
|
@ -19,9 +17,11 @@ from llama_stack.apis.vector_io import (
|
|||
VectorIO,
|
||||
VectorStoreChunkingStrategy,
|
||||
VectorStoreDeleteResponse,
|
||||
VectorStoreFileBatchObject,
|
||||
VectorStoreFileContentsResponse,
|
||||
VectorStoreFileDeleteResponse,
|
||||
VectorStoreFileObject,
|
||||
VectorStoreFilesListInBatchResponse,
|
||||
VectorStoreFileStatus,
|
||||
VectorStoreListResponse,
|
||||
VectorStoreObject,
|
||||
|
@ -193,7 +193,10 @@ class VectorIORouter(VectorIO):
|
|||
all_stores = all_stores[after_index + 1 :]
|
||||
|
||||
if before:
|
||||
before_index = next((i for i, store in enumerate(all_stores) if store.id == before), len(all_stores))
|
||||
before_index = next(
|
||||
(i for i, store in enumerate(all_stores) if store.id == before),
|
||||
len(all_stores),
|
||||
)
|
||||
all_stores = all_stores[:before_index]
|
||||
|
||||
# Apply limit
|
||||
|
@ -363,3 +366,61 @@ class VectorIORouter(VectorIO):
|
|||
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
|
||||
)
|
||||
return health_statuses
|
||||
|
||||
async def openai_create_vector_store_file_batch(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_ids: list[str],
|
||||
attributes: dict[str, Any] | None = None,
|
||||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
) -> VectorStoreFileBatchObject:
|
||||
logger.debug(f"VectorIORouter.openai_create_vector_store_file_batch: {vector_store_id}, {len(file_ids)} files")
|
||||
return await self.routing_table.openai_create_vector_store_file_batch(
|
||||
vector_store_id=vector_store_id,
|
||||
file_ids=file_ids,
|
||||
attributes=attributes,
|
||||
chunking_strategy=chunking_strategy,
|
||||
)
|
||||
|
||||
async def openai_retrieve_vector_store_file_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreFileBatchObject:
|
||||
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_batch: {batch_id}, {vector_store_id}")
|
||||
return await self.routing_table.openai_retrieve_vector_store_file_batch(
|
||||
batch_id=batch_id,
|
||||
vector_store_id=vector_store_id,
|
||||
)
|
||||
|
||||
async def openai_list_files_in_vector_store_file_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
vector_store_id: str,
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
filter: str | None = None,
|
||||
limit: int | None = 20,
|
||||
order: str | None = "desc",
|
||||
) -> VectorStoreFilesListInBatchResponse:
|
||||
logger.debug(f"VectorIORouter.openai_list_files_in_vector_store_file_batch: {batch_id}, {vector_store_id}")
|
||||
return await self.routing_table.openai_list_files_in_vector_store_file_batch(
|
||||
batch_id=batch_id,
|
||||
vector_store_id=vector_store_id,
|
||||
after=after,
|
||||
before=before,
|
||||
filter=filter,
|
||||
limit=limit,
|
||||
order=order,
|
||||
)
|
||||
|
||||
async def openai_cancel_vector_store_file_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreFileBatchObject:
|
||||
logger.debug(f"VectorIORouter.openai_cancel_vector_store_file_batch: {batch_id}, {vector_store_id}")
|
||||
return await self.routing_table.openai_cancel_vector_store_file_batch(
|
||||
batch_id=batch_id,
|
||||
vector_store_id=vector_store_id,
|
||||
)
|
||||
|
|
|
@ -56,3 +56,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
|||
provider_resource_id=provider_benchmark_id,
|
||||
)
|
||||
await self.register_object(benchmark)
|
||||
|
||||
async def unregister_benchmark(self, benchmark_id: str) -> None:
|
||||
existing_benchmark = await self.get_benchmark(benchmark_id)
|
||||
await self.unregister_object(existing_benchmark)
|
||||
|
|
|
@ -64,6 +64,10 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
|||
return await p.unregister_shield(obj.identifier)
|
||||
elif api == Api.datasetio:
|
||||
return await p.unregister_dataset(obj.identifier)
|
||||
elif api == Api.eval:
|
||||
return await p.unregister_benchmark(obj.identifier)
|
||||
elif api == Api.scoring:
|
||||
return await p.unregister_scoring_function(obj.identifier)
|
||||
elif api == Api.tool_runtime:
|
||||
return await p.unregister_toolgroup(obj.identifier)
|
||||
else:
|
||||
|
|
|
@ -33,7 +33,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
try:
|
||||
models = await provider.list_models()
|
||||
except Exception as e:
|
||||
logger.exception(f"Model refresh failed for provider {provider_id}: {e}")
|
||||
logger.warning(f"Model refresh failed for provider {provider_id}: {e}")
|
||||
continue
|
||||
|
||||
self.listed_providers.add(provider_id)
|
||||
|
|
|
@ -60,3 +60,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
|||
)
|
||||
scoring_fn.provider_id = provider_id
|
||||
await self.register_object(scoring_fn)
|
||||
|
||||
async def unregister_scoring_function(self, scoring_fn_id: str) -> None:
|
||||
existing_scoring_fn = await self.get_scoring_function(scoring_fn_id)
|
||||
await self.unregister_object(existing_scoring_fn)
|
||||
|
|
|
@ -9,7 +9,7 @@ from typing import Any
|
|||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.common.errors import ToolGroupNotFoundError
|
||||
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
|
||||
from llama_stack.core.datatypes import ToolGroupWithOwner
|
||||
from llama_stack.core.datatypes import AuthenticationRequiredError, ToolGroupWithOwner
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
@ -54,7 +54,18 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
all_tools = []
|
||||
for toolgroup in toolgroups:
|
||||
if toolgroup.identifier not in self.toolgroups_to_tools:
|
||||
await self._index_tools(toolgroup)
|
||||
try:
|
||||
await self._index_tools(toolgroup)
|
||||
except AuthenticationRequiredError:
|
||||
# Send authentication errors back to the client so it knows
|
||||
# that it needs to supply credentials for remote MCP servers.
|
||||
raise
|
||||
except Exception as e:
|
||||
# Other errors that the client cannot fix are logged and
|
||||
# those specific toolgroups are skipped.
|
||||
logger.warning(f"Error listing tools for toolgroup {toolgroup.identifier}: {e}")
|
||||
logger.debug(e, exc_info=True)
|
||||
continue
|
||||
all_tools.extend(self.toolgroups_to_tools[toolgroup.identifier])
|
||||
|
||||
return ListToolsResponse(data=all_tools)
|
||||
|
|
|
@ -14,7 +14,6 @@ from starlette.routing import Route
|
|||
|
||||
from llama_stack.apis.datatypes import Api, ExternalApiSpec
|
||||
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
||||
from llama_stack.core.resolver import api_protocol_map
|
||||
from llama_stack.schema_utils import WebMethod
|
||||
|
||||
|
@ -54,22 +53,23 @@ def get_all_api_routes(
|
|||
protocol_methods.append((f"{tool_group.value}.{name}", method))
|
||||
|
||||
for name, method in protocol_methods:
|
||||
if not hasattr(method, "__webmethod__"):
|
||||
# Get all webmethods for this method (supports multiple decorators)
|
||||
webmethods = getattr(method, "__webmethods__", [])
|
||||
if not webmethods:
|
||||
continue
|
||||
|
||||
# The __webmethod__ attribute is dynamically added by the @webmethod decorator
|
||||
# mypy doesn't know about this dynamic attribute, so we ignore the attr-defined error
|
||||
webmethod = method.__webmethod__ # type: ignore[attr-defined]
|
||||
path = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
|
||||
if webmethod.method == hdrs.METH_GET:
|
||||
http_method = hdrs.METH_GET
|
||||
elif webmethod.method == hdrs.METH_DELETE:
|
||||
http_method = hdrs.METH_DELETE
|
||||
else:
|
||||
http_method = hdrs.METH_POST
|
||||
routes.append(
|
||||
(Route(path=path, methods=[http_method], name=name, endpoint=None), webmethod)
|
||||
) # setting endpoint to None since don't use a Router object
|
||||
# Create routes for each webmethod decorator
|
||||
for webmethod in webmethods:
|
||||
path = f"/{webmethod.level}/{webmethod.route.lstrip('/')}"
|
||||
if webmethod.method == hdrs.METH_GET:
|
||||
http_method = hdrs.METH_GET
|
||||
elif webmethod.method == hdrs.METH_DELETE:
|
||||
http_method = hdrs.METH_DELETE
|
||||
else:
|
||||
http_method = hdrs.METH_POST
|
||||
routes.append(
|
||||
(Route(path=path, methods=[http_method], name=name, endpoint=None), webmethod)
|
||||
) # setting endpoint to None since don't use a Router object
|
||||
|
||||
apis[api] = routes
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
import argparse
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
|
@ -24,7 +25,6 @@ from typing import Annotated, Any, get_origin
|
|||
import httpx
|
||||
import rich.pretty
|
||||
import yaml
|
||||
from aiohttp import hdrs
|
||||
from fastapi import Body, FastAPI, HTTPException, Request, Response
|
||||
from fastapi import Path as FastapiPath
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
|
@ -44,23 +44,17 @@ from llama_stack.core.datatypes import (
|
|||
process_cors_config,
|
||||
)
|
||||
from llama_stack.core.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.core.external import ExternalApiSpec, load_external_apis
|
||||
from llama_stack.core.external import load_external_apis
|
||||
from llama_stack.core.request_headers import (
|
||||
PROVIDER_DATA_VAR,
|
||||
request_provider_data_context,
|
||||
user_from_scope,
|
||||
)
|
||||
from llama_stack.core.resolver import InvalidProviderError
|
||||
from llama_stack.core.server.routes import (
|
||||
find_matching_route,
|
||||
get_all_api_routes,
|
||||
initialize_route_impls,
|
||||
)
|
||||
from llama_stack.core.server.routes import get_all_api_routes
|
||||
from llama_stack.core.stack import (
|
||||
Stack,
|
||||
cast_image_name_to_string,
|
||||
construct_stack,
|
||||
replace_env_vars,
|
||||
shutdown_stack,
|
||||
validate_env_pair,
|
||||
)
|
||||
from llama_stack.core.utils.config import redact_sensitive_fields
|
||||
|
@ -74,13 +68,12 @@ from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
|
|||
)
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
CURRENT_TRACE_CONTEXT,
|
||||
end_trace,
|
||||
setup_logger,
|
||||
start_trace,
|
||||
)
|
||||
|
||||
from .auth import AuthenticationMiddleware
|
||||
from .quota import QuotaMiddleware
|
||||
from .tracing import TracingMiddleware
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
|
@ -156,21 +149,34 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
|
|||
)
|
||||
|
||||
|
||||
async def shutdown(app):
|
||||
"""Initiate a graceful shutdown of the application.
|
||||
|
||||
Handled by the lifespan context manager. The shutdown process involves
|
||||
shutting down all implementations registered in the application.
|
||||
class StackApp(FastAPI):
|
||||
"""
|
||||
await shutdown_stack(app.__llama_stack_impls__)
|
||||
A wrapper around the FastAPI application to hold a reference to the Stack instance so that we can
|
||||
start background tasks (e.g. refresh model registry periodically) from the lifespan context manager.
|
||||
"""
|
||||
|
||||
def __init__(self, config: StackRunConfig, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.stack: Stack = Stack(config)
|
||||
|
||||
# This code is called from a running event loop managed by uvicorn so we cannot simply call
|
||||
# asyncio.run() to initialize the stack. We cannot await either since this is not an async
|
||||
# function.
|
||||
# As a workaround, we use a thread pool executor to run the initialize() method
|
||||
# in a separate thread.
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(asyncio.run, self.stack.initialize())
|
||||
future.result()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
async def lifespan(app: StackApp):
|
||||
logger.info("Starting up")
|
||||
assert app.stack is not None
|
||||
app.stack.create_registry_refresh_task()
|
||||
yield
|
||||
logger.info("Shutting down")
|
||||
await shutdown(app)
|
||||
await app.stack.shutdown()
|
||||
|
||||
|
||||
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
||||
|
@ -287,65 +293,6 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
|||
return route_handler
|
||||
|
||||
|
||||
class TracingMiddleware:
|
||||
def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]):
|
||||
self.app = app
|
||||
self.impls = impls
|
||||
self.external_apis = external_apis
|
||||
# FastAPI built-in paths that should bypass custom routing
|
||||
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope.get("type") == "lifespan":
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
path = scope.get("path", "")
|
||||
|
||||
# Check if the path is a FastAPI built-in path
|
||||
if path.startswith(self.fastapi_paths):
|
||||
# Pass through to FastAPI's built-in handlers
|
||||
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
if not hasattr(self, "route_impls"):
|
||||
self.route_impls = initialize_route_impls(self.impls, self.external_apis)
|
||||
|
||||
try:
|
||||
_, _, route_path, webmethod = find_matching_route(
|
||||
scope.get("method", hdrs.METH_GET), path, self.route_impls
|
||||
)
|
||||
except ValueError:
|
||||
# If no matching endpoint is found, pass through to FastAPI
|
||||
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
trace_attributes = {"__location__": "server", "raw_path": path}
|
||||
|
||||
# Extract W3C trace context headers and store as trace attributes
|
||||
headers = dict(scope.get("headers", []))
|
||||
traceparent = headers.get(b"traceparent", b"").decode()
|
||||
if traceparent:
|
||||
trace_attributes["traceparent"] = traceparent
|
||||
tracestate = headers.get(b"tracestate", b"").decode()
|
||||
if tracestate:
|
||||
trace_attributes["tracestate"] = tracestate
|
||||
|
||||
trace_path = webmethod.descriptive_name or route_path
|
||||
trace_context = await start_trace(trace_path, trace_attributes)
|
||||
|
||||
async def send_with_trace_id(message):
|
||||
if message["type"] == "http.response.start":
|
||||
headers = message.get("headers", [])
|
||||
headers.append([b"x-trace-id", str(trace_context.trace_id).encode()])
|
||||
message["headers"] = headers
|
||||
await send(message)
|
||||
|
||||
try:
|
||||
return await self.app(scope, receive, send_with_trace_id)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
|
||||
class ClientVersionMiddleware:
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
|
@ -386,73 +333,61 @@ class ClientVersionMiddleware:
|
|||
return await self.app(scope, receive, send)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace | None = None):
|
||||
"""Start the LlamaStack server."""
|
||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||
def create_app(
|
||||
config_file: str | None = None,
|
||||
env_vars: list[str] | None = None,
|
||||
) -> StackApp:
|
||||
"""Create and configure the FastAPI application.
|
||||
|
||||
add_config_distro_args(parser)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
|
||||
help="Port to listen on",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--env",
|
||||
action="append",
|
||||
help="Environment variables in KEY=value format. Can be specified multiple times.",
|
||||
)
|
||||
Args:
|
||||
config_file: Path to config file. If None, uses LLAMA_STACK_CONFIG env var or default resolution.
|
||||
env_vars: List of environment variables in KEY=value format.
|
||||
disable_version_check: Whether to disable version checking. If None, uses LLAMA_STACK_DISABLE_VERSION_CHECK env var.
|
||||
|
||||
# Determine whether the server args are being passed by the "run" command, if this is the case
|
||||
# the args will be passed as a Namespace object to the main function, otherwise they will be
|
||||
# parsed from the command line
|
||||
if args is None:
|
||||
args = parser.parse_args()
|
||||
Returns:
|
||||
Configured StackApp instance.
|
||||
"""
|
||||
config_file = config_file or os.getenv("LLAMA_STACK_CONFIG")
|
||||
if config_file is None:
|
||||
raise ValueError("No config file provided and LLAMA_STACK_CONFIG env var is not set")
|
||||
|
||||
config_or_distro = get_config_from_args(args)
|
||||
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
|
||||
config_file = resolve_config_or_distro(config_file, Mode.RUN)
|
||||
|
||||
# Load and process configuration
|
||||
logger_config = None
|
||||
with open(config_file) as fp:
|
||||
config_contents = yaml.safe_load(fp)
|
||||
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||
logger_config = LoggingConfig(**cfg)
|
||||
logger = get_logger(name=__name__, category="core::server", config=logger_config)
|
||||
if args.env:
|
||||
for env_pair in args.env:
|
||||
|
||||
if env_vars:
|
||||
for env_pair in env_vars:
|
||||
try:
|
||||
key, value = validate_env_pair(env_pair)
|
||||
logger.info(f"Setting CLI environment variable {key} => {value}")
|
||||
logger.info(f"Setting environment variable {key} => {value}")
|
||||
os.environ[key] = value
|
||||
except ValueError as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
raise ValueError(f"Invalid environment variable format: {env_pair}") from e
|
||||
|
||||
config = replace_env_vars(config_contents)
|
||||
config = StackRunConfig(**cast_image_name_to_string(config))
|
||||
|
||||
_log_run_config(run_config=config)
|
||||
|
||||
app = FastAPI(
|
||||
app = StackApp(
|
||||
lifespan=lifespan,
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
openapi_url="/openapi.json",
|
||||
config=config,
|
||||
)
|
||||
|
||||
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
||||
app.add_middleware(ClientVersionMiddleware)
|
||||
|
||||
try:
|
||||
# Create and set the event loop that will be used for both construction and server runtime
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# Construct the stack in the persistent event loop
|
||||
impls = loop.run_until_complete(construct_stack(config))
|
||||
|
||||
except InvalidProviderError as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
impls = app.stack.impls
|
||||
|
||||
if config.server.auth:
|
||||
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
|
||||
|
@ -553,9 +488,54 @@ def main(args: argparse.Namespace | None = None):
|
|||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||
app.exception_handler(Exception)(global_exception_handler)
|
||||
|
||||
app.__llama_stack_impls__ = impls
|
||||
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def main(args: argparse.Namespace | None = None):
|
||||
"""Start the LlamaStack server."""
|
||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||
|
||||
add_config_distro_args(parser)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
|
||||
help="Port to listen on",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--env",
|
||||
action="append",
|
||||
help="Environment variables in KEY=value format. Can be specified multiple times.",
|
||||
)
|
||||
|
||||
# Determine whether the server args are being passed by the "run" command, if this is the case
|
||||
# the args will be passed as a Namespace object to the main function, otherwise they will be
|
||||
# parsed from the command line
|
||||
if args is None:
|
||||
args = parser.parse_args()
|
||||
|
||||
config_or_distro = get_config_from_args(args)
|
||||
|
||||
try:
|
||||
app = create_app(
|
||||
config_file=config_or_distro,
|
||||
env_vars=args.env,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating app: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
|
||||
with open(config_file) as fp:
|
||||
config_contents = yaml.safe_load(fp)
|
||||
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||
logger_config = LoggingConfig(**cfg)
|
||||
else:
|
||||
logger_config = None
|
||||
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
|
||||
|
||||
import uvicorn
|
||||
|
||||
# Configure SSL if certificates are provided
|
||||
|
@ -593,7 +573,6 @@ def main(args: argparse.Namespace | None = None):
|
|||
if ssl_config:
|
||||
uvicorn_config.update(ssl_config)
|
||||
|
||||
# Run uvicorn in the existing event loop to preserve background tasks
|
||||
# We need to catch KeyboardInterrupt because uvicorn's signal handling
|
||||
# re-raises SIGINT signals using signal.raise_signal(), which Python
|
||||
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
|
||||
|
@ -604,13 +583,9 @@ def main(args: argparse.Namespace | None = None):
|
|||
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
|
||||
# signal handling but this is quite intrusive and not worth the effort.
|
||||
try:
|
||||
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
|
||||
asyncio.run(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
logger.info("Received interrupt signal, shutting down gracefully...")
|
||||
finally:
|
||||
if not loop.is_closed():
|
||||
logger.debug("Closing event loop")
|
||||
loop.close()
|
||||
|
||||
|
||||
def _log_run_config(run_config: StackRunConfig):
|
||||
|
|
80
llama_stack/core/server/tracing.py
Normal file
80
llama_stack/core/server/tracing.py
Normal file
|
@ -0,0 +1,80 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from aiohttp import hdrs
|
||||
|
||||
from llama_stack.core.external import ExternalApiSpec
|
||||
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.telemetry.tracing import end_trace, start_trace
|
||||
|
||||
logger = get_logger(name=__name__, category="core::server")
|
||||
|
||||
|
||||
class TracingMiddleware:
|
||||
def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]):
|
||||
self.app = app
|
||||
self.impls = impls
|
||||
self.external_apis = external_apis
|
||||
# FastAPI built-in paths that should bypass custom routing
|
||||
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope.get("type") == "lifespan":
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
path = scope.get("path", "")
|
||||
|
||||
# Check if the path is a FastAPI built-in path
|
||||
if path.startswith(self.fastapi_paths):
|
||||
# Pass through to FastAPI's built-in handlers
|
||||
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
if not hasattr(self, "route_impls"):
|
||||
self.route_impls = initialize_route_impls(self.impls, self.external_apis)
|
||||
|
||||
try:
|
||||
_, _, route_path, webmethod = find_matching_route(
|
||||
scope.get("method", hdrs.METH_GET), path, self.route_impls
|
||||
)
|
||||
except ValueError:
|
||||
# If no matching endpoint is found, pass through to FastAPI
|
||||
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
# Log deprecation warning if route is deprecated
|
||||
if getattr(webmethod, "deprecated", False):
|
||||
logger.warning(
|
||||
f"DEPRECATED ROUTE USED: {scope.get('method', 'GET')} {path} - "
|
||||
f"This route is deprecated and may be removed in a future version. "
|
||||
f"Please check the docs for the supported version."
|
||||
)
|
||||
|
||||
trace_attributes = {"__location__": "server", "raw_path": path}
|
||||
|
||||
# Extract W3C trace context headers and store as trace attributes
|
||||
headers = dict(scope.get("headers", []))
|
||||
traceparent = headers.get(b"traceparent", b"").decode()
|
||||
if traceparent:
|
||||
trace_attributes["traceparent"] = traceparent
|
||||
tracestate = headers.get(b"tracestate", b"").decode()
|
||||
if tracestate:
|
||||
trace_attributes["tracestate"] = tracestate
|
||||
|
||||
trace_path = webmethod.descriptive_name or route_path
|
||||
trace_context = await start_trace(trace_path, trace_attributes)
|
||||
|
||||
async def send_with_trace_id(message):
|
||||
if message["type"] == "http.response.start":
|
||||
headers = message.get("headers", [])
|
||||
headers.append([b"x-trace-id", str(trace_context.trace_id).encode()])
|
||||
message["headers"] = headers
|
||||
await send(message)
|
||||
|
||||
try:
|
||||
return await self.app(scope, receive, send_with_trace_id)
|
||||
finally:
|
||||
await end_trace()
|
|
@ -14,7 +14,6 @@ from typing import Any
|
|||
import yaml
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.batch_inference import BatchInference
|
||||
from llama_stack.apis.benchmarks import Benchmarks
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
|
@ -54,7 +53,6 @@ class LlamaStack(
|
|||
Providers,
|
||||
VectorDBs,
|
||||
Inference,
|
||||
BatchInference,
|
||||
Agents,
|
||||
Safety,
|
||||
SyntheticDataGeneration,
|
||||
|
@ -315,78 +313,84 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
|
|||
impls[Api.prompts] = prompts_impl
|
||||
|
||||
|
||||
# Produces a stack of providers for the given run config. Not all APIs may be
|
||||
# asked for in the run config.
|
||||
async def construct_stack(
|
||||
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None
|
||||
) -> dict[Api, Any]:
|
||||
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
|
||||
from llama_stack.testing.inference_recorder import setup_inference_recording
|
||||
class Stack:
|
||||
def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
|
||||
self.run_config = run_config
|
||||
self.provider_registry = provider_registry
|
||||
self.impls = None
|
||||
|
||||
# Produces a stack of providers for the given run config. Not all APIs may be
|
||||
# asked for in the run config.
|
||||
async def initialize(self):
|
||||
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
|
||||
from llama_stack.testing.inference_recorder import setup_inference_recording
|
||||
|
||||
global TEST_RECORDING_CONTEXT
|
||||
TEST_RECORDING_CONTEXT = setup_inference_recording()
|
||||
if TEST_RECORDING_CONTEXT:
|
||||
TEST_RECORDING_CONTEXT.__enter__()
|
||||
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
|
||||
|
||||
dist_registry, _ = await create_dist_registry(self.run_config.metadata_store, self.run_config.image_name)
|
||||
policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else []
|
||||
impls = await resolve_impls(
|
||||
self.run_config, self.provider_registry or get_provider_registry(self.run_config), dist_registry, policy
|
||||
)
|
||||
|
||||
# Add internal implementations after all other providers are resolved
|
||||
add_internal_implementations(impls, self.run_config)
|
||||
|
||||
if Api.prompts in impls:
|
||||
await impls[Api.prompts].initialize()
|
||||
|
||||
await register_resources(self.run_config, impls)
|
||||
|
||||
await refresh_registry_once(impls)
|
||||
self.impls = impls
|
||||
|
||||
def create_registry_refresh_task(self):
|
||||
assert self.impls is not None, "Must call initialize() before starting"
|
||||
|
||||
global REGISTRY_REFRESH_TASK
|
||||
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(self.impls))
|
||||
|
||||
def cb(task):
|
||||
import traceback
|
||||
|
||||
if task.cancelled():
|
||||
logger.error("Model refresh task cancelled")
|
||||
elif task.exception():
|
||||
logger.error(f"Model refresh task failed: {task.exception()}")
|
||||
traceback.print_exception(task.exception())
|
||||
else:
|
||||
logger.debug("Model refresh task completed")
|
||||
|
||||
REGISTRY_REFRESH_TASK.add_done_callback(cb)
|
||||
|
||||
async def shutdown(self):
|
||||
for impl in self.impls.values():
|
||||
impl_name = impl.__class__.__name__
|
||||
logger.info(f"Shutting down {impl_name}")
|
||||
try:
|
||||
if hasattr(impl, "shutdown"):
|
||||
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
||||
else:
|
||||
logger.warning(f"No shutdown method for {impl_name}")
|
||||
except TimeoutError:
|
||||
logger.exception(f"Shutdown timeout for {impl_name}")
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
logger.exception(f"Failed to shutdown {impl_name}: {e}")
|
||||
|
||||
global TEST_RECORDING_CONTEXT
|
||||
TEST_RECORDING_CONTEXT = setup_inference_recording()
|
||||
if TEST_RECORDING_CONTEXT:
|
||||
TEST_RECORDING_CONTEXT.__enter__()
|
||||
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
|
||||
try:
|
||||
TEST_RECORDING_CONTEXT.__exit__(None, None, None)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during inference recording cleanup: {e}")
|
||||
|
||||
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
||||
policy = run_config.server.auth.access_policy if run_config.server.auth else []
|
||||
impls = await resolve_impls(
|
||||
run_config, provider_registry or get_provider_registry(run_config), dist_registry, policy
|
||||
)
|
||||
|
||||
# Add internal implementations after all other providers are resolved
|
||||
add_internal_implementations(impls, run_config)
|
||||
|
||||
if Api.prompts in impls:
|
||||
await impls[Api.prompts].initialize()
|
||||
|
||||
await register_resources(run_config, impls)
|
||||
|
||||
await refresh_registry_once(impls)
|
||||
|
||||
global REGISTRY_REFRESH_TASK
|
||||
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(impls))
|
||||
|
||||
def cb(task):
|
||||
import traceback
|
||||
|
||||
if task.cancelled():
|
||||
logger.error("Model refresh task cancelled")
|
||||
elif task.exception():
|
||||
logger.error(f"Model refresh task failed: {task.exception()}")
|
||||
traceback.print_exception(task.exception())
|
||||
else:
|
||||
logger.debug("Model refresh task completed")
|
||||
|
||||
REGISTRY_REFRESH_TASK.add_done_callback(cb)
|
||||
return impls
|
||||
|
||||
|
||||
async def shutdown_stack(impls: dict[Api, Any]):
|
||||
for impl in impls.values():
|
||||
impl_name = impl.__class__.__name__
|
||||
logger.info(f"Shutting down {impl_name}")
|
||||
try:
|
||||
if hasattr(impl, "shutdown"):
|
||||
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
||||
else:
|
||||
logger.warning(f"No shutdown method for {impl_name}")
|
||||
except TimeoutError:
|
||||
logger.exception(f"Shutdown timeout for {impl_name}")
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
logger.exception(f"Failed to shutdown {impl_name}: {e}")
|
||||
|
||||
global TEST_RECORDING_CONTEXT
|
||||
if TEST_RECORDING_CONTEXT:
|
||||
try:
|
||||
TEST_RECORDING_CONTEXT.__exit__(None, None, None)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during inference recording cleanup: {e}")
|
||||
|
||||
global REGISTRY_REFRESH_TASK
|
||||
if REGISTRY_REFRESH_TASK:
|
||||
REGISTRY_REFRESH_TASK.cancel()
|
||||
global REGISTRY_REFRESH_TASK
|
||||
if REGISTRY_REFRESH_TASK:
|
||||
REGISTRY_REFRESH_TASK.cancel()
|
||||
|
||||
|
||||
async def refresh_registry_once(impls: dict[Api, Any]):
|
||||
|
|
|
@ -123,6 +123,6 @@ if [[ "$env_type" == "venv" ]]; then
|
|||
$other_args
|
||||
elif [[ "$env_type" == "container" ]]; then
|
||||
echo -e "${RED}Warning: Llama Stack no longer supports running Containers via the 'llama stack run' command.${NC}"
|
||||
echo -e "Please refer to the documentation for more information: https://llama-stack.readthedocs.io/en/latest/distributions/building_distro.html#llama-stack-build"
|
||||
echo -e "Please refer to the documentation for more information: https://llamastack.github.io/latest/distributions/building_distro.html#llama-stack-build"
|
||||
exit 1
|
||||
fi
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
## Developer Setup
|
||||
|
||||
1. Start up Llama Stack API server. More details [here](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html).
|
||||
1. Start up Llama Stack API server. More details [here](https://llamastack.github.io/latest/getting_started/index.htmll).
|
||||
|
||||
```
|
||||
llama stack build --distro together --image-type venv
|
||||
|
|
|
@ -17,6 +17,7 @@ distribution_spec:
|
|||
- provider_type: remote::vertexai
|
||||
- provider_type: remote::groq
|
||||
- provider_type: remote::sambanova
|
||||
- provider_type: remote::azure
|
||||
- provider_type: inline::sentence-transformers
|
||||
vector_io:
|
||||
- provider_type: inline::faiss
|
||||
|
|
|
@ -81,6 +81,13 @@ providers:
|
|||
config:
|
||||
url: https://api.sambanova.ai/v1
|
||||
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||
- provider_id: ${env.AZURE_API_KEY:+azure}
|
||||
provider_type: remote::azure
|
||||
config:
|
||||
api_key: ${env.AZURE_API_KEY:=}
|
||||
api_base: ${env.AZURE_API_BASE:=}
|
||||
api_version: ${env.AZURE_API_VERSION:=}
|
||||
api_type: ${env.AZURE_API_TYPE:=}
|
||||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
vector_io:
|
||||
|
|
|
@ -23,6 +23,8 @@ distribution_spec:
|
|||
- provider_type: inline::basic
|
||||
tool_runtime:
|
||||
- provider_type: inline::rag-runtime
|
||||
files:
|
||||
- provider_type: inline::localfs
|
||||
image_type: venv
|
||||
additional_pip_packages:
|
||||
- aiosqlite
|
||||
|
|
|
@ -49,22 +49,22 @@ The deployed platform includes the NIM Proxy microservice, which is the service
|
|||
### Datasetio API: NeMo Data Store
|
||||
The NeMo Data Store microservice serves as the default file storage solution for the NeMo microservices platform. It exposts APIs compatible with the Hugging Face Hub client (`HfApi`), so you can use the client to interact with Data Store. The `NVIDIA_DATASETS_URL` environment variable should point to your NeMo Data Store endpoint.
|
||||
|
||||
See the {repopath}`NVIDIA Datasetio docs::llama_stack/providers/remote/datasetio/nvidia/README.md` for supported features and example usage.
|
||||
See the [NVIDIA Datasetio docs](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/datasetio/nvidia/README.md) for supported features and example usage.
|
||||
|
||||
### Eval API: NeMo Evaluator
|
||||
The NeMo Evaluator microservice supports evaluation of LLMs. Launching an Evaluation job with NeMo Evaluator requires an Evaluation Config (an object that contains metadata needed by the job). A Llama Stack Benchmark maps to an Evaluation Config, so registering a Benchmark creates an Evaluation Config in NeMo Evaluator. The `NVIDIA_EVALUATOR_URL` environment variable should point to your NeMo Microservices endpoint.
|
||||
|
||||
See the {repopath}`NVIDIA Eval docs::llama_stack/providers/remote/eval/nvidia/README.md` for supported features and example usage.
|
||||
See the [NVIDIA Eval docs](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/eval/nvidia/README.md) for supported features and example usage.
|
||||
|
||||
### Post-Training API: NeMo Customizer
|
||||
The NeMo Customizer microservice supports fine-tuning models. You can reference {repopath}`this list of supported models::llama_stack/providers/remote/post_training/nvidia/models.py` that can be fine-tuned using Llama Stack. The `NVIDIA_CUSTOMIZER_URL` environment variable should point to your NeMo Microservices endpoint.
|
||||
The NeMo Customizer microservice supports fine-tuning models. You can reference [this list of supported models](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/post_training/nvidia/models.py) that can be fine-tuned using Llama Stack. The `NVIDIA_CUSTOMIZER_URL` environment variable should point to your NeMo Microservices endpoint.
|
||||
|
||||
See the {repopath}`NVIDIA Post-Training docs::llama_stack/providers/remote/post_training/nvidia/README.md` for supported features and example usage.
|
||||
See the [NVIDIA Post-Training docs](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/post_training/nvidia/README.md) for supported features and example usage.
|
||||
|
||||
### Safety API: NeMo Guardrails
|
||||
The NeMo Guardrails microservice sits between your application and the LLM, and adds checks and content moderation to a model. The `GUARDRAILS_SERVICE_URL` environment variable should point to your NeMo Microservices endpoint.
|
||||
|
||||
See the {repopath}`NVIDIA Safety docs::llama_stack/providers/remote/safety/nvidia/README.md` for supported features and example usage.
|
||||
See the [NVIDIA Safety docs](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/safety/nvidia/README.md) for supported features and example usage.
|
||||
|
||||
## Deploying models
|
||||
In order to use a registered model with the Llama Stack APIs, ensure the corresponding NIM is deployed to your environment. For example, you can use the NIM Proxy microservice to deploy `meta/llama-3.2-1b-instruct`.
|
||||
|
@ -138,4 +138,4 @@ llama stack run ./run.yaml \
|
|||
```
|
||||
|
||||
## Example Notebooks
|
||||
For examples of how to use the NVIDIA Distribution to run inference, fine-tune, evaluate, and run safety checks on your LLMs, you can reference the example notebooks in {repopath}`docs/notebooks/nvidia`.
|
||||
For examples of how to use the NVIDIA Distribution to run inference, fine-tune, evaluate, and run safety checks on your LLMs, you can reference the example notebooks in [docs/notebooks/nvidia](https://github.com/meta-llama/llama-stack/tree/main/docs/notebooks/nvidia).
|
||||
|
|
|
@ -7,15 +7,15 @@
|
|||
from pathlib import Path
|
||||
|
||||
from llama_stack.core.datatypes import BuildProvider, ModelInput, Provider, ShieldInput, ToolGroupInput
|
||||
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings, get_model_registry
|
||||
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
|
||||
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
|
||||
from llama_stack.providers.remote.datasetio.nvidia import NvidiaDatasetIOConfig
|
||||
from llama_stack.providers.remote.eval.nvidia import NVIDIAEvalConfig
|
||||
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
||||
from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
|
||||
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
|
||||
|
||||
|
||||
def get_distribution_template() -> DistributionTemplate:
|
||||
def get_distribution_template(name: str = "nvidia") -> DistributionTemplate:
|
||||
providers = {
|
||||
"inference": [BuildProvider(provider_type="remote::nvidia")],
|
||||
"vector_io": [BuildProvider(provider_type="inline::faiss")],
|
||||
|
@ -30,6 +30,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
],
|
||||
"scoring": [BuildProvider(provider_type="inline::basic")],
|
||||
"tool_runtime": [BuildProvider(provider_type="inline::rag-runtime")],
|
||||
"files": [BuildProvider(provider_type="inline::localfs")],
|
||||
}
|
||||
|
||||
inference_provider = Provider(
|
||||
|
@ -52,6 +53,11 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
provider_type="remote::nvidia",
|
||||
config=NVIDIAEvalConfig.sample_run_config(),
|
||||
)
|
||||
files_provider = Provider(
|
||||
provider_id="meta-reference-files",
|
||||
provider_type="inline::localfs",
|
||||
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
)
|
||||
inference_model = ModelInput(
|
||||
model_id="${env.INFERENCE_MODEL}",
|
||||
provider_id="nvidia",
|
||||
|
@ -61,9 +67,6 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
provider_id="nvidia",
|
||||
)
|
||||
|
||||
available_models = {
|
||||
"nvidia": MODEL_ENTRIES,
|
||||
}
|
||||
default_tool_groups = [
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::rag",
|
||||
|
@ -71,23 +74,21 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
),
|
||||
]
|
||||
|
||||
default_models, _ = get_model_registry(available_models)
|
||||
return DistributionTemplate(
|
||||
name="nvidia",
|
||||
name=name,
|
||||
distro_type="self_hosted",
|
||||
description="Use NVIDIA NIM for running LLM inference, evaluation and safety",
|
||||
container_image=None,
|
||||
template_path=Path(__file__).parent / "doc_template.md",
|
||||
providers=providers,
|
||||
available_models_by_provider=available_models,
|
||||
run_configs={
|
||||
"run.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
"inference": [inference_provider],
|
||||
"datasetio": [datasetio_provider],
|
||||
"eval": [eval_provider],
|
||||
"files": [files_provider],
|
||||
},
|
||||
default_models=default_models,
|
||||
default_tool_groups=default_tool_groups,
|
||||
),
|
||||
"run-with-safety.yaml": RunConfigSettings(
|
||||
|
@ -97,6 +98,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
safety_provider,
|
||||
],
|
||||
"eval": [eval_provider],
|
||||
"files": [files_provider],
|
||||
},
|
||||
default_models=[inference_model, safety_model],
|
||||
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}", provider_id="nvidia")],
|
||||
|
|
|
@ -4,6 +4,7 @@ apis:
|
|||
- agents
|
||||
- datasetio
|
||||
- eval
|
||||
- files
|
||||
- inference
|
||||
- post_training
|
||||
- safety
|
||||
|
@ -88,6 +89,14 @@ providers:
|
|||
tool_runtime:
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/nvidia/files}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/files_metadata.db
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/registry.db
|
||||
|
|
|
@ -4,6 +4,7 @@ apis:
|
|||
- agents
|
||||
- datasetio
|
||||
- eval
|
||||
- files
|
||||
- inference
|
||||
- post_training
|
||||
- safety
|
||||
|
@ -77,96 +78,21 @@ providers:
|
|||
tool_runtime:
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/nvidia/files}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/files_metadata.db
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/registry.db
|
||||
inference_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: meta/llama3-8b-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama3-8b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama3-70b-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama3-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.1-8b-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.1-8b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.1-70b-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.1-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.1-405b-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.1-405b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.2-1b-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.2-1b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.2-3b-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.2-3b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.2-11b-vision-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.2-11b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.2-90b-vision-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.2-90b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.3-70b-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.3-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: nvidia/vila
|
||||
provider_id: nvidia
|
||||
provider_model_id: nvidia/vila
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 2048
|
||||
context_length: 8192
|
||||
model_id: nvidia/llama-3.2-nv-embedqa-1b-v2
|
||||
provider_id: nvidia
|
||||
provider_model_id: nvidia/llama-3.2-nv-embedqa-1b-v2
|
||||
model_type: embedding
|
||||
- metadata:
|
||||
embedding_dimension: 1024
|
||||
context_length: 512
|
||||
model_id: nvidia/nv-embedqa-e5-v5
|
||||
provider_id: nvidia
|
||||
provider_model_id: nvidia/nv-embedqa-e5-v5
|
||||
model_type: embedding
|
||||
- metadata:
|
||||
embedding_dimension: 4096
|
||||
context_length: 512
|
||||
model_id: nvidia/nv-embedqa-mistral-7b-v2
|
||||
provider_id: nvidia
|
||||
provider_model_id: nvidia/nv-embedqa-mistral-7b-v2
|
||||
model_type: embedding
|
||||
- metadata:
|
||||
embedding_dimension: 1024
|
||||
context_length: 512
|
||||
model_id: snowflake/arctic-embed-l
|
||||
provider_id: nvidia
|
||||
provider_model_id: snowflake/arctic-embed-l
|
||||
model_type: embedding
|
||||
models: []
|
||||
shields: []
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
|
|
|
@ -18,6 +18,7 @@ distribution_spec:
|
|||
- provider_type: remote::vertexai
|
||||
- provider_type: remote::groq
|
||||
- provider_type: remote::sambanova
|
||||
- provider_type: remote::azure
|
||||
- provider_type: inline::sentence-transformers
|
||||
vector_io:
|
||||
- provider_type: inline::faiss
|
||||
|
|
|
@ -81,6 +81,13 @@ providers:
|
|||
config:
|
||||
url: https://api.sambanova.ai/v1
|
||||
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||
- provider_id: ${env.AZURE_API_KEY:+azure}
|
||||
provider_type: remote::azure
|
||||
config:
|
||||
api_key: ${env.AZURE_API_KEY:=}
|
||||
api_base: ${env.AZURE_API_BASE:=}
|
||||
api_version: ${env.AZURE_API_VERSION:=}
|
||||
api_type: ${env.AZURE_API_TYPE:=}
|
||||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
vector_io:
|
||||
|
|
|
@ -18,6 +18,7 @@ distribution_spec:
|
|||
- provider_type: remote::vertexai
|
||||
- provider_type: remote::groq
|
||||
- provider_type: remote::sambanova
|
||||
- provider_type: remote::azure
|
||||
- provider_type: inline::sentence-transformers
|
||||
vector_io:
|
||||
- provider_type: inline::faiss
|
||||
|
|
|
@ -81,6 +81,13 @@ providers:
|
|||
config:
|
||||
url: https://api.sambanova.ai/v1
|
||||
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||
- provider_id: ${env.AZURE_API_KEY:+azure}
|
||||
provider_type: remote::azure
|
||||
config:
|
||||
api_key: ${env.AZURE_API_KEY:=}
|
||||
api_base: ${env.AZURE_API_BASE:=}
|
||||
api_version: ${env.AZURE_API_VERSION:=}
|
||||
api_type: ${env.AZURE_API_TYPE:=}
|
||||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
vector_io:
|
||||
|
|
|
@ -59,6 +59,7 @@ ENABLED_INFERENCE_PROVIDERS = [
|
|||
"cerebras",
|
||||
"nvidia",
|
||||
"bedrock",
|
||||
"azure",
|
||||
]
|
||||
|
||||
INFERENCE_PROVIDER_IDS = {
|
||||
|
@ -68,6 +69,7 @@ INFERENCE_PROVIDER_IDS = {
|
|||
"cerebras": "${env.CEREBRAS_API_KEY:+cerebras}",
|
||||
"nvidia": "${env.NVIDIA_API_KEY:+nvidia}",
|
||||
"vertexai": "${env.VERTEX_AI_PROJECT:+vertexai}",
|
||||
"azure": "${env.AZURE_API_KEY:+azure}",
|
||||
}
|
||||
|
||||
|
||||
|
@ -76,12 +78,12 @@ def get_remote_inference_providers() -> list[Provider]:
|
|||
remote_providers = [
|
||||
provider
|
||||
for provider in available_providers()
|
||||
if isinstance(provider, RemoteProviderSpec) and provider.adapter.adapter_type in ENABLED_INFERENCE_PROVIDERS
|
||||
if isinstance(provider, RemoteProviderSpec) and provider.adapter_type in ENABLED_INFERENCE_PROVIDERS
|
||||
]
|
||||
|
||||
inference_providers = []
|
||||
for provider_spec in remote_providers:
|
||||
provider_type = provider_spec.adapter.adapter_type
|
||||
provider_type = provider_spec.adapter_type
|
||||
|
||||
if provider_type in INFERENCE_PROVIDER_IDS:
|
||||
provider_id = INFERENCE_PROVIDER_IDS[provider_type]
|
||||
|
@ -277,5 +279,21 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
|
|||
"http://localhost:11434",
|
||||
"Ollama URL",
|
||||
),
|
||||
"AZURE_API_KEY": (
|
||||
"",
|
||||
"Azure API Key",
|
||||
),
|
||||
"AZURE_API_BASE": (
|
||||
"",
|
||||
"Azure API Base",
|
||||
),
|
||||
"AZURE_API_VERSION": (
|
||||
"",
|
||||
"Azure API Version",
|
||||
),
|
||||
"AZURE_API_TYPE": (
|
||||
"azure",
|
||||
"Azure API Type",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
|
|
@ -10,6 +10,7 @@ apis:
|
|||
- telemetry
|
||||
- tool_runtime
|
||||
- vector_io
|
||||
- files
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: watsonx
|
||||
|
@ -94,6 +95,14 @@ providers:
|
|||
provider_type: inline::rag-runtime
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/watsonx/files}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/files_metadata.db
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/registry.db
|
||||
|
|
|
@ -9,6 +9,7 @@ from pathlib import Path
|
|||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.datatypes import BuildProvider, ModelInput, Provider, ToolGroupInput
|
||||
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings, get_model_registry
|
||||
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
|
||||
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||
SentenceTransformersInferenceConfig,
|
||||
)
|
||||
|
@ -16,7 +17,7 @@ from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
|
|||
from llama_stack.providers.remote.inference.watsonx.models import MODEL_ENTRIES
|
||||
|
||||
|
||||
def get_distribution_template() -> DistributionTemplate:
|
||||
def get_distribution_template(name: str = "watsonx") -> DistributionTemplate:
|
||||
providers = {
|
||||
"inference": [
|
||||
BuildProvider(provider_type="remote::watsonx"),
|
||||
|
@ -42,6 +43,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
BuildProvider(provider_type="inline::rag-runtime"),
|
||||
BuildProvider(provider_type="remote::model-context-protocol"),
|
||||
],
|
||||
"files": [BuildProvider(provider_type="inline::localfs")],
|
||||
}
|
||||
|
||||
inference_provider = Provider(
|
||||
|
@ -79,9 +81,14 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
},
|
||||
)
|
||||
|
||||
files_provider = Provider(
|
||||
provider_id="meta-reference-files",
|
||||
provider_type="inline::localfs",
|
||||
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
)
|
||||
default_models, _ = get_model_registry(available_models)
|
||||
return DistributionTemplate(
|
||||
name="watsonx",
|
||||
name=name,
|
||||
distro_type="remote_hosted",
|
||||
description="Use watsonx for running LLM inference",
|
||||
container_image=None,
|
||||
|
@ -92,6 +99,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"run.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
"inference": [inference_provider, embedding_provider],
|
||||
"files": [files_provider],
|
||||
},
|
||||
default_models=default_models + [embedding_model],
|
||||
default_tool_groups=default_tool_groups,
|
||||
|
|
|
@ -92,6 +92,8 @@ class ToolParamDefinition(BaseModel):
|
|||
param_type: str
|
||||
description: str | None = None
|
||||
required: bool | None = True
|
||||
items: Any | None = None
|
||||
title: str | None = None
|
||||
default: Any | None = None
|
||||
|
||||
|
||||
|
|
|
@ -131,6 +131,15 @@ class ProviderSpec(BaseModel):
|
|||
""",
|
||||
)
|
||||
|
||||
pip_packages: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="The pip dependencies needed for this implementation",
|
||||
)
|
||||
|
||||
provider_data_validator: str | None = Field(
|
||||
default=None,
|
||||
)
|
||||
|
||||
is_external: bool = Field(default=False, description="Notes whether this provider is an external provider.")
|
||||
|
||||
# used internally by the resolver; this is a hack for now
|
||||
|
@ -145,45 +154,8 @@ class RoutingTable(Protocol):
|
|||
async def get_provider_impl(self, routing_key: str) -> Any: ...
|
||||
|
||||
|
||||
# TODO: this can now be inlined into RemoteProviderSpec
|
||||
@json_schema_type
|
||||
class AdapterSpec(BaseModel):
|
||||
adapter_type: str = Field(
|
||||
...,
|
||||
description="Unique identifier for this adapter",
|
||||
)
|
||||
module: str = Field(
|
||||
default_factory=str,
|
||||
description="""
|
||||
Fully-qualified name of the module to import. The module is expected to have:
|
||||
|
||||
- `get_adapter_impl(config, deps)`: returns the adapter implementation
|
||||
""",
|
||||
)
|
||||
pip_packages: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="The pip dependencies needed for this implementation",
|
||||
)
|
||||
config_class: str = Field(
|
||||
description="Fully-qualified classname of the config for this provider",
|
||||
)
|
||||
provider_data_validator: str | None = Field(
|
||||
default=None,
|
||||
)
|
||||
description: str | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
A description of the provider. This is used to display in the documentation.
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class InlineProviderSpec(ProviderSpec):
|
||||
pip_packages: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="The pip dependencies needed for this implementation",
|
||||
)
|
||||
container_image: str | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
|
@ -191,10 +163,6 @@ The container image to use for this implementation. If one is provided, pip_pack
|
|||
If a provider depends on other providers, the dependencies MUST NOT specify a container image.
|
||||
""",
|
||||
)
|
||||
# module field is inherited from ProviderSpec
|
||||
provider_data_validator: str | None = Field(
|
||||
default=None,
|
||||
)
|
||||
description: str | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
|
@ -223,10 +191,15 @@ class RemoteProviderConfig(BaseModel):
|
|||
|
||||
@json_schema_type
|
||||
class RemoteProviderSpec(ProviderSpec):
|
||||
adapter: AdapterSpec = Field(
|
||||
adapter_type: str = Field(
|
||||
...,
|
||||
description="Unique identifier for this adapter",
|
||||
)
|
||||
|
||||
description: str | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
If some code is needed to convert the remote responses into Llama Stack compatible
|
||||
API responses, specify the adapter here.
|
||||
A description of the provider. This is used to display in the documentation.
|
||||
""",
|
||||
)
|
||||
|
||||
|
@ -234,33 +207,6 @@ API responses, specify the adapter here.
|
|||
def container_image(self) -> str | None:
|
||||
return None
|
||||
|
||||
# module field is inherited from ProviderSpec
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> list[str]:
|
||||
return self.adapter.pip_packages
|
||||
|
||||
@property
|
||||
def provider_data_validator(self) -> str | None:
|
||||
return self.adapter.provider_data_validator
|
||||
|
||||
|
||||
def remote_provider_spec(
|
||||
api: Api,
|
||||
adapter: AdapterSpec,
|
||||
api_dependencies: list[Api] | None = None,
|
||||
optional_api_dependencies: list[Api] | None = None,
|
||||
) -> RemoteProviderSpec:
|
||||
return RemoteProviderSpec(
|
||||
api=api,
|
||||
provider_type=f"remote::{adapter.adapter_type}",
|
||||
config_class=adapter.config_class,
|
||||
module=adapter.module,
|
||||
adapter=adapter,
|
||||
api_dependencies=api_dependencies or [],
|
||||
optional_api_dependencies=optional_api_dependencies or [],
|
||||
)
|
||||
|
||||
|
||||
class HealthStatus(StrEnum):
|
||||
OK = "OK"
|
||||
|
|
|
@ -830,6 +830,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
items=param.items,
|
||||
title=param.title,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool_def.parameters
|
||||
|
@ -873,6 +875,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
items=param.items,
|
||||
title=param.title,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool_def.parameters
|
||||
|
@ -952,7 +956,7 @@ async def get_raw_document_text(document: Document) -> str:
|
|||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
elif not (document.mime_type.startswith("text/") or document.mime_type == "application/yaml"):
|
||||
elif not (document.mime_type.startswith("text/") or document.mime_type in ("application/yaml", "application/json")):
|
||||
raise ValueError(f"Unexpected document mime type: {document.mime_type}")
|
||||
|
||||
if isinstance(document.content, URL):
|
||||
|
|
|
@ -237,6 +237,7 @@ class OpenAIResponsesImpl:
|
|||
response_tools=tools,
|
||||
temperature=temperature,
|
||||
response_format=response_format,
|
||||
inputs=input,
|
||||
)
|
||||
|
||||
# Create orchestrator and delegate streaming logic
|
||||
|
|
|
@ -10,10 +10,12 @@ from typing import Any
|
|||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
AllowedToolsFilter,
|
||||
ApprovalFilter,
|
||||
MCPListToolsTool,
|
||||
OpenAIResponseContentPartOutputText,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseMCPApprovalRequest,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseObjectStreamResponseCompleted,
|
||||
|
@ -50,6 +52,36 @@ from .utils import convert_chat_choice_to_response_message, is_function_tool_cal
|
|||
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||
|
||||
|
||||
def convert_tooldef_to_chat_tool(tool_def):
|
||||
"""Convert a ToolDef to OpenAI ChatCompletionToolParam format.
|
||||
|
||||
Args:
|
||||
tool_def: ToolDef from the tools API
|
||||
|
||||
Returns:
|
||||
ChatCompletionToolParam suitable for OpenAI chat completion
|
||||
"""
|
||||
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||
|
||||
internal_tool_def = ToolDefinition(
|
||||
tool_name=tool_def.name,
|
||||
description=tool_def.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
items=param.items,
|
||||
)
|
||||
for param in tool_def.parameters
|
||||
},
|
||||
)
|
||||
return convert_tooldef_to_openai_tool(internal_tool_def)
|
||||
|
||||
|
||||
class StreamingResponseOrchestrator:
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -117,10 +149,17 @@ class StreamingResponseOrchestrator:
|
|||
raise ValueError("Streaming chunk processor failed to return completion data")
|
||||
current_response = self._build_chat_completion(completion_result_data)
|
||||
|
||||
function_tool_calls, non_function_tool_calls, next_turn_messages = self._separate_tool_calls(
|
||||
function_tool_calls, non_function_tool_calls, approvals, next_turn_messages = self._separate_tool_calls(
|
||||
current_response, messages
|
||||
)
|
||||
|
||||
# add any approval requests required
|
||||
for tool_call in approvals:
|
||||
async for evt in self._add_mcp_approval_request(
|
||||
tool_call.function.name, tool_call.function.arguments, output_messages
|
||||
):
|
||||
yield evt
|
||||
|
||||
# Handle choices with no tool calls
|
||||
for choice in current_response.choices:
|
||||
if not (choice.message.tool_calls and self.ctx.response_tools):
|
||||
|
@ -164,10 +203,11 @@ class StreamingResponseOrchestrator:
|
|||
# Emit response.completed
|
||||
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
|
||||
|
||||
def _separate_tool_calls(self, current_response, messages) -> tuple[list, list, list]:
|
||||
def _separate_tool_calls(self, current_response, messages) -> tuple[list, list, list, list]:
|
||||
"""Separate tool calls into function and non-function categories."""
|
||||
function_tool_calls = []
|
||||
non_function_tool_calls = []
|
||||
approvals = []
|
||||
next_turn_messages = messages.copy()
|
||||
|
||||
for choice in current_response.choices:
|
||||
|
@ -178,9 +218,23 @@ class StreamingResponseOrchestrator:
|
|||
if is_function_tool_call(tool_call, self.ctx.response_tools):
|
||||
function_tool_calls.append(tool_call)
|
||||
else:
|
||||
non_function_tool_calls.append(tool_call)
|
||||
if self._approval_required(tool_call.function.name):
|
||||
approval_response = self.ctx.approval_response(
|
||||
tool_call.function.name, tool_call.function.arguments
|
||||
)
|
||||
if approval_response:
|
||||
if approval_response.approve:
|
||||
logger.info(f"Approval granted for {tool_call.id} on {tool_call.function.name}")
|
||||
non_function_tool_calls.append(tool_call)
|
||||
else:
|
||||
logger.info(f"Approval denied for {tool_call.id} on {tool_call.function.name}")
|
||||
else:
|
||||
logger.info(f"Requesting approval for {tool_call.id} on {tool_call.function.name}")
|
||||
approvals.append(tool_call)
|
||||
else:
|
||||
non_function_tool_calls.append(tool_call)
|
||||
|
||||
return function_tool_calls, non_function_tool_calls, next_turn_messages
|
||||
return function_tool_calls, non_function_tool_calls, approvals, next_turn_messages
|
||||
|
||||
async def _process_streaming_chunks(
|
||||
self, completion_result, output_messages: list[OpenAIResponseOutput]
|
||||
|
@ -556,23 +610,7 @@ class StreamingResponseOrchestrator:
|
|||
continue
|
||||
if not always_allowed or t.name in always_allowed:
|
||||
# Add to chat tools for inference
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||
|
||||
tool_def = ToolDefinition(
|
||||
tool_name=t.name,
|
||||
description=t.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
for param in t.parameters
|
||||
},
|
||||
)
|
||||
openai_tool = convert_tooldef_to_openai_tool(tool_def)
|
||||
openai_tool = convert_tooldef_to_chat_tool(t)
|
||||
if self.ctx.chat_tools is None:
|
||||
self.ctx.chat_tools = []
|
||||
self.ctx.chat_tools.append(openai_tool)
|
||||
|
@ -632,3 +670,46 @@ class StreamingResponseOrchestrator:
|
|||
# TODO: Emit mcp_list_tools.failed event if needed
|
||||
logger.exception(f"Failed to list MCP tools from {mcp_tool.server_url}: {e}")
|
||||
raise
|
||||
|
||||
def _approval_required(self, tool_name: str) -> bool:
|
||||
if tool_name not in self.mcp_tool_to_server:
|
||||
return False
|
||||
mcp_server = self.mcp_tool_to_server[tool_name]
|
||||
if mcp_server.require_approval == "always":
|
||||
return True
|
||||
if mcp_server.require_approval == "never":
|
||||
return False
|
||||
if isinstance(mcp_server, ApprovalFilter):
|
||||
if tool_name in mcp_server.always:
|
||||
return True
|
||||
if tool_name in mcp_server.never:
|
||||
return False
|
||||
return True
|
||||
|
||||
async def _add_mcp_approval_request(
|
||||
self, tool_name: str, arguments: str, output_messages: list[OpenAIResponseOutput]
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
mcp_server = self.mcp_tool_to_server[tool_name]
|
||||
mcp_approval_request = OpenAIResponseMCPApprovalRequest(
|
||||
arguments=arguments,
|
||||
id=f"approval_{uuid.uuid4()}",
|
||||
name=tool_name,
|
||||
server_label=mcp_server.server_label,
|
||||
)
|
||||
output_messages.append(mcp_approval_request)
|
||||
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
|
||||
response_id=self.response_id,
|
||||
item=mcp_approval_request,
|
||||
output_index=len(output_messages) - 1,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
||||
self.sequence_number += 1
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemDone(
|
||||
response_id=self.response_id,
|
||||
item=mcp_approval_request,
|
||||
output_index=len(output_messages) - 1,
|
||||
sequence_number=self.sequence_number,
|
||||
)
|
||||
|
|
|
@ -10,7 +10,10 @@ from openai.types.chat import ChatCompletionToolParam
|
|||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseMCPApprovalRequest,
|
||||
OpenAIResponseMCPApprovalResponse,
|
||||
OpenAIResponseObjectStream,
|
||||
OpenAIResponseOutput,
|
||||
)
|
||||
|
@ -58,3 +61,37 @@ class ChatCompletionContext(BaseModel):
|
|||
chat_tools: list[ChatCompletionToolParam] | None = None
|
||||
temperature: float | None
|
||||
response_format: OpenAIResponseFormatParam
|
||||
approval_requests: list[OpenAIResponseMCPApprovalRequest] = []
|
||||
approval_responses: dict[str, OpenAIResponseMCPApprovalResponse] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[OpenAIMessageParam],
|
||||
response_tools: list[OpenAIResponseInputTool] | None,
|
||||
temperature: float | None,
|
||||
response_format: OpenAIResponseFormatParam,
|
||||
inputs: list[OpenAIResponseInput] | str,
|
||||
):
|
||||
super().__init__(
|
||||
model=model,
|
||||
messages=messages,
|
||||
response_tools=response_tools,
|
||||
temperature=temperature,
|
||||
response_format=response_format,
|
||||
)
|
||||
if not isinstance(inputs, str):
|
||||
self.approval_requests = [input for input in inputs if input.type == "mcp_approval_request"]
|
||||
self.approval_responses = {
|
||||
input.approval_request_id: input for input in inputs if input.type == "mcp_approval_response"
|
||||
}
|
||||
|
||||
def approval_response(self, tool_name: str, arguments: str) -> OpenAIResponseMCPApprovalResponse | None:
|
||||
request = self._approval_request(tool_name, arguments)
|
||||
return self.approval_responses.get(request.id, None) if request else None
|
||||
|
||||
def _approval_request(self, tool_name: str, arguments: str) -> OpenAIResponseMCPApprovalRequest | None:
|
||||
for request in self.approval_requests:
|
||||
if request.name == tool_name and request.arguments == arguments:
|
||||
return request
|
||||
return None
|
||||
|
|
|
@ -13,6 +13,8 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseInputMessageContentImage,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseMCPApprovalRequest,
|
||||
OpenAIResponseMCPApprovalResponse,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseOutputMessageContent,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
|
@ -149,6 +151,11 @@ async def convert_response_input_to_chat_messages(
|
|||
elif isinstance(input_item, OpenAIResponseOutputMessageMCPListTools):
|
||||
# the tool list will be handled separately
|
||||
pass
|
||||
elif isinstance(input_item, OpenAIResponseMCPApprovalRequest) or isinstance(
|
||||
input_item, OpenAIResponseMCPApprovalResponse
|
||||
):
|
||||
# these are handled by the responses impl itself and not pass through to chat completions
|
||||
pass
|
||||
else:
|
||||
content = await convert_response_content_to_chat_content(input_item.content)
|
||||
message_type = await get_message_type_by_role(input_item.role)
|
||||
|
|
|
@ -12,7 +12,7 @@ from llama_stack.apis.agents import Agents, StepType
|
|||
from llama_stack.apis.benchmarks import Benchmark
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.inference import Inference, SystemMessage, UserMessage
|
||||
from llama_stack.apis.inference import Inference, OpenAISystemMessageParam, OpenAIUserMessageParam, UserMessage
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
|
||||
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
||||
|
@ -75,6 +75,13 @@ class MetaReferenceEvalImpl(
|
|||
)
|
||||
self.benchmarks[task_def.identifier] = task_def
|
||||
|
||||
async def unregister_benchmark(self, benchmark_id: str) -> None:
|
||||
if benchmark_id in self.benchmarks:
|
||||
del self.benchmarks[benchmark_id]
|
||||
|
||||
key = f"{EVAL_TASKS_PREFIX}{benchmark_id}"
|
||||
await self.kvstore.delete(key)
|
||||
|
||||
async def run_eval(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
|
@ -152,31 +159,40 @@ class MetaReferenceEvalImpl(
|
|||
) -> list[dict[str, Any]]:
|
||||
candidate = benchmark_config.eval_candidate
|
||||
assert candidate.sampling_params.max_tokens is not None, "SamplingParams.max_tokens must be provided"
|
||||
sampling_params = {"max_tokens": candidate.sampling_params.max_tokens}
|
||||
|
||||
generations = []
|
||||
for x in tqdm(input_rows):
|
||||
if ColumnName.completion_input.value in x:
|
||||
if candidate.sampling_params.stop:
|
||||
sampling_params["stop"] = candidate.sampling_params.stop
|
||||
|
||||
input_content = json.loads(x[ColumnName.completion_input.value])
|
||||
response = await self.inference_api.completion(
|
||||
response = await self.inference_api.openai_completion(
|
||||
model=candidate.model,
|
||||
content=input_content,
|
||||
sampling_params=candidate.sampling_params,
|
||||
prompt=input_content,
|
||||
**sampling_params,
|
||||
)
|
||||
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
|
||||
generations.append({ColumnName.generated_answer.value: response.choices[0].text})
|
||||
elif ColumnName.chat_completion_input.value in x:
|
||||
chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
|
||||
input_messages = [UserMessage(**x) for x in chat_completion_input_json if x["role"] == "user"]
|
||||
input_messages = [
|
||||
OpenAIUserMessageParam(**x) for x in chat_completion_input_json if x["role"] == "user"
|
||||
]
|
||||
|
||||
messages = []
|
||||
if candidate.system_message:
|
||||
messages.append(candidate.system_message)
|
||||
messages += [SystemMessage(**x) for x in chat_completion_input_json if x["role"] == "system"]
|
||||
|
||||
messages += [OpenAISystemMessageParam(**x) for x in chat_completion_input_json if x["role"] == "system"]
|
||||
|
||||
messages += input_messages
|
||||
response = await self.inference_api.chat_completion(
|
||||
model_id=candidate.model,
|
||||
response = await self.inference_api.openai_chat_completion(
|
||||
model=candidate.model,
|
||||
messages=messages,
|
||||
sampling_params=candidate.sampling_params,
|
||||
**sampling_params,
|
||||
)
|
||||
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
|
||||
generations.append({ColumnName.generated_answer.value: response.choices[0].message.content})
|
||||
else:
|
||||
raise ValueError("Invalid input row")
|
||||
|
||||
|
|
|
@ -9,11 +9,12 @@ import uuid
|
|||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import File, Form, Response, UploadFile
|
||||
from fastapi import Depends, File, Form, Response, UploadFile
|
||||
|
||||
from llama_stack.apis.common.errors import ResourceNotFoundError
|
||||
from llama_stack.apis.common.responses import Order
|
||||
from llama_stack.apis.files import (
|
||||
ExpiresAfter,
|
||||
Files,
|
||||
ListOpenAIFileResponse,
|
||||
OpenAIFileDeleteResponse,
|
||||
|
@ -22,6 +23,7 @@ from llama_stack.apis.files import (
|
|||
)
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.files.form_data import parse_expires_after
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
|
||||
|
@ -44,7 +46,7 @@ class LocalfsFilesImpl(Files):
|
|||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Initialize SQL store for metadata
|
||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store))
|
||||
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store), self.policy)
|
||||
await self.sql_store.create_table(
|
||||
"openai_files",
|
||||
{
|
||||
|
@ -74,7 +76,7 @@ class LocalfsFilesImpl(Files):
|
|||
if not self.sql_store:
|
||||
raise RuntimeError("Files provider not initialized")
|
||||
|
||||
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
|
||||
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||
if not row:
|
||||
raise ResourceNotFoundError(file_id, "File", "client.files.list()")
|
||||
|
||||
|
@ -86,14 +88,13 @@ class LocalfsFilesImpl(Files):
|
|||
self,
|
||||
file: Annotated[UploadFile, File()],
|
||||
purpose: Annotated[OpenAIFilePurpose, Form()],
|
||||
expires_after_anchor: Annotated[str | None, Form(alias="expires_after[anchor]")] = None,
|
||||
expires_after_seconds: Annotated[int | None, Form(alias="expires_after[seconds]")] = None,
|
||||
expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None,
|
||||
) -> OpenAIFileObject:
|
||||
"""Upload a file that can be used across various endpoints."""
|
||||
if not self.sql_store:
|
||||
raise RuntimeError("Files provider not initialized")
|
||||
|
||||
if expires_after_anchor is not None or expires_after_seconds is not None:
|
||||
if expires_after is not None:
|
||||
raise NotImplementedError("File expiration is not supported by this provider")
|
||||
|
||||
file_id = self._generate_file_id()
|
||||
|
@ -150,7 +151,6 @@ class LocalfsFilesImpl(Files):
|
|||
|
||||
paginated_result = await self.sql_store.fetch_all(
|
||||
table="openai_files",
|
||||
policy=self.policy,
|
||||
where=where_conditions if where_conditions else None,
|
||||
order_by=[("created_at", order.value)],
|
||||
cursor=("id", after) if after else None,
|
||||
|
|
|
@ -18,8 +18,6 @@ from llama_stack.apis.common.content_types import (
|
|||
ToolCallParseStatus,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
BatchChatCompletionResponse,
|
||||
BatchCompletionResponse,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
|
@ -219,41 +217,6 @@ class MetaReferenceInferenceImpl(
|
|||
results = await self._nonstream_completion([request])
|
||||
return results[0]
|
||||
|
||||
async def batch_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content_batch: list[InterleavedContent],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> BatchCompletionResponse:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if logprobs:
|
||||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||
|
||||
content_batch = [
|
||||
augment_content_with_response_format_prompt(response_format, content) for content in content_batch
|
||||
]
|
||||
|
||||
request_batch = []
|
||||
for content in content_batch:
|
||||
request = CompletionRequest(
|
||||
model=model_id,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
self.check_model(request)
|
||||
request = await convert_request_to_raw(request)
|
||||
request_batch.append(request)
|
||||
|
||||
results = await self._nonstream_completion(request_batch)
|
||||
return BatchCompletionResponse(batch=results)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
|
||||
|
@ -399,49 +362,6 @@ class MetaReferenceInferenceImpl(
|
|||
results = await self._nonstream_chat_completion([request])
|
||||
return results[0]
|
||||
|
||||
async def batch_chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages_batch: list[list[Message]],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> BatchChatCompletionResponse:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if logprobs:
|
||||
assert logprobs.top_k == 1, f"Unexpected top_k={logprobs.top_k}"
|
||||
|
||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||
request_batch = []
|
||||
for messages in messages_batch:
|
||||
request = ChatCompletionRequest(
|
||||
model=model_id,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
response_format=response_format,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config or ToolConfig(),
|
||||
)
|
||||
self.check_model(request)
|
||||
|
||||
# augment and rewrite messages depending on the model
|
||||
request.messages = chat_completion_request_to_messages(request, self.llama_model.core_model_id.value)
|
||||
# download media and convert to raw content so we can send it to the model
|
||||
request = await convert_request_to_raw(request)
|
||||
request_batch.append(request)
|
||||
|
||||
if self.config.create_distributed_process_group:
|
||||
if SEMAPHORE.locked():
|
||||
raise RuntimeError("Only one concurrent request is supported")
|
||||
|
||||
results = await self._nonstream_chat_completion(request_batch)
|
||||
return BatchChatCompletionResponse(batch=results)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request_batch: list[ChatCompletionRequest]
|
||||
) -> list[ChatCompletionResponse]:
|
||||
|
|
|
@ -290,13 +290,13 @@ class LlamaGuardShield:
|
|||
else:
|
||||
shield_input_message = self.build_text_shield_input(messages)
|
||||
|
||||
# TODO: llama-stack inference protocol has issues with non-streaming inference code
|
||||
response = await self.inference_api.chat_completion(
|
||||
model_id=self.model,
|
||||
response = await self.inference_api.openai_chat_completion(
|
||||
model=self.model,
|
||||
messages=[shield_input_message],
|
||||
stream=False,
|
||||
temperature=0.0, # default is 1, which is too high for safety
|
||||
)
|
||||
content = response.completion_message.content
|
||||
content = response.choices[0].message.content
|
||||
content = content.strip()
|
||||
return self.get_shield_response(content)
|
||||
|
||||
|
|
|
@ -63,6 +63,9 @@ class LlmAsJudgeScoringImpl(
|
|||
async def register_scoring_function(self, function_def: ScoringFn) -> None:
|
||||
self.llm_as_judge_fn.register_scoring_fn_def(function_def)
|
||||
|
||||
async def unregister_scoring_function(self, scoring_fn_id: str) -> None:
|
||||
self.llm_as_judge_fn.unregister_scoring_fn_def(scoring_fn_id)
|
||||
|
||||
async def score_batch(
|
||||
self,
|
||||
dataset_id: str,
|
||||
|
|
|
@ -224,10 +224,6 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
return _GLOBAL_STORAGE["gauges"][name]
|
||||
|
||||
def _log_metric(self, event: MetricEvent) -> None:
|
||||
# Always log to console if console sink is enabled (debug)
|
||||
if TelemetrySink.CONSOLE in self.config.sinks:
|
||||
logger.debug(f"METRIC: {event.metric}={event.value} {event.unit} {event.attributes}")
|
||||
|
||||
# Add metric as an event to the current span
|
||||
try:
|
||||
with self._lock:
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
from jinja2 import Template
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.inference import UserMessage
|
||||
from llama_stack.apis.inference import OpenAIUserMessageParam
|
||||
from llama_stack.apis.tools.rag_tool import (
|
||||
DefaultRAGQueryGeneratorConfig,
|
||||
LLMRAGQueryGeneratorConfig,
|
||||
|
@ -61,16 +61,16 @@ async def llm_rag_query_generator(
|
|||
messages = [interleaved_content_as_str(content)]
|
||||
|
||||
template = Template(config.template)
|
||||
content = template.render({"messages": messages})
|
||||
rendered_content: str = template.render({"messages": messages})
|
||||
|
||||
model = config.model
|
||||
message = UserMessage(content=content)
|
||||
response = await inference_api.chat_completion(
|
||||
model_id=model,
|
||||
message = OpenAIUserMessageParam(content=rendered_content)
|
||||
response = await inference_api.openai_chat_completion(
|
||||
model=model,
|
||||
messages=[message],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
query = response.completion_message.content
|
||||
query = response.choices[0].message.content
|
||||
|
||||
return query
|
||||
|
|
|
@ -45,10 +45,7 @@ from llama_stack.apis.vector_io import (
|
|||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
content_from_doc,
|
||||
parse_data_url,
|
||||
)
|
||||
from llama_stack.providers.utils.memory.vector_store import parse_data_url
|
||||
|
||||
from .config import RagToolRuntimeConfig
|
||||
from .context_retriever import generate_rag_query
|
||||
|
@ -60,6 +57,47 @@ def make_random_string(length: int = 8):
|
|||
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
|
||||
|
||||
|
||||
async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]:
|
||||
"""Get raw binary data and mime type from a RAGDocument for file upload."""
|
||||
if isinstance(doc.content, URL):
|
||||
if doc.content.uri.startswith("data:"):
|
||||
parts = parse_data_url(doc.content.uri)
|
||||
mime_type = parts["mimetype"]
|
||||
data = parts["data"]
|
||||
|
||||
if parts["is_base64"]:
|
||||
file_data = base64.b64decode(data)
|
||||
else:
|
||||
file_data = data.encode("utf-8")
|
||||
|
||||
return file_data, mime_type
|
||||
else:
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(doc.content.uri)
|
||||
r.raise_for_status()
|
||||
mime_type = r.headers.get("content-type", "application/octet-stream")
|
||||
return r.content, mime_type
|
||||
else:
|
||||
if isinstance(doc.content, str):
|
||||
content_str = doc.content
|
||||
else:
|
||||
content_str = interleaved_content_as_str(doc.content)
|
||||
|
||||
if content_str.startswith("data:"):
|
||||
parts = parse_data_url(content_str)
|
||||
mime_type = parts["mimetype"]
|
||||
data = parts["data"]
|
||||
|
||||
if parts["is_base64"]:
|
||||
file_data = base64.b64decode(data)
|
||||
else:
|
||||
file_data = data.encode("utf-8")
|
||||
|
||||
return file_data, mime_type
|
||||
else:
|
||||
return content_str.encode("utf-8"), "text/plain"
|
||||
|
||||
|
||||
class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -95,46 +133,52 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
return
|
||||
|
||||
for doc in documents:
|
||||
if isinstance(doc.content, URL):
|
||||
if doc.content.uri.startswith("data:"):
|
||||
parts = parse_data_url(doc.content.uri)
|
||||
file_data = base64.b64decode(parts["data"]) if parts["is_base64"] else parts["data"].encode()
|
||||
mime_type = parts["mimetype"]
|
||||
else:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(doc.content.uri)
|
||||
file_data = response.content
|
||||
mime_type = doc.mime_type or response.headers.get("content-type", "application/octet-stream")
|
||||
else:
|
||||
content_str = await content_from_doc(doc)
|
||||
file_data = content_str.encode("utf-8")
|
||||
mime_type = doc.mime_type or "text/plain"
|
||||
try:
|
||||
try:
|
||||
file_data, mime_type = await raw_data_from_doc(doc)
|
||||
except Exception as e:
|
||||
log.error(f"Failed to extract content from document {doc.document_id}: {e}")
|
||||
continue
|
||||
|
||||
file_extension = mimetypes.guess_extension(mime_type) or ".txt"
|
||||
filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}")
|
||||
file_extension = mimetypes.guess_extension(mime_type) or ".txt"
|
||||
filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}")
|
||||
|
||||
file_obj = io.BytesIO(file_data)
|
||||
file_obj.name = filename
|
||||
file_obj = io.BytesIO(file_data)
|
||||
file_obj.name = filename
|
||||
|
||||
upload_file = UploadFile(file=file_obj, filename=filename)
|
||||
upload_file = UploadFile(file=file_obj, filename=filename)
|
||||
|
||||
created_file = await self.files_api.openai_upload_file(
|
||||
file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS
|
||||
)
|
||||
try:
|
||||
created_file = await self.files_api.openai_upload_file(
|
||||
file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Failed to upload file for document {doc.document_id}: {e}")
|
||||
continue
|
||||
|
||||
chunking_strategy = VectorStoreChunkingStrategyStatic(
|
||||
static=VectorStoreChunkingStrategyStaticConfig(
|
||||
max_chunk_size_tokens=chunk_size_in_tokens,
|
||||
chunk_overlap_tokens=chunk_size_in_tokens // 4,
|
||||
chunking_strategy = VectorStoreChunkingStrategyStatic(
|
||||
static=VectorStoreChunkingStrategyStaticConfig(
|
||||
max_chunk_size_tokens=chunk_size_in_tokens,
|
||||
chunk_overlap_tokens=chunk_size_in_tokens // 4,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
await self.vector_io_api.openai_attach_file_to_vector_store(
|
||||
vector_store_id=vector_db_id,
|
||||
file_id=created_file.id,
|
||||
attributes=doc.metadata,
|
||||
chunking_strategy=chunking_strategy,
|
||||
)
|
||||
try:
|
||||
await self.vector_io_api.openai_attach_file_to_vector_store(
|
||||
vector_store_id=vector_db_id,
|
||||
file_id=created_file.id,
|
||||
attributes=doc.metadata,
|
||||
chunking_strategy=chunking_strategy,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"Failed to attach file {created_file.id} to vector store {vector_db_id} for document {doc.document_id}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Unexpected error processing document {doc.document_id}: {e}")
|
||||
continue
|
||||
|
||||
async def query(
|
||||
self,
|
||||
|
@ -274,7 +318,6 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
if query_config:
|
||||
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
|
||||
else:
|
||||
# handle someone passing an empty dict
|
||||
query_config = RAGQueryConfig()
|
||||
|
||||
query = kwargs["query"]
|
||||
|
@ -285,6 +328,6 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
)
|
||||
|
||||
return ToolInvocationResult(
|
||||
content=result.content,
|
||||
content=result.content or [],
|
||||
metadata=result.metadata,
|
||||
)
|
||||
|
|
|
@ -13,7 +13,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
InlineProviderSpec(
|
||||
api=Api.batches,
|
||||
provider_type="inline::reference",
|
||||
pip_packages=["openai"],
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.inline.batches.reference",
|
||||
config_class="llama_stack.providers.inline.batches.reference.config.ReferenceBatchesImplConfig",
|
||||
api_dependencies=[
|
||||
|
|
|
@ -6,11 +6,10 @@
|
|||
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
RemoteProviderSpec,
|
||||
)
|
||||
|
||||
|
||||
|
@ -25,28 +24,26 @@ def available_providers() -> list[ProviderSpec]:
|
|||
api_dependencies=[],
|
||||
description="Local filesystem-based dataset I/O provider for reading and writing datasets to local storage.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.datasetio,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="huggingface",
|
||||
pip_packages=[
|
||||
"datasets>=4.0.0",
|
||||
],
|
||||
module="llama_stack.providers.remote.datasetio.huggingface",
|
||||
config_class="llama_stack.providers.remote.datasetio.huggingface.HuggingfaceDatasetIOConfig",
|
||||
description="HuggingFace datasets provider for accessing and managing datasets from the HuggingFace Hub.",
|
||||
),
|
||||
adapter_type="huggingface",
|
||||
provider_type="remote::huggingface",
|
||||
pip_packages=[
|
||||
"datasets>=4.0.0",
|
||||
],
|
||||
module="llama_stack.providers.remote.datasetio.huggingface",
|
||||
config_class="llama_stack.providers.remote.datasetio.huggingface.HuggingfaceDatasetIOConfig",
|
||||
description="HuggingFace datasets provider for accessing and managing datasets from the HuggingFace Hub.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.datasetio,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="nvidia",
|
||||
pip_packages=[
|
||||
"datasets>=4.0.0",
|
||||
],
|
||||
module="llama_stack.providers.remote.datasetio.nvidia",
|
||||
config_class="llama_stack.providers.remote.datasetio.nvidia.NvidiaDatasetIOConfig",
|
||||
description="NVIDIA's dataset I/O provider for accessing datasets from NVIDIA's data platform.",
|
||||
),
|
||||
adapter_type="nvidia",
|
||||
provider_type="remote::nvidia",
|
||||
module="llama_stack.providers.remote.datasetio.nvidia",
|
||||
config_class="llama_stack.providers.remote.datasetio.nvidia.NvidiaDatasetIOConfig",
|
||||
pip_packages=[
|
||||
"datasets>=4.0.0",
|
||||
],
|
||||
description="NVIDIA's dataset I/O provider for accessing datasets from NVIDIA's data platform.",
|
||||
),
|
||||
]
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
|
||||
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
|
||||
|
||||
|
||||
def available_providers() -> list[ProviderSpec]:
|
||||
|
@ -25,17 +25,16 @@ def available_providers() -> list[ProviderSpec]:
|
|||
],
|
||||
description="Meta's reference implementation of evaluation tasks with support for multiple languages and evaluation metrics.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.eval,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="nvidia",
|
||||
pip_packages=[
|
||||
"requests",
|
||||
],
|
||||
module="llama_stack.providers.remote.eval.nvidia",
|
||||
config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig",
|
||||
description="NVIDIA's evaluation provider for running evaluation tasks on NVIDIA's platform.",
|
||||
),
|
||||
adapter_type="nvidia",
|
||||
pip_packages=[
|
||||
"requests",
|
||||
],
|
||||
provider_type="remote::nvidia",
|
||||
module="llama_stack.providers.remote.eval.nvidia",
|
||||
config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig",
|
||||
description="NVIDIA's evaluation provider for running evaluation tasks on NVIDIA's platform.",
|
||||
api_dependencies=[
|
||||
Api.datasetio,
|
||||
Api.datasets,
|
||||
|
|
|
@ -4,13 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
)
|
||||
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages
|
||||
|
||||
|
||||
|
@ -25,14 +19,13 @@ def available_providers() -> list[ProviderSpec]:
|
|||
config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig",
|
||||
description="Local filesystem-based file storage provider for managing files and documents locally.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.files,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="s3",
|
||||
pip_packages=["boto3"] + sql_store_pip_packages,
|
||||
module="llama_stack.providers.remote.files.s3",
|
||||
config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig",
|
||||
description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.",
|
||||
),
|
||||
provider_type="remote::s3",
|
||||
adapter_type="s3",
|
||||
pip_packages=["boto3"] + sql_store_pip_packages,
|
||||
module="llama_stack.providers.remote.files.s3",
|
||||
config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig",
|
||||
description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.",
|
||||
),
|
||||
]
|
||||
|
|
|
@ -6,11 +6,10 @@
|
|||
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
RemoteProviderSpec,
|
||||
)
|
||||
|
||||
META_REFERENCE_DEPS = [
|
||||
|
@ -49,180 +48,167 @@ def available_providers() -> list[ProviderSpec]:
|
|||
config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig",
|
||||
description="Sentence Transformers inference provider for text embeddings and similarity search.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="cerebras",
|
||||
pip_packages=[
|
||||
"cerebras_cloud_sdk",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.cerebras",
|
||||
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
|
||||
description="Cerebras inference provider for running models on Cerebras Cloud platform.",
|
||||
),
|
||||
adapter_type="cerebras",
|
||||
provider_type="remote::cerebras",
|
||||
pip_packages=[
|
||||
"cerebras_cloud_sdk",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.cerebras",
|
||||
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
|
||||
description="Cerebras inference provider for running models on Cerebras Cloud platform.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="ollama",
|
||||
pip_packages=["ollama", "aiohttp", "h11>=0.16.0"],
|
||||
config_class="llama_stack.providers.remote.inference.ollama.OllamaImplConfig",
|
||||
module="llama_stack.providers.remote.inference.ollama",
|
||||
description="Ollama inference provider for running local models through the Ollama runtime.",
|
||||
),
|
||||
adapter_type="ollama",
|
||||
provider_type="remote::ollama",
|
||||
pip_packages=["ollama", "aiohttp", "h11>=0.16.0"],
|
||||
config_class="llama_stack.providers.remote.inference.ollama.OllamaImplConfig",
|
||||
module="llama_stack.providers.remote.inference.ollama",
|
||||
description="Ollama inference provider for running local models through the Ollama runtime.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="vllm",
|
||||
pip_packages=["openai"],
|
||||
module="llama_stack.providers.remote.inference.vllm",
|
||||
config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig",
|
||||
description="Remote vLLM inference provider for connecting to vLLM servers.",
|
||||
),
|
||||
adapter_type="vllm",
|
||||
provider_type="remote::vllm",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.vllm",
|
||||
config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator",
|
||||
description="Remote vLLM inference provider for connecting to vLLM servers.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="tgi",
|
||||
pip_packages=["huggingface_hub", "aiohttp"],
|
||||
module="llama_stack.providers.remote.inference.tgi",
|
||||
config_class="llama_stack.providers.remote.inference.tgi.TGIImplConfig",
|
||||
description="Text Generation Inference (TGI) provider for HuggingFace model serving.",
|
||||
),
|
||||
adapter_type="tgi",
|
||||
provider_type="remote::tgi",
|
||||
pip_packages=["huggingface_hub", "aiohttp"],
|
||||
module="llama_stack.providers.remote.inference.tgi",
|
||||
config_class="llama_stack.providers.remote.inference.tgi.TGIImplConfig",
|
||||
description="Text Generation Inference (TGI) provider for HuggingFace model serving.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="hf::serverless",
|
||||
pip_packages=["huggingface_hub", "aiohttp"],
|
||||
module="llama_stack.providers.remote.inference.tgi",
|
||||
config_class="llama_stack.providers.remote.inference.tgi.InferenceAPIImplConfig",
|
||||
description="HuggingFace Inference API serverless provider for on-demand model inference.",
|
||||
),
|
||||
adapter_type="hf::serverless",
|
||||
provider_type="remote::hf::serverless",
|
||||
pip_packages=["huggingface_hub", "aiohttp"],
|
||||
module="llama_stack.providers.remote.inference.tgi",
|
||||
config_class="llama_stack.providers.remote.inference.tgi.InferenceAPIImplConfig",
|
||||
description="HuggingFace Inference API serverless provider for on-demand model inference.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="hf::endpoint",
|
||||
pip_packages=["huggingface_hub", "aiohttp"],
|
||||
module="llama_stack.providers.remote.inference.tgi",
|
||||
config_class="llama_stack.providers.remote.inference.tgi.InferenceEndpointImplConfig",
|
||||
description="HuggingFace Inference Endpoints provider for dedicated model serving.",
|
||||
),
|
||||
provider_type="remote::hf::endpoint",
|
||||
adapter_type="hf::endpoint",
|
||||
pip_packages=["huggingface_hub", "aiohttp"],
|
||||
module="llama_stack.providers.remote.inference.tgi",
|
||||
config_class="llama_stack.providers.remote.inference.tgi.InferenceEndpointImplConfig",
|
||||
description="HuggingFace Inference Endpoints provider for dedicated model serving.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="fireworks",
|
||||
pip_packages=[
|
||||
"fireworks-ai<=0.17.16",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.fireworks",
|
||||
config_class="llama_stack.providers.remote.inference.fireworks.FireworksImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.fireworks.FireworksProviderDataValidator",
|
||||
description="Fireworks AI inference provider for Llama models and other AI models on the Fireworks platform.",
|
||||
),
|
||||
adapter_type="fireworks",
|
||||
provider_type="remote::fireworks",
|
||||
pip_packages=[
|
||||
"fireworks-ai<=0.17.16",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.fireworks",
|
||||
config_class="llama_stack.providers.remote.inference.fireworks.FireworksImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.fireworks.FireworksProviderDataValidator",
|
||||
description="Fireworks AI inference provider for Llama models and other AI models on the Fireworks platform.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="together",
|
||||
pip_packages=[
|
||||
"together",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.together",
|
||||
config_class="llama_stack.providers.remote.inference.together.TogetherImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
|
||||
description="Together AI inference provider for open-source models and collaborative AI development.",
|
||||
),
|
||||
adapter_type="together",
|
||||
provider_type="remote::together",
|
||||
pip_packages=[
|
||||
"together",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.together",
|
||||
config_class="llama_stack.providers.remote.inference.together.TogetherImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
|
||||
description="Together AI inference provider for open-source models and collaborative AI development.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="bedrock",
|
||||
pip_packages=["boto3"],
|
||||
module="llama_stack.providers.remote.inference.bedrock",
|
||||
config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig",
|
||||
description="AWS Bedrock inference provider for accessing various AI models through AWS's managed service.",
|
||||
),
|
||||
adapter_type="bedrock",
|
||||
provider_type="remote::bedrock",
|
||||
pip_packages=["boto3"],
|
||||
module="llama_stack.providers.remote.inference.bedrock",
|
||||
config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig",
|
||||
description="AWS Bedrock inference provider for accessing various AI models through AWS's managed service.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="databricks",
|
||||
pip_packages=[
|
||||
"openai",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.databricks",
|
||||
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
|
||||
description="Databricks inference provider for running models on Databricks' unified analytics platform.",
|
||||
),
|
||||
adapter_type="databricks",
|
||||
provider_type="remote::databricks",
|
||||
pip_packages=["databricks-sdk"],
|
||||
module="llama_stack.providers.remote.inference.databricks",
|
||||
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
|
||||
description="Databricks inference provider for running models on Databricks' unified analytics platform.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="nvidia",
|
||||
pip_packages=[
|
||||
"openai",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.nvidia",
|
||||
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
|
||||
description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.",
|
||||
),
|
||||
adapter_type="nvidia",
|
||||
provider_type="remote::nvidia",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.nvidia",
|
||||
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
|
||||
description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="runpod",
|
||||
pip_packages=["openai"],
|
||||
module="llama_stack.providers.remote.inference.runpod",
|
||||
config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig",
|
||||
description="RunPod inference provider for running models on RunPod's cloud GPU platform.",
|
||||
),
|
||||
adapter_type="runpod",
|
||||
provider_type="remote::runpod",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.runpod",
|
||||
config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig",
|
||||
description="RunPod inference provider for running models on RunPod's cloud GPU platform.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="openai",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.openai",
|
||||
config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
|
||||
description="OpenAI inference provider for accessing GPT models and other OpenAI services.",
|
||||
),
|
||||
adapter_type="openai",
|
||||
provider_type="remote::openai",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.openai",
|
||||
config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
|
||||
description="OpenAI inference provider for accessing GPT models and other OpenAI services.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="anthropic",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.anthropic",
|
||||
config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",
|
||||
description="Anthropic inference provider for accessing Claude models and Anthropic's AI services.",
|
||||
),
|
||||
adapter_type="anthropic",
|
||||
provider_type="remote::anthropic",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.anthropic",
|
||||
config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",
|
||||
description="Anthropic inference provider for accessing Claude models and Anthropic's AI services.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="gemini",
|
||||
pip_packages=["litellm", "openai"],
|
||||
module="llama_stack.providers.remote.inference.gemini",
|
||||
config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
|
||||
description="Google Gemini inference provider for accessing Gemini models and Google's AI services.",
|
||||
),
|
||||
adapter_type="gemini",
|
||||
provider_type="remote::gemini",
|
||||
pip_packages=[
|
||||
"litellm",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.gemini",
|
||||
config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
|
||||
description="Google Gemini inference provider for accessing Gemini models and Google's AI services.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="vertexai",
|
||||
pip_packages=["litellm", "google-cloud-aiplatform", "openai"],
|
||||
module="llama_stack.providers.remote.inference.vertexai",
|
||||
config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator",
|
||||
description="""Google Vertex AI inference provider enables you to use Google's Gemini models through Google Cloud's Vertex AI platform, providing several advantages:
|
||||
adapter_type="vertexai",
|
||||
provider_type="remote::vertexai",
|
||||
pip_packages=[
|
||||
"litellm",
|
||||
"google-cloud-aiplatform",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.vertexai",
|
||||
config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator",
|
||||
description="""Google Vertex AI inference provider enables you to use Google's Gemini models through Google Cloud's Vertex AI platform, providing several advantages:
|
||||
|
||||
• Enterprise-grade security: Uses Google Cloud's security controls and IAM
|
||||
• Better integration: Seamless integration with other Google Cloud services
|
||||
|
@ -242,61 +228,73 @@ Available Models:
|
|||
- vertex_ai/gemini-2.0-flash
|
||||
- vertex_ai/gemini-2.5-flash
|
||||
- vertex_ai/gemini-2.5-pro""",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="groq",
|
||||
pip_packages=["litellm", "openai"],
|
||||
module="llama_stack.providers.remote.inference.groq",
|
||||
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
|
||||
description="Groq inference provider for ultra-fast inference using Groq's LPU technology.",
|
||||
),
|
||||
adapter_type="groq",
|
||||
provider_type="remote::groq",
|
||||
pip_packages=[
|
||||
"litellm",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.groq",
|
||||
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
|
||||
description="Groq inference provider for ultra-fast inference using Groq's LPU technology.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="llama-openai-compat",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.llama_openai_compat",
|
||||
config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
|
||||
description="Llama OpenAI-compatible provider for using Llama models with OpenAI API format.",
|
||||
),
|
||||
adapter_type="llama-openai-compat",
|
||||
provider_type="remote::llama-openai-compat",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.llama_openai_compat",
|
||||
config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
|
||||
description="Llama OpenAI-compatible provider for using Llama models with OpenAI API format.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="sambanova",
|
||||
pip_packages=["litellm", "openai"],
|
||||
module="llama_stack.providers.remote.inference.sambanova",
|
||||
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
|
||||
description="SambaNova inference provider for running models on SambaNova's dataflow architecture.",
|
||||
),
|
||||
adapter_type="sambanova",
|
||||
provider_type="remote::sambanova",
|
||||
pip_packages=[
|
||||
"litellm",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.sambanova",
|
||||
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
|
||||
description="SambaNova inference provider for running models on SambaNova's dataflow architecture.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="passthrough",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.passthrough",
|
||||
config_class="llama_stack.providers.remote.inference.passthrough.PassthroughImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator",
|
||||
description="Passthrough inference provider for connecting to any external inference service not directly supported.",
|
||||
),
|
||||
adapter_type="passthrough",
|
||||
provider_type="remote::passthrough",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.remote.inference.passthrough",
|
||||
config_class="llama_stack.providers.remote.inference.passthrough.PassthroughImplConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator",
|
||||
description="Passthrough inference provider for connecting to any external inference service not directly supported.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="watsonx",
|
||||
pip_packages=["ibm_watsonx_ai"],
|
||||
module="llama_stack.providers.remote.inference.watsonx",
|
||||
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
|
||||
description="IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform.",
|
||||
),
|
||||
adapter_type="watsonx",
|
||||
provider_type="remote::watsonx",
|
||||
pip_packages=["ibm_watsonx_ai"],
|
||||
module="llama_stack.providers.remote.inference.watsonx",
|
||||
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
|
||||
description="IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform.",
|
||||
),
|
||||
RemoteProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="remote::azure",
|
||||
adapter_type="azure",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.azure",
|
||||
config_class="llama_stack.providers.remote.inference.azure.AzureConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.azure.config.AzureProviderDataValidator",
|
||||
description="""
|
||||
Azure OpenAI inference provider for accessing GPT models and other Azure services.
|
||||
Provider documentation
|
||||
https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview
|
||||
""",
|
||||
),
|
||||
]
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
|
||||
from typing import cast
|
||||
|
||||
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
|
||||
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
|
||||
|
||||
# We provide two versions of these providers so that distributions can package the appropriate version of torch.
|
||||
# The CPU version is used for distributions that don't have GPU support -- they result in smaller container images.
|
||||
|
@ -57,14 +57,13 @@ def available_providers() -> list[ProviderSpec]:
|
|||
],
|
||||
description="HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.post_training,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="nvidia",
|
||||
pip_packages=["requests", "aiohttp"],
|
||||
module="llama_stack.providers.remote.post_training.nvidia",
|
||||
config_class="llama_stack.providers.remote.post_training.nvidia.NvidiaPostTrainingConfig",
|
||||
description="NVIDIA's post-training provider for fine-tuning models on NVIDIA's platform.",
|
||||
),
|
||||
adapter_type="nvidia",
|
||||
provider_type="remote::nvidia",
|
||||
pip_packages=["requests", "aiohttp"],
|
||||
module="llama_stack.providers.remote.post_training.nvidia",
|
||||
config_class="llama_stack.providers.remote.post_training.nvidia.NvidiaPostTrainingConfig",
|
||||
description="NVIDIA's post-training provider for fine-tuning models on NVIDIA's platform.",
|
||||
),
|
||||
]
|
||||
|
|
|
@ -6,11 +6,10 @@
|
|||
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
RemoteProviderSpec,
|
||||
)
|
||||
|
||||
|
||||
|
@ -48,35 +47,32 @@ def available_providers() -> list[ProviderSpec]:
|
|||
config_class="llama_stack.providers.inline.safety.code_scanner.CodeScannerConfig",
|
||||
description="Code Scanner safety provider for detecting security vulnerabilities and unsafe code patterns.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.safety,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="bedrock",
|
||||
pip_packages=["boto3"],
|
||||
module="llama_stack.providers.remote.safety.bedrock",
|
||||
config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig",
|
||||
description="AWS Bedrock safety provider for content moderation using AWS's safety services.",
|
||||
),
|
||||
adapter_type="bedrock",
|
||||
provider_type="remote::bedrock",
|
||||
pip_packages=["boto3"],
|
||||
module="llama_stack.providers.remote.safety.bedrock",
|
||||
config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig",
|
||||
description="AWS Bedrock safety provider for content moderation using AWS's safety services.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.safety,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="nvidia",
|
||||
pip_packages=["requests"],
|
||||
module="llama_stack.providers.remote.safety.nvidia",
|
||||
config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig",
|
||||
description="NVIDIA's safety provider for content moderation and safety filtering.",
|
||||
),
|
||||
adapter_type="nvidia",
|
||||
provider_type="remote::nvidia",
|
||||
pip_packages=["requests"],
|
||||
module="llama_stack.providers.remote.safety.nvidia",
|
||||
config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig",
|
||||
description="NVIDIA's safety provider for content moderation and safety filtering.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.safety,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="sambanova",
|
||||
pip_packages=["litellm", "requests"],
|
||||
module="llama_stack.providers.remote.safety.sambanova",
|
||||
config_class="llama_stack.providers.remote.safety.sambanova.SambaNovaSafetyConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.safety.sambanova.config.SambaNovaProviderDataValidator",
|
||||
description="SambaNova's safety provider for content moderation and safety filtering.",
|
||||
),
|
||||
adapter_type="sambanova",
|
||||
provider_type="remote::sambanova",
|
||||
pip_packages=["litellm", "requests"],
|
||||
module="llama_stack.providers.remote.safety.sambanova",
|
||||
config_class="llama_stack.providers.remote.safety.sambanova.SambaNovaSafetyConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.safety.sambanova.config.SambaNovaProviderDataValidator",
|
||||
description="SambaNova's safety provider for content moderation and safety filtering.",
|
||||
),
|
||||
]
|
||||
|
|
|
@ -38,7 +38,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
InlineProviderSpec(
|
||||
api=Api.scoring,
|
||||
provider_type="inline::braintrust",
|
||||
pip_packages=["autoevals", "openai"],
|
||||
pip_packages=["autoevals"],
|
||||
module="llama_stack.providers.inline.scoring.braintrust",
|
||||
config_class="llama_stack.providers.inline.scoring.braintrust.BraintrustScoringConfig",
|
||||
api_dependencies=[
|
||||
|
|
|
@ -6,11 +6,10 @@
|
|||
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
RemoteProviderSpec,
|
||||
)
|
||||
|
||||
|
||||
|
@ -35,59 +34,54 @@ def available_providers() -> list[ProviderSpec]:
|
|||
api_dependencies=[Api.vector_io, Api.inference, Api.files],
|
||||
description="RAG (Retrieval-Augmented Generation) tool runtime for document ingestion, chunking, and semantic search.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.tool_runtime,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="brave-search",
|
||||
module="llama_stack.providers.remote.tool_runtime.brave_search",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.brave_search.config.BraveSearchToolConfig",
|
||||
pip_packages=["requests"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
|
||||
description="Brave Search tool for web search capabilities with privacy-focused results.",
|
||||
),
|
||||
adapter_type="brave-search",
|
||||
provider_type="remote::brave-search",
|
||||
module="llama_stack.providers.remote.tool_runtime.brave_search",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.brave_search.config.BraveSearchToolConfig",
|
||||
pip_packages=["requests"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
|
||||
description="Brave Search tool for web search capabilities with privacy-focused results.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.tool_runtime,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="bing-search",
|
||||
module="llama_stack.providers.remote.tool_runtime.bing_search",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.bing_search.config.BingSearchToolConfig",
|
||||
pip_packages=["requests"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.bing_search.BingSearchToolProviderDataValidator",
|
||||
description="Bing Search tool for web search capabilities using Microsoft's search engine.",
|
||||
),
|
||||
adapter_type="bing-search",
|
||||
provider_type="remote::bing-search",
|
||||
module="llama_stack.providers.remote.tool_runtime.bing_search",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.bing_search.config.BingSearchToolConfig",
|
||||
pip_packages=["requests"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.bing_search.BingSearchToolProviderDataValidator",
|
||||
description="Bing Search tool for web search capabilities using Microsoft's search engine.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.tool_runtime,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="tavily-search",
|
||||
module="llama_stack.providers.remote.tool_runtime.tavily_search",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.tavily_search.config.TavilySearchToolConfig",
|
||||
pip_packages=["requests"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator",
|
||||
description="Tavily Search tool for AI-optimized web search with structured results.",
|
||||
),
|
||||
adapter_type="tavily-search",
|
||||
provider_type="remote::tavily-search",
|
||||
module="llama_stack.providers.remote.tool_runtime.tavily_search",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.tavily_search.config.TavilySearchToolConfig",
|
||||
pip_packages=["requests"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator",
|
||||
description="Tavily Search tool for AI-optimized web search with structured results.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.tool_runtime,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="wolfram-alpha",
|
||||
module="llama_stack.providers.remote.tool_runtime.wolfram_alpha",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.wolfram_alpha.config.WolframAlphaToolConfig",
|
||||
pip_packages=["requests"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.wolfram_alpha.WolframAlphaToolProviderDataValidator",
|
||||
description="Wolfram Alpha tool for computational knowledge and mathematical calculations.",
|
||||
),
|
||||
adapter_type="wolfram-alpha",
|
||||
provider_type="remote::wolfram-alpha",
|
||||
module="llama_stack.providers.remote.tool_runtime.wolfram_alpha",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.wolfram_alpha.config.WolframAlphaToolConfig",
|
||||
pip_packages=["requests"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.wolfram_alpha.WolframAlphaToolProviderDataValidator",
|
||||
description="Wolfram Alpha tool for computational knowledge and mathematical calculations.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
RemoteProviderSpec(
|
||||
api=Api.tool_runtime,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="model-context-protocol",
|
||||
module="llama_stack.providers.remote.tool_runtime.model_context_protocol",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderConfig",
|
||||
pip_packages=["mcp>=1.8.1"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderDataValidator",
|
||||
description="Model Context Protocol (MCP) tool for standardized tool calling and context management.",
|
||||
),
|
||||
adapter_type="model-context-protocol",
|
||||
provider_type="remote::model-context-protocol",
|
||||
module="llama_stack.providers.remote.tool_runtime.model_context_protocol",
|
||||
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderConfig",
|
||||
pip_packages=["mcp>=1.8.1"],
|
||||
provider_data_validator="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderDataValidator",
|
||||
description="Model Context Protocol (MCP) tool for standardized tool calling and context management.",
|
||||
),
|
||||
]
|
||||
|
|
|
@ -6,11 +6,10 @@
|
|||
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
RemoteProviderSpec,
|
||||
)
|
||||
|
||||
|
||||
|
@ -300,14 +299,16 @@ See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) f
|
|||
Please refer to the sqlite-vec provider documentation.
|
||||
""",
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.vector_io,
|
||||
AdapterSpec(
|
||||
adapter_type="chromadb",
|
||||
pip_packages=["chromadb-client"],
|
||||
module="llama_stack.providers.remote.vector_io.chroma",
|
||||
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig",
|
||||
description="""
|
||||
RemoteProviderSpec(
|
||||
api=Api.vector_io,
|
||||
adapter_type="chromadb",
|
||||
provider_type="remote::chromadb",
|
||||
pip_packages=["chromadb-client"],
|
||||
module="llama_stack.providers.remote.vector_io.chroma",
|
||||
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
description="""
|
||||
[Chroma](https://www.trychroma.com/) is an inline and remote vector
|
||||
database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database.
|
||||
That means you're not limited to storing vectors in memory or in a separate service.
|
||||
|
@ -340,9 +341,6 @@ pip install chromadb
|
|||
## Documentation
|
||||
See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introduction) for more details about Chroma in general.
|
||||
""",
|
||||
),
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
|
@ -387,14 +385,16 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
|
|||
|
||||
""",
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.vector_io,
|
||||
AdapterSpec(
|
||||
adapter_type="pgvector",
|
||||
pip_packages=["psycopg2-binary"],
|
||||
module="llama_stack.providers.remote.vector_io.pgvector",
|
||||
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig",
|
||||
description="""
|
||||
RemoteProviderSpec(
|
||||
api=Api.vector_io,
|
||||
adapter_type="pgvector",
|
||||
provider_type="remote::pgvector",
|
||||
pip_packages=["psycopg2-binary"],
|
||||
module="llama_stack.providers.remote.vector_io.pgvector",
|
||||
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
description="""
|
||||
[PGVector](https://github.com/pgvector/pgvector) is a remote vector database provider for Llama Stack. It
|
||||
allows you to store and query vectors directly in memory.
|
||||
That means you'll get fast and efficient vector retrieval.
|
||||
|
@ -410,7 +410,7 @@ There are three implementations of search for PGVectoIndex available:
|
|||
- How it works:
|
||||
- Uses PostgreSQL's vector extension (pgvector) to perform similarity search
|
||||
- Compares query embeddings against stored embeddings using Cosine distance or other distance metrics
|
||||
- Eg. SQL query: SELECT document, embedding <=> %s::vector AS distance FROM table ORDER BY distance
|
||||
- Eg. SQL query: SELECT document, embedding <=> %s::vector AS distance FROM table ORDER BY distance
|
||||
|
||||
-Characteristics:
|
||||
- Semantic understanding - finds documents similar in meaning even if they don't share keywords
|
||||
|
@ -495,19 +495,18 @@ docker pull pgvector/pgvector:pg17
|
|||
## Documentation
|
||||
See [PGVector's documentation](https://github.com/pgvector/pgvector) for more details about PGVector in general.
|
||||
""",
|
||||
),
|
||||
),
|
||||
RemoteProviderSpec(
|
||||
api=Api.vector_io,
|
||||
adapter_type="weaviate",
|
||||
provider_type="remote::weaviate",
|
||||
pip_packages=["weaviate-client"],
|
||||
module="llama_stack.providers.remote.vector_io.weaviate",
|
||||
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.vector_io,
|
||||
AdapterSpec(
|
||||
adapter_type="weaviate",
|
||||
pip_packages=["weaviate-client"],
|
||||
module="llama_stack.providers.remote.vector_io.weaviate",
|
||||
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
|
||||
description="""
|
||||
description="""
|
||||
[Weaviate](https://weaviate.io/) is a vector database provider for Llama Stack.
|
||||
It allows you to store and query vectors directly within a Weaviate database.
|
||||
That means you're not limited to storing vectors in memory or in a separate service.
|
||||
|
@ -538,9 +537,6 @@ To install Weaviate see the [Weaviate quickstart documentation](https://weaviate
|
|||
## Documentation
|
||||
See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more details about Weaviate in general.
|
||||
""",
|
||||
),
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
|
@ -594,28 +590,29 @@ docker pull qdrant/qdrant
|
|||
See the [Qdrant documentation](https://qdrant.tech/documentation/) for more details about Qdrant in general.
|
||||
""",
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.vector_io,
|
||||
AdapterSpec(
|
||||
adapter_type="qdrant",
|
||||
pip_packages=["qdrant-client"],
|
||||
module="llama_stack.providers.remote.vector_io.qdrant",
|
||||
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig",
|
||||
description="""
|
||||
Please refer to the inline provider documentation.
|
||||
""",
|
||||
),
|
||||
RemoteProviderSpec(
|
||||
api=Api.vector_io,
|
||||
adapter_type="qdrant",
|
||||
provider_type="remote::qdrant",
|
||||
pip_packages=["qdrant-client"],
|
||||
module="llama_stack.providers.remote.vector_io.qdrant",
|
||||
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
description="""
|
||||
Please refer to the inline provider documentation.
|
||||
""",
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.vector_io,
|
||||
AdapterSpec(
|
||||
adapter_type="milvus",
|
||||
pip_packages=["pymilvus>=2.4.10"],
|
||||
module="llama_stack.providers.remote.vector_io.milvus",
|
||||
config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig",
|
||||
description="""
|
||||
RemoteProviderSpec(
|
||||
api=Api.vector_io,
|
||||
adapter_type="milvus",
|
||||
provider_type="remote::milvus",
|
||||
pip_packages=["pymilvus>=2.4.10"],
|
||||
module="llama_stack.providers.remote.vector_io.milvus",
|
||||
config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
description="""
|
||||
[Milvus](https://milvus.io/) is an inline and remote vector database provider for Llama Stack. It
|
||||
allows you to store and query vectors directly within a Milvus database.
|
||||
That means you're not limited to storing vectors in memory or in a separate service.
|
||||
|
@ -636,7 +633,13 @@ To use Milvus in your Llama Stack project, follow these steps:
|
|||
|
||||
## Installation
|
||||
|
||||
You can install Milvus using pymilvus:
|
||||
If you want to use inline Milvus, you can install:
|
||||
|
||||
```bash
|
||||
pip install pymilvus[milvus-lite]
|
||||
```
|
||||
|
||||
If you want to use remote Milvus, you can install:
|
||||
|
||||
```bash
|
||||
pip install pymilvus
|
||||
|
@ -806,14 +809,11 @@ See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for m
|
|||
|
||||
For more details on TLS configuration, refer to the [TLS setup guide](https://milvus.io/docs/tls.md).
|
||||
""",
|
||||
),
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.vector_io,
|
||||
provider_type="inline::milvus",
|
||||
pip_packages=["pymilvus>=2.4.10"],
|
||||
pip_packages=["pymilvus[milvus-lite]>=2.4.10"],
|
||||
module="llama_stack.providers.inline.vector_io.milvus",
|
||||
config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
|
|
|
@ -14,7 +14,6 @@ from llama_stack.apis.datasets import Datasets
|
|||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.scoring import Scoring, ScoringResult
|
||||
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
|
||||
from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
|
||||
from .....apis.common.job_types import Job, JobStatus
|
||||
|
@ -45,24 +44,29 @@ class NVIDIAEvalImpl(
|
|||
self.inference_api = inference_api
|
||||
self.agents_api = agents_api
|
||||
|
||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||
ModelRegistryHelper.__init__(self)
|
||||
|
||||
async def initialize(self) -> None: ...
|
||||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def _evaluator_get(self, path):
|
||||
async def _evaluator_get(self, path: str):
|
||||
"""Helper for making GET requests to the evaluator service."""
|
||||
response = requests.get(url=f"{self.config.evaluator_url}{path}")
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def _evaluator_post(self, path, data):
|
||||
async def _evaluator_post(self, path: str, data: dict[str, Any]):
|
||||
"""Helper for making POST requests to the evaluator service."""
|
||||
response = requests.post(url=f"{self.config.evaluator_url}{path}", json=data)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def _evaluator_delete(self, path: str) -> None:
|
||||
"""Helper for making DELETE requests to the evaluator service."""
|
||||
response = requests.delete(url=f"{self.config.evaluator_url}{path}")
|
||||
response.raise_for_status()
|
||||
|
||||
async def register_benchmark(self, task_def: Benchmark) -> None:
|
||||
"""Register a benchmark as an evaluation configuration."""
|
||||
await self._evaluator_post(
|
||||
|
@ -75,6 +79,10 @@ class NVIDIAEvalImpl(
|
|||
},
|
||||
)
|
||||
|
||||
async def unregister_benchmark(self, benchmark_id: str) -> None:
|
||||
"""Unregister a benchmark evaluation configuration from NeMo Evaluator."""
|
||||
await self._evaluator_delete(f"/v1/evaluation/configs/{DEFAULT_NAMESPACE}/{benchmark_id}")
|
||||
|
||||
async def run_eval(
|
||||
self,
|
||||
benchmark_id: str,
|
||||
|
|
|
@ -10,7 +10,7 @@ from typing import Annotated, Any
|
|||
|
||||
import boto3
|
||||
from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError
|
||||
from fastapi import File, Form, Response, UploadFile
|
||||
from fastapi import Depends, File, Form, Response, UploadFile
|
||||
|
||||
from llama_stack.apis.common.errors import ResourceNotFoundError
|
||||
from llama_stack.apis.common.responses import Order
|
||||
|
@ -23,6 +23,7 @@ from llama_stack.apis.files import (
|
|||
OpenAIFilePurpose,
|
||||
)
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.providers.utils.files.form_data import parse_expires_after
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
|
||||
|
@ -137,7 +138,7 @@ class S3FilesImpl(Files):
|
|||
where: dict[str, str | dict] = {"id": file_id}
|
||||
if not return_expired:
|
||||
where["expires_at"] = {">": self._now()}
|
||||
if not (row := await self.sql_store.fetch_one("openai_files", policy=self.policy, where=where)):
|
||||
if not (row := await self.sql_store.fetch_one("openai_files", where=where)):
|
||||
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
||||
return row
|
||||
|
||||
|
@ -164,7 +165,7 @@ class S3FilesImpl(Files):
|
|||
self._client = _create_s3_client(self._config)
|
||||
await _create_bucket_if_not_exists(self._client, self._config)
|
||||
|
||||
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store))
|
||||
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store), self.policy)
|
||||
await self._sql_store.create_table(
|
||||
"openai_files",
|
||||
{
|
||||
|
@ -195,8 +196,7 @@ class S3FilesImpl(Files):
|
|||
self,
|
||||
file: Annotated[UploadFile, File()],
|
||||
purpose: Annotated[OpenAIFilePurpose, Form()],
|
||||
expires_after_anchor: Annotated[str | None, Form(alias="expires_after[anchor]")] = None,
|
||||
expires_after_seconds: Annotated[int | None, Form(alias="expires_after[seconds]")] = None,
|
||||
expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None,
|
||||
) -> OpenAIFileObject:
|
||||
file_id = f"file-{uuid.uuid4().hex}"
|
||||
|
||||
|
@ -204,14 +204,6 @@ class S3FilesImpl(Files):
|
|||
|
||||
created_at = self._now()
|
||||
|
||||
expires_after = None
|
||||
if expires_after_anchor is not None or expires_after_seconds is not None:
|
||||
# we use ExpiresAfter to validate input
|
||||
expires_after = ExpiresAfter(
|
||||
anchor=expires_after_anchor, # type: ignore[arg-type]
|
||||
seconds=expires_after_seconds, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# the default is no expiration.
|
||||
# to implement no expiration we set an expiration beyond the max.
|
||||
# we'll hide this fact from users when returning the file object.
|
||||
|
@ -268,7 +260,6 @@ class S3FilesImpl(Files):
|
|||
|
||||
paginated_result = await self.sql_store.fetch_all(
|
||||
table="openai_files",
|
||||
policy=self.policy,
|
||||
where=where_conditions,
|
||||
order_by=[("created_at", order.value)],
|
||||
cursor=("id", after) if after else None,
|
||||
|
|
|
@ -4,15 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import AnthropicConfig
|
||||
|
||||
|
||||
class AnthropicProviderDataValidator(BaseModel):
|
||||
anthropic_api_key: str | None = None
|
||||
|
||||
|
||||
async def get_adapter_impl(config: AnthropicConfig, _deps):
|
||||
from .anthropic import AnthropicInferenceAdapter
|
||||
|
||||
|
|
|
@ -8,14 +8,24 @@ from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOp
|
|||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import AnthropicConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class AnthropicInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
# source: https://docs.claude.com/en/docs/build-with-claude/embeddings
|
||||
# TODO: add support for voyageai, which is where these models are hosted
|
||||
# embedding_model_metadata = {
|
||||
# "voyage-3-large": {"embedding_dimension": 1024, "context_length": 32000}, # supports dimensions 256, 512, 1024, 2048
|
||||
# "voyage-3.5": {"embedding_dimension": 1024, "context_length": 32000}, # supports dimensions 256, 512, 1024, 2048
|
||||
# "voyage-3.5-lite": {"embedding_dimension": 1024, "context_length": 32000}, # supports dimensions 256, 512, 1024, 2048
|
||||
# "voyage-code-3": {"embedding_dimension": 1024, "context_length": 32000}, # supports dimensions 256, 512, 1024, 2048
|
||||
# "voyage-finance-2": {"embedding_dimension": 1024, "context_length": 32000},
|
||||
# "voyage-law-2": {"embedding_dimension": 1024, "context_length": 16000},
|
||||
# "voyage-multimodal-3": {"embedding_dimension": 1024, "context_length": 32000},
|
||||
# }
|
||||
|
||||
def __init__(self, config: AnthropicConfig) -> None:
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
MODEL_ENTRIES,
|
||||
litellm_provider_name="anthropic",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="anthropic_api_key",
|
||||
|
|
|
@ -1,40 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ProviderModelEntry,
|
||||
)
|
||||
|
||||
LLM_MODEL_IDS = [
|
||||
"claude-3-5-sonnet-latest",
|
||||
"claude-3-7-sonnet-latest",
|
||||
"claude-3-5-haiku-latest",
|
||||
]
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
MODEL_ENTRIES = (
|
||||
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
|
||||
+ [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="voyage-3",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 1024, "context_length": 32000},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="voyage-3-lite",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 512, "context_length": 32000},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="voyage-code-3",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 1024, "context_length": 32000},
|
||||
),
|
||||
]
|
||||
+ SAFETY_MODELS_ENTRIES
|
||||
)
|
15
llama_stack/providers/remote/inference/azure/__init__.py
Normal file
15
llama_stack/providers/remote/inference/azure/__init__.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import AzureConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: AzureConfig, _deps):
|
||||
from .azure import AzureInferenceAdapter
|
||||
|
||||
impl = AzureInferenceAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
62
llama_stack/providers/remote/inference/azure/azure.py
Normal file
62
llama_stack/providers/remote/inference/azure/azure.py
Normal file
|
@ -0,0 +1,62 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from llama_stack.apis.inference import ChatCompletionRequest
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
|
||||
LiteLLMOpenAIMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import AzureConfig
|
||||
|
||||
|
||||
class AzureInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
def __init__(self, config: AzureConfig) -> None:
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
litellm_provider_name="azure",
|
||||
api_key_from_config=config.api_key.get_secret_value(),
|
||||
provider_data_api_key_field="azure_api_key",
|
||||
openai_compat_api_base=str(config.api_base),
|
||||
)
|
||||
self.config = config
|
||||
|
||||
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
|
||||
get_api_key = LiteLLMOpenAIMixin.get_api_key
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
"""
|
||||
Get the Azure API base URL.
|
||||
|
||||
Returns the Azure API base URL from the configuration.
|
||||
"""
|
||||
return urljoin(str(self.config.api_base), "/openai/v1")
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
|
||||
# Get base parameters from parent
|
||||
params = await super()._get_params(request)
|
||||
|
||||
# Add Azure specific parameters
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data:
|
||||
if getattr(provider_data, "azure_api_key", None):
|
||||
params["api_key"] = provider_data.azure_api_key
|
||||
if getattr(provider_data, "azure_api_base", None):
|
||||
params["api_base"] = provider_data.azure_api_base
|
||||
if getattr(provider_data, "azure_api_version", None):
|
||||
params["api_version"] = provider_data.azure_api_version
|
||||
if getattr(provider_data, "azure_api_type", None):
|
||||
params["api_type"] = provider_data.azure_api_type
|
||||
else:
|
||||
params["api_key"] = self.config.api_key.get_secret_value()
|
||||
params["api_base"] = str(self.config.api_base)
|
||||
params["api_version"] = self.config.api_version
|
||||
params["api_type"] = self.config.api_type
|
||||
|
||||
return params
|
63
llama_stack/providers/remote/inference/azure/config.py
Normal file
63
llama_stack/providers/remote/inference/azure/config.py
Normal file
|
@ -0,0 +1,63 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, HttpUrl, SecretStr
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class AzureProviderDataValidator(BaseModel):
|
||||
azure_api_key: SecretStr = Field(
|
||||
description="Azure API key for Azure",
|
||||
)
|
||||
azure_api_base: HttpUrl = Field(
|
||||
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com)",
|
||||
)
|
||||
azure_api_version: str | None = Field(
|
||||
default=None,
|
||||
description="Azure API version for Azure (e.g., 2024-06-01)",
|
||||
)
|
||||
azure_api_type: str | None = Field(
|
||||
default="azure",
|
||||
description="Azure API type for Azure (e.g., azure)",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AzureConfig(BaseModel):
|
||||
api_key: SecretStr = Field(
|
||||
description="Azure API key for Azure",
|
||||
)
|
||||
api_base: HttpUrl = Field(
|
||||
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com)",
|
||||
)
|
||||
api_version: str | None = Field(
|
||||
default_factory=lambda: os.getenv("AZURE_API_VERSION"),
|
||||
description="Azure API version for Azure (e.g., 2024-12-01-preview)",
|
||||
)
|
||||
api_type: str | None = Field(
|
||||
default_factory=lambda: os.getenv("AZURE_API_TYPE", "azure"),
|
||||
description="Azure API type for Azure (e.g., azure)",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
api_key: str = "${env.AZURE_API_KEY:=}",
|
||||
api_base: str = "${env.AZURE_API_BASE:=}",
|
||||
api_version: str = "${env.AZURE_API_VERSION:=}",
|
||||
api_type: str = "${env.AZURE_API_TYPE:=}",
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"api_key": api_key,
|
||||
"api_base": api_base,
|
||||
"api_version": api_version,
|
||||
"api_type": api_type,
|
||||
}
|
|
@ -11,21 +11,17 @@ from botocore.client import BaseClient
|
|||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -47,12 +43,47 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
content_has_media,
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
REGION_PREFIX_MAP = {
|
||||
"us": "us.",
|
||||
"eu": "eu.",
|
||||
"ap": "ap.",
|
||||
}
|
||||
|
||||
|
||||
def _get_region_prefix(region: str | None) -> str:
|
||||
# AWS requires region prefixes for inference profiles
|
||||
if region is None:
|
||||
return "us." # default to US when we don't know
|
||||
|
||||
# Handle case insensitive region matching
|
||||
region_lower = region.lower()
|
||||
for prefix in REGION_PREFIX_MAP:
|
||||
if region_lower.startswith(f"{prefix}-"):
|
||||
return REGION_PREFIX_MAP[prefix]
|
||||
|
||||
# Fallback to US for anything we don't recognize
|
||||
return "us."
|
||||
|
||||
|
||||
def _to_inference_profile_id(model_id: str, region: str = None) -> str:
|
||||
# Return ARNs unchanged
|
||||
if model_id.startswith("arn:"):
|
||||
return model_id
|
||||
|
||||
# Return inference profile IDs that already have regional prefixes
|
||||
if any(model_id.startswith(p) for p in REGION_PREFIX_MAP.values()):
|
||||
return model_id
|
||||
|
||||
# Default to US East when no region is provided
|
||||
if region is None:
|
||||
region = "us-east-1"
|
||||
|
||||
return _get_region_prefix(region) + model_id
|
||||
|
||||
|
||||
class BedrockInferenceAdapter(
|
||||
ModelRegistryHelper,
|
||||
|
@ -61,7 +92,7 @@ class BedrockInferenceAdapter(
|
|||
OpenAICompletionToLlamaStackMixin,
|
||||
):
|
||||
def __init__(self, config: BedrockConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||
self._config = config
|
||||
self._client = None
|
||||
|
||||
|
@ -166,8 +197,13 @@ class BedrockInferenceAdapter(
|
|||
options["repetition_penalty"] = sampling_params.repetition_penalty
|
||||
|
||||
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
|
||||
|
||||
# Convert foundation model ID to inference profile ID
|
||||
region_name = self.client.meta.region_name
|
||||
inference_profile_id = _to_inference_profile_id(bedrock_model, region_name)
|
||||
|
||||
return {
|
||||
"modelId": bedrock_model,
|
||||
"modelId": inference_profile_id,
|
||||
"body": json.dumps(
|
||||
{
|
||||
"prompt": prompt,
|
||||
|
@ -176,31 +212,6 @@ class BedrockInferenceAdapter(
|
|||
),
|
||||
}
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
embeddings = []
|
||||
for content in contents:
|
||||
assert not content_has_media(content), "Bedrock does not support media for embeddings"
|
||||
input_text = interleaved_content_as_str(content)
|
||||
input_body = {"inputText": input_text}
|
||||
body = json.dumps(input_body)
|
||||
response = self.client.invoke_model(
|
||||
body=body,
|
||||
modelId=model.provider_resource_id,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
response_body = json.loads(response.get("body").read())
|
||||
embeddings.append(response_body.get("embedding"))
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -5,26 +5,23 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from cerebras.cloud.sdk import AsyncCerebras
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -35,42 +32,41 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
process_completion_response,
|
||||
process_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
)
|
||||
|
||||
from .config import CerebrasImplConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class CerebrasInferenceAdapter(
|
||||
OpenAIMixin,
|
||||
ModelRegistryHelper,
|
||||
Inference,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
):
|
||||
def __init__(self, config: CerebrasImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
)
|
||||
self.config = config
|
||||
|
||||
# TODO: make this use provider data, etc. like other providers
|
||||
self.client = AsyncCerebras(
|
||||
self._cerebras_client = AsyncCerebras(
|
||||
base_url=self.config.base_url,
|
||||
api_key=self.config.api_key.get_secret_value(),
|
||||
)
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_key.get_secret_value()
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return urljoin(self.config.base_url, "v1")
|
||||
|
||||
async def initialize(self) -> None:
|
||||
return
|
||||
|
||||
|
@ -107,14 +103,14 @@ class CerebrasInferenceAdapter(
|
|||
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
|
||||
r = await self.client.completions.create(**params)
|
||||
r = await self._cerebras_client.completions.create(**params)
|
||||
|
||||
return process_completion_response(r)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
stream = await self.client.completions.create(**params)
|
||||
stream = await self._cerebras_client.completions.create(**params)
|
||||
|
||||
async for chunk in process_completion_stream_response(stream):
|
||||
yield chunk
|
||||
|
@ -156,14 +152,14 @@ class CerebrasInferenceAdapter(
|
|||
async def _nonstream_chat_completion(self, request: CompletionRequest) -> CompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
|
||||
r = await self.client.completions.create(**params)
|
||||
r = await self._cerebras_client.completions.create(**params)
|
||||
|
||||
return process_chat_completion_response(r, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = await self._get_params(request)
|
||||
|
||||
stream = await self.client.completions.create(**params)
|
||||
stream = await self._cerebras_client.completions.create(**params)
|
||||
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
@ -187,16 +183,6 @@ class CerebrasInferenceAdapter(
|
|||
**get_sampling_options(request.sampling_params),
|
||||
}
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -20,8 +20,8 @@ class CerebrasImplConfig(BaseModel):
|
|||
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
|
||||
description="Base URL for the Cerebras API",
|
||||
)
|
||||
api_key: SecretStr | None = Field(
|
||||
default=os.environ.get("CEREBRAS_API_KEY"),
|
||||
api_key: SecretStr = Field(
|
||||
default=SecretStr(os.environ.get("CEREBRAS_API_KEY")),
|
||||
description="Cerebras API Key",
|
||||
)
|
||||
|
||||
|
|
|
@ -1,28 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
# https://inference-docs.cerebras.ai/models
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.1-8b",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama-3.3-70b",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama-4-scout-17b-16e-instruct",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
),
|
||||
] + SAFETY_MODELS_ENTRIES
|
|
@ -5,10 +5,11 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from .config import DatabricksImplConfig
|
||||
from .databricks import DatabricksInferenceAdapter
|
||||
|
||||
|
||||
async def get_adapter_impl(config: DatabricksImplConfig, _deps):
|
||||
from .databricks import DatabricksInferenceAdapter
|
||||
|
||||
assert isinstance(config, DatabricksImplConfig), f"Unexpected config type: {type(config)}"
|
||||
impl = DatabricksInferenceAdapter(config)
|
||||
await impl.initialize()
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
@ -17,16 +17,16 @@ class DatabricksImplConfig(BaseModel):
|
|||
default=None,
|
||||
description="The URL for the Databricks model serving endpoint",
|
||||
)
|
||||
api_token: str = Field(
|
||||
default=None,
|
||||
api_token: SecretStr = Field(
|
||||
default=SecretStr(None),
|
||||
description="The Databricks API token",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
url: str = "${env.DATABRICKS_URL:=}",
|
||||
api_token: str = "${env.DATABRICKS_API_TOKEN:=}",
|
||||
url: str = "${env.DATABRICKS_HOST:=}",
|
||||
api_token: str = "${env.DATABRICKS_TOKEN:=}",
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
|
|
|
@ -4,74 +4,59 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from openai import OpenAI
|
||||
from databricks.sdk import WorkspaceClient
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
Model,
|
||||
OpenAICompletion,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
get_sampling_options,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
)
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import DatabricksImplConfig
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
# https://docs.databricks.com/aws/en/machine-learning/model-serving/foundation-model-overview
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"databricks-meta-llama-3-1-70b-instruct",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"databricks-meta-llama-3-1-405b-instruct",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
logger = get_logger(name=__name__, category="inference::databricks")
|
||||
|
||||
|
||||
class DatabricksInferenceAdapter(
|
||||
ModelRegistryHelper,
|
||||
OpenAIMixin,
|
||||
Inference,
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
):
|
||||
# source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models
|
||||
embedding_model_metadata = {
|
||||
"databricks-gte-large-en": {"embedding_dimension": 1024, "context_length": 8192},
|
||||
"databricks-bge-large-en": {"embedding_dimension": 1024, "context_length": 512},
|
||||
}
|
||||
|
||||
def __init__(self, config: DatabricksImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||
self.config = config
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
return self.config.api_token.get_secret_value()
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return f"{self.config.url}/serving-endpoints"
|
||||
|
||||
async def initialize(self) -> None:
|
||||
return
|
||||
|
||||
|
@ -80,89 +65,80 @@ class DatabricksInferenceAdapter(
|
|||
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
model_id: str,
|
||||
content: InterleavedContent,
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str | list[str] | list[int] | list[list[int]],
|
||||
best_of: int | None = None,
|
||||
echo: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str, float] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
max_tokens: int | None = None,
|
||||
n: int | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
seed: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
stream: bool | None = None,
|
||||
stream_options: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
guided_choice: list[str] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
suffix: str | None = None,
|
||||
) -> OpenAICompletion:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
model_id: str,
|
||||
messages: list[Message],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_choice: ToolChoice | None = ToolChoice.auto,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
tools=tools or [],
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
|
||||
client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
||||
if stream:
|
||||
return self._stream_chat_completion(request, client)
|
||||
else:
|
||||
return await self._nonstream_chat_completion(request, client)
|
||||
|
||||
async def _nonstream_chat_completion(
|
||||
self, request: ChatCompletionRequest, client: OpenAI
|
||||
) -> ChatCompletionResponse:
|
||||
params = self._get_params(request)
|
||||
r = client.completions.create(**params)
|
||||
return process_chat_completion_response(r, request)
|
||||
|
||||
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
|
||||
params = self._get_params(request)
|
||||
|
||||
async def _to_async_generator():
|
||||
s = client.completions.create(**params)
|
||||
for chunk in s:
|
||||
yield chunk
|
||||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
return {
|
||||
"model": request.model,
|
||||
"prompt": chat_completion_request_to_prompt(request, self.get_llama_model(request.model)),
|
||||
"stream": request.stream,
|
||||
**get_sampling_options(request.sampling_params),
|
||||
}
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
input: str | list[str],
|
||||
encoding_format: str | None = "float",
|
||||
dimensions: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
self._model_cache = {} # from OpenAIMixin
|
||||
ws_client = WorkspaceClient(host=self.config.url, token=self.get_api_key()) # TODO: this is not async
|
||||
endpoints = ws_client.serving_endpoints.list()
|
||||
for endpoint in endpoints:
|
||||
model = Model(
|
||||
provider_id=self.__provider_id__,
|
||||
provider_resource_id=endpoint.name,
|
||||
identifier=endpoint.name,
|
||||
)
|
||||
if endpoint.task == "llm/v1/chat":
|
||||
model.model_type = ModelType.llm # this is redundant, but informative
|
||||
elif endpoint.task == "llm/v1/embeddings":
|
||||
if endpoint.name not in self.embedding_model_metadata:
|
||||
logger.warning(f"No metadata information available for embedding model {endpoint.name}, skipping.")
|
||||
continue
|
||||
model.model_type = ModelType.embedding
|
||||
model.metadata = self.embedding_model_metadata[endpoint.name]
|
||||
else:
|
||||
logger.warning(f"Unknown model type, skipping: {endpoint}")
|
||||
continue
|
||||
|
||||
self._model_cache[endpoint.name] = model
|
||||
|
||||
return list(self._model_cache.values())
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue