Merge branch 'main' into feat/add-dana-agent-provider-stub

This commit is contained in:
Zooey Nguyen 2025-11-14 09:22:10 -08:00 committed by GitHub
commit 3b3a2d0ceb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
418 changed files with 24245 additions and 1794 deletions

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .agents import *

View file

@ -1,9 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .batches import Batches, BatchObject, ListBatchesResponse
__all__ = ["Batches", "BatchObject", "ListBatchesResponse"]

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .benchmarks import *

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,27 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .conversations import (
Conversation,
ConversationDeletedResource,
ConversationItem,
ConversationItemCreateRequest,
ConversationItemDeletedResource,
ConversationItemList,
Conversations,
Metadata,
)
__all__ = [
"Conversation",
"ConversationDeletedResource",
"ConversationItem",
"ConversationItemCreateRequest",
"ConversationItemDeletedResource",
"ConversationItemList",
"Conversations",
"Metadata",
]

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .datasetio import *

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .datasets import *

View file

@ -1,158 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum, EnumMeta
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
class DynamicApiMeta(EnumMeta):
def __new__(cls, name, bases, namespace):
# Store the original enum values
original_values = {k: v for k, v in namespace.items() if not k.startswith("_")}
# Create the enum class
cls = super().__new__(cls, name, bases, namespace)
# Store the original values for reference
cls._original_values = original_values
# Initialize _dynamic_values
cls._dynamic_values = {}
return cls
def __call__(cls, value):
try:
return super().__call__(value)
except ValueError as e:
# If this value was already dynamically added, return it
if value in cls._dynamic_values:
return cls._dynamic_values[value]
# If the value doesn't exist, create a new enum member
# Create a new member name from the value
member_name = value.lower().replace("-", "_")
# If this member name already exists in the enum, return the existing member
if member_name in cls._member_map_:
return cls._member_map_[member_name]
# Instead of creating a new member, raise ValueError to force users to use Api.add() to
# register new APIs explicitly
raise ValueError(f"API '{value}' does not exist. Use Api.add() to register new APIs.") from e
def __iter__(cls):
# Allow iteration over both static and dynamic members
yield from super().__iter__()
if hasattr(cls, "_dynamic_values"):
yield from cls._dynamic_values.values()
def add(cls, value):
"""
Add a new API to the enum.
Used to register external APIs.
"""
member_name = value.lower().replace("-", "_")
# If this member name already exists in the enum, return it
if member_name in cls._member_map_:
return cls._member_map_[member_name]
# Create a new enum member
member = object.__new__(cls)
member._name_ = member_name
member._value_ = value
# Add it to the enum class
cls._member_map_[member_name] = member
cls._member_names_.append(member_name)
cls._member_type_ = str
# Store it in our dynamic values
cls._dynamic_values[value] = member
return member
@json_schema_type
class Api(Enum, metaclass=DynamicApiMeta):
"""Enumeration of all available APIs in the Llama Stack system.
:cvar providers: Provider management and configuration
:cvar inference: Text generation, chat completions, and embeddings
:cvar safety: Content moderation and safety shields
:cvar agents: Agent orchestration and execution
:cvar batches: Batch processing for asynchronous API requests
:cvar vector_io: Vector database operations and queries
:cvar datasetio: Dataset input/output operations
:cvar scoring: Model output evaluation and scoring
:cvar eval: Model evaluation and benchmarking framework
:cvar post_training: Fine-tuning and model training
:cvar tool_runtime: Tool execution and management
:cvar telemetry: Observability and system monitoring
:cvar models: Model metadata and management
:cvar shields: Safety shield implementations
:cvar datasets: Dataset creation and management
:cvar scoring_functions: Scoring function definitions
:cvar benchmarks: Benchmark suite management
:cvar tool_groups: Tool group organization
:cvar files: File storage and management
:cvar prompts: Prompt versions and management
:cvar inspect: Built-in system inspection and introspection
"""
providers = "providers"
inference = "inference"
safety = "safety"
agents = "agents"
batches = "batches"
vector_io = "vector_io"
datasetio = "datasetio"
scoring = "scoring"
eval = "eval"
post_training = "post_training"
tool_runtime = "tool_runtime"
models = "models"
shields = "shields"
vector_stores = "vector_stores" # only used for routing table
datasets = "datasets"
scoring_functions = "scoring_functions"
benchmarks = "benchmarks"
tool_groups = "tool_groups"
files = "files"
prompts = "prompts"
conversations = "conversations"
# built-in API
inspect = "inspect"
@json_schema_type
class Error(BaseModel):
"""
Error response from the API. Roughly follows RFC 7807.
:param status: HTTP status code
:param title: Error title, a short summary of the error which is invariant for an error type
:param detail: Error detail, a longer human-readable description of the error
:param instance: (Optional) A URL which can be used to retrieve more information about the specific occurrence of the error
"""
status: int
title: str
detail: str
instance: str | None = None
class ExternalApiSpec(BaseModel):
"""Specification for an external API implementation."""
module: str = Field(..., description="Python module containing the API implementation")
name: str = Field(..., description="Name of the API")
pip_packages: list[str] = Field(default=[], description="List of pip packages to install the API")
protocol: str = Field(..., description="Name of the protocol class for the API")

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .eval import *

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .files import *

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .inference import *

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .inspect import *

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .models import *

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .post_training import *

View file

@ -1,9 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .prompts import ListPromptsResponse, Prompt, Prompts
__all__ = ["Prompt", "Prompts", "ListPromptsResponse"]

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .providers import *

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .safety import *

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .scoring import *

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .scoring_functions import *

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .shields import *

View file

@ -1,8 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .rag_tool import *
from .tools import *

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .vector_io import *

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .vector_stores import *

View file

@ -21,7 +21,7 @@ from llama_stack.core.datatypes import (
from llama_stack.core.distribution import get_provider_registry
from llama_stack.core.stack import replace_env_vars
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
from llama_stack_api import Api
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"

View file

@ -32,7 +32,7 @@ from llama_stack.core.storage.datatypes import (
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR, EXTERNAL_PROVIDERS_DIR
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.core.utils.image_types import LlamaStackImageType
from llama_stack.providers.datatypes import Api
from llama_stack_api import Api
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "distributions"

View file

@ -13,7 +13,7 @@ from llama_stack.core.datatypes import BuildConfig
from llama_stack.core.distribution import get_provider_registry
from llama_stack.distributions.template import DistributionTemplate
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
from llama_stack_api import Api
log = get_logger(name=__name__, category="core")

View file

@ -15,7 +15,7 @@ import httpx
from pydantic import BaseModel, parse_obj_as
from termcolor import cprint
from llama_stack.providers.datatypes import RemoteProviderConfig
from llama_stack_api import RemoteProviderConfig
_CLIENT_CLASSES = {}

View file

@ -20,7 +20,7 @@ from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.core.utils.prompt_for_config import prompt_for_config
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack_api import Api, ProviderSpec
logger = get_logger(name=__name__, category="core")

View file

@ -10,7 +10,12 @@ from typing import Any, Literal
from pydantic import BaseModel, TypeAdapter
from llama_stack.apis.conversations.conversations import (
from llama_stack.core.datatypes import AccessRule, StackRunConfig
from llama_stack.log import get_logger
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
from llama_stack_api import (
Conversation,
ConversationDeletedResource,
ConversationItem,
@ -20,11 +25,6 @@ from llama_stack.apis.conversations.conversations import (
Conversations,
Metadata,
)
from llama_stack.core.datatypes import AccessRule, StackRunConfig
from llama_stack.log import get_logger
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
logger = get_logger(name=__name__, category="openai_conversations")

View file

@ -11,20 +11,6 @@ from urllib.parse import urlparse
from pydantic import BaseModel, Field, field_validator, model_validator
from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Dataset, DatasetInput
from llama_stack.apis.eval import Eval
from llama_stack.apis.inference import Inference
from llama_stack.apis.models import Model, ModelInput
from llama_stack.apis.resource import Resource
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
from llama_stack.apis.shields import Shield, ShieldInput
from llama_stack.apis.tools import ToolGroup, ToolGroupInput, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.apis.vector_stores import VectorStore, VectorStoreInput
from llama_stack.core.access_control.datatypes import AccessRule
from llama_stack.core.storage.datatypes import (
KVStoreReference,
@ -32,7 +18,32 @@ from llama_stack.core.storage.datatypes import (
StorageConfig,
)
from llama_stack.log import LoggingConfig
from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack_api import (
Api,
Benchmark,
BenchmarkInput,
Dataset,
DatasetInput,
DatasetIO,
Eval,
Inference,
Model,
ModelInput,
ProviderSpec,
Resource,
Safety,
Scoring,
ScoringFn,
ScoringFnInput,
Shield,
ShieldInput,
ToolGroup,
ToolGroupInput,
ToolRuntime,
VectorIO,
VectorStore,
VectorStoreInput,
)
LLAMA_STACK_BUILD_CONFIG_VERSION = 2
LLAMA_STACK_RUN_CONFIG_VERSION = 2

View file

@ -15,7 +15,7 @@ from pydantic import BaseModel
from llama_stack.core.datatypes import BuildConfig, DistributionSpec
from llama_stack.core.external import load_external_apis
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import (
from llama_stack_api import (
Api,
InlineProviderSpec,
ProviderSpec,

View file

@ -7,9 +7,9 @@
import yaml
from llama_stack.apis.datatypes import Api, ExternalApiSpec
from llama_stack.core.datatypes import BuildConfig, StackRunConfig
from llama_stack.log import get_logger
from llama_stack_api import Api, ExternalApiSpec
logger = get_logger(name=__name__, category="core")

View file

@ -8,17 +8,17 @@ from importlib.metadata import version
from pydantic import BaseModel
from llama_stack.apis.inspect import (
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.core.external import load_external_apis
from llama_stack.core.server.routes import get_all_api_routes
from llama_stack_api import (
HealthInfo,
HealthStatus,
Inspect,
ListRoutesResponse,
RouteInfo,
VersionInfo,
)
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.core.external import load_external_apis
from llama_stack.core.server.routes import get_all_api_routes
from llama_stack.providers.datatypes import HealthStatus
class DistributionInspectConfig(BaseModel):

View file

@ -19,6 +19,8 @@ import httpx
import yaml
from fastapi import Response as FastAPIResponse
from llama_stack_api import is_unwrapped_body_param
try:
from llama_stack_client import (
NOT_GIVEN,
@ -57,7 +59,6 @@ from llama_stack.core.utils.config import redact_sensitive_fields
from llama_stack.core.utils.context import preserve_contexts_async_generator
from llama_stack.core.utils.exec import in_notebook
from llama_stack.log import get_logger, setup_logging
from llama_stack.strong_typing.inspection import is_unwrapped_body_param
logger = get_logger(name=__name__, category="core")

View file

@ -9,9 +9,9 @@ from typing import Any
from pydantic import BaseModel
from llama_stack.apis.prompts import ListPromptsResponse, Prompt, Prompts
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
from llama_stack_api import ListPromptsResponse, Prompt, Prompts
class PromptServiceConfig(BaseModel):

View file

@ -9,9 +9,8 @@ from typing import Any
from pydantic import BaseModel
from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import HealthResponse, HealthStatus
from llama_stack_api import HealthResponse, HealthStatus, ListProvidersResponse, ProviderInfo, Providers
from .datatypes import StackRunConfig
from .utils.config import redact_sensitive_fields

View file

@ -8,29 +8,6 @@ import importlib.metadata
import inspect
from typing import Any
from llama_stack.apis.agents import Agents
from llama_stack.apis.batches import Batches
from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.datatypes import ExternalApiSpec
from llama_stack.apis.eval import Eval
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference, InferenceProvider
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining
from llama_stack.apis.prompts import Prompts
from llama_stack.apis.providers import Providers as ProvidersAPI
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
from llama_stack.core.client import get_client_impl
from llama_stack.core.datatypes import (
AccessRule,
@ -44,17 +21,44 @@ from llama_stack.core.external import load_external_apis
from llama_stack.core.store import DistributionRegistry
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import (
from llama_stack_api import (
LLAMA_STACK_API_V1ALPHA,
Agents,
Api,
Batches,
Benchmarks,
BenchmarksProtocolPrivate,
Conversations,
DatasetIO,
Datasets,
DatasetsProtocolPrivate,
Eval,
ExternalApiSpec,
Files,
Inference,
InferenceProvider,
Inspect,
Models,
ModelsProtocolPrivate,
PostTraining,
Prompts,
ProviderSpec,
RemoteProviderConfig,
RemoteProviderSpec,
Safety,
Scoring,
ScoringFunctions,
ScoringFunctionsProtocolPrivate,
Shields,
ShieldsProtocolPrivate,
ToolGroups,
ToolGroupsProtocolPrivate,
ToolRuntime,
VectorIO,
VectorStore,
)
from llama_stack_api import (
Providers as ProvidersAPI,
)
logger = get_logger(name=__name__, category="core")

View file

@ -12,8 +12,8 @@ from llama_stack.core.datatypes import (
)
from llama_stack.core.stack import StackRunConfig
from llama_stack.core.store import DistributionRegistry
from llama_stack.providers.datatypes import Api, RoutingTable
from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack_api import Api, RoutingTable
async def get_routing_table_impl(

View file

@ -6,11 +6,8 @@
from typing import Any
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
from llama_stack_api import DatasetIO, DatasetPurpose, DataSource, PaginatedResponse, RoutingTable
logger = get_logger(name=__name__, category="core::routers")

View file

@ -6,15 +6,18 @@
from typing import Any
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
from llama_stack.apis.scoring import (
from llama_stack.log import get_logger
from llama_stack_api import (
BenchmarkConfig,
Eval,
EvaluateResponse,
Job,
RoutingTable,
ScoreBatchResponse,
ScoreResponse,
Scoring,
ScoringFnParams,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core::routers")

View file

@ -15,13 +15,25 @@ from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatC
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
from pydantic import TypeAdapter
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
from llama_stack.apis.inference import (
from llama_stack.core.telemetry.telemetry import MetricEvent
from llama_stack.core.telemetry.tracing import enqueue_event, get_current_span
from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack_api import (
HealthResponse,
HealthStatus,
Inference,
ListOpenAIChatCompletionResponse,
ModelNotFoundError,
ModelType,
ModelTypeError,
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
@ -35,19 +47,8 @@ from llama_stack.apis.inference import (
OpenAIMessageParam,
Order,
RerankResponse,
RoutingTable,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
)
from llama_stack.apis.models import ModelType
from llama_stack.core.telemetry.telemetry import MetricEvent
from llama_stack.core.telemetry.tracing import enqueue_event, get_current_span
from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
from llama_stack.providers.utils.inference.inference_store import InferenceStore
logger = get_logger(name=__name__, category="core::routers")
@ -416,7 +417,7 @@ class InferenceRouter(Inference):
prompt_tokens=chunk.usage.prompt_tokens,
completion_tokens=chunk.usage.completion_tokens,
total_tokens=chunk.usage.total_tokens,
model_id=fully_qualified_model_id,
fully_qualified_model_id=fully_qualified_model_id,
provider_id=provider_id,
)
for metric in metrics:

View file

@ -6,13 +6,9 @@
from typing import Any
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.safety.safety import ModerationObject
from llama_stack.apis.shields import Shield
from llama_stack.core.datatypes import SafetyConfig
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
from llama_stack_api import ModerationObject, OpenAIMessageParam, RoutingTable, RunShieldResponse, Safety, Shield
logger = get_logger(name=__name__, category="core::routers")

View file

@ -6,14 +6,12 @@
from typing import Any
from llama_stack.apis.common.content_types import (
from llama_stack.log import get_logger
from llama_stack_api import (
URL,
)
from llama_stack.apis.tools import (
ListToolDefsResponse,
ToolRuntime,
)
from llama_stack.log import get_logger
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
@ -36,16 +34,16 @@ class ToolRuntimeRouter(ToolRuntime):
logger.debug("ToolRuntimeRouter.shutdown")
pass
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None) -> Any:
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
provider = await self.routing_table.get_provider_impl(tool_name)
return await provider.invoke_tool(
tool_name=tool_name,
kwargs=kwargs,
authorization=authorization,
)
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None, authorization: str | None = None
) -> ListToolDefsResponse:
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.list_tools(tool_group_id)
return await self.routing_table.list_tools(tool_group_id, authorization=authorization)

View file

@ -10,13 +10,20 @@ from typing import Annotated, Any
from fastapi import Body
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.models import ModelType
from llama_stack.apis.vector_io import (
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.log import get_logger
from llama_stack_api import (
Chunk,
HealthResponse,
HealthStatus,
InterleavedContent,
ModelNotFoundError,
ModelType,
ModelTypeError,
OpenAICreateVectorStoreFileBatchRequestWithExtraBody,
OpenAICreateVectorStoreRequestWithExtraBody,
QueryChunksResponse,
RoutingTable,
SearchRankingOptions,
VectorIO,
VectorStoreChunkingStrategy,
@ -33,9 +40,6 @@ from llama_stack.apis.vector_io import (
VectorStoreObject,
VectorStoreSearchResponsePage,
)
from llama_stack.core.datatypes import VectorStoresConfig
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
logger = get_logger(name=__name__, category="core::routers")
@ -122,6 +126,14 @@ class VectorIORouter(VectorIO):
if embedding_model is not None and embedding_dimension is None:
embedding_dimension = await self._get_embedding_model_dimension(embedding_model)
# Validate that embedding model exists and is of the correct type
if embedding_model is not None:
model = await self.routing_table.get_object_by_identifier("model", embedding_model)
if model is None:
raise ModelNotFoundError(embedding_model)
if model.model_type != ModelType.embedding:
raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding)
# Auto-select provider if not specified
if provider_id is None:
num_providers = len(self.routing_table.impls_by_provider_id)

View file

@ -6,11 +6,11 @@
from typing import Any
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
from llama_stack.core.datatypes import (
BenchmarkWithOwner,
)
from llama_stack.log import get_logger
from llama_stack_api import Benchmark, Benchmarks, ListBenchmarksResponse
from .common import CommonRoutingTableImpl

View file

@ -6,9 +6,6 @@
from typing import Any
from llama_stack.apis.common.errors import ModelNotFoundError
from llama_stack.apis.models import Model
from llama_stack.apis.resource import ResourceType
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.core.access_control.datatypes import Action
from llama_stack.core.datatypes import (
@ -21,7 +18,7 @@ from llama_stack.core.datatypes import (
from llama_stack.core.request_headers import get_authenticated_user
from llama_stack.core.store import DistributionRegistry
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, RoutingTable
from llama_stack_api import Api, Model, ModelNotFoundError, ResourceType, RoutingTable
logger = get_logger(name=__name__, category="core::routing_tables")

View file

@ -7,22 +7,22 @@
import uuid
from typing import Any
from llama_stack.apis.common.errors import DatasetNotFoundError
from llama_stack.apis.datasets import (
from llama_stack.core.datatypes import (
DatasetWithOwner,
)
from llama_stack.log import get_logger
from llama_stack_api import (
Dataset,
DatasetNotFoundError,
DatasetPurpose,
Datasets,
DatasetType,
DataSource,
ListDatasetsResponse,
ResourceType,
RowsDataSource,
URIDataSource,
)
from llama_stack.apis.resource import ResourceType
from llama_stack.core.datatypes import (
DatasetWithOwner,
)
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl

View file

@ -7,8 +7,6 @@
import time
from typing import Any
from llama_stack.apis.common.errors import ModelNotFoundError
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
from llama_stack.core.datatypes import (
ModelWithOwner,
RegistryEntrySource,
@ -16,6 +14,15 @@ from llama_stack.core.datatypes import (
from llama_stack.core.request_headers import PROVIDER_DATA_VAR, NeedsRequestProviderData
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
from llama_stack_api import (
ListModelsResponse,
Model,
ModelNotFoundError,
Models,
ModelType,
OpenAIListModelsResponse,
OpenAIModel,
)
from .common import CommonRoutingTableImpl, lookup_model

View file

@ -4,18 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import (
ListScoringFunctionsResponse,
ScoringFn,
ScoringFnParams,
ScoringFunctions,
)
from llama_stack.core.datatypes import (
ScoringFnWithOwner,
)
from llama_stack.log import get_logger
from llama_stack_api import (
ListScoringFunctionsResponse,
ParamType,
ResourceType,
ScoringFn,
ScoringFnParams,
ScoringFunctions,
)
from .common import CommonRoutingTableImpl

View file

@ -6,12 +6,11 @@
from typing import Any
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
from llama_stack.core.datatypes import (
ShieldWithOwner,
)
from llama_stack.log import get_logger
from llama_stack_api import ListShieldsResponse, ResourceType, Shield, Shields
from .common import CommonRoutingTableImpl

View file

@ -6,11 +6,17 @@
from typing import Any
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.errors import ToolGroupNotFoundError
from llama_stack.apis.tools import ListToolDefsResponse, ListToolGroupsResponse, ToolDef, ToolGroup, ToolGroups
from llama_stack.core.datatypes import AuthenticationRequiredError, ToolGroupWithOwner
from llama_stack.log import get_logger
from llama_stack_api import (
URL,
ListToolDefsResponse,
ListToolGroupsResponse,
ToolDef,
ToolGroup,
ToolGroupNotFoundError,
ToolGroups,
)
from .common import CommonRoutingTableImpl
@ -43,7 +49,9 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
routing_key = self.tool_to_toolgroup[routing_key]
return await super().get_provider_impl(routing_key, provider_id)
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolDefsResponse:
async def list_tools(
self, toolgroup_id: str | None = None, authorization: str | None = None
) -> ListToolDefsResponse:
if toolgroup_id:
if group_id := parse_toolgroup_from_toolgroup_name_pair(toolgroup_id):
toolgroup_id = group_id
@ -55,7 +63,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
for toolgroup in toolgroups:
if toolgroup.identifier not in self.toolgroups_to_tools:
try:
await self._index_tools(toolgroup)
await self._index_tools(toolgroup, authorization=authorization)
except AuthenticationRequiredError:
# Send authentication errors back to the client so it knows
# that it needs to supply credentials for remote MCP servers.
@ -70,9 +78,11 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
return ListToolDefsResponse(data=all_tools)
async def _index_tools(self, toolgroup: ToolGroup):
async def _index_tools(self, toolgroup: ToolGroup, authorization: str | None = None):
provider_impl = await super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint)
tooldefs_response = await provider_impl.list_runtime_tools(
toolgroup.identifier, toolgroup.mcp_endpoint, authorization=authorization
)
tooldefs = tooldefs_response.data
for t in tooldefs:

View file

@ -6,12 +6,17 @@
from typing import Any
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
from llama_stack.apis.models import ModelType
from llama_stack.apis.resource import ResourceType
from llama_stack.core.datatypes import (
VectorStoreWithOwner,
)
from llama_stack.log import get_logger
# Removed VectorStores import to avoid exposing public API
from llama_stack.apis.vector_io.vector_io import (
from llama_stack_api import (
ModelNotFoundError,
ModelType,
ModelTypeError,
ResourceType,
SearchRankingOptions,
VectorStoreChunkingStrategy,
VectorStoreDeleteResponse,
@ -22,10 +27,6 @@ from llama_stack.apis.vector_io.vector_io import (
VectorStoreObject,
VectorStoreSearchResponsePage,
)
from llama_stack.core.datatypes import (
VectorStoreWithOwner,
)
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl, lookup_model

View file

@ -13,7 +13,6 @@ import httpx
import jwt
from pydantic import BaseModel, Field
from llama_stack.apis.common.errors import TokenValidationError
from llama_stack.core.datatypes import (
AuthenticationConfig,
CustomAuthConfig,
@ -23,6 +22,7 @@ from llama_stack.core.datatypes import (
User,
)
from llama_stack.log import get_logger
from llama_stack_api import TokenValidationError
logger = get_logger(name=__name__, category="core::auth")

View file

@ -12,9 +12,8 @@ from typing import Any
from aiohttp import hdrs
from starlette.routing import Route
from llama_stack.apis.datatypes import Api, ExternalApiSpec
from llama_stack.core.resolver import api_protocol_map
from llama_stack.schema_utils import WebMethod
from llama_stack_api import Api, ExternalApiSpec, WebMethod
EndpointFunc = Callable[..., Any]
PathParams = dict[str, str]

View file

@ -31,8 +31,6 @@ from fastapi.responses import JSONResponse, StreamingResponse
from openai import BadRequestError
from pydantic import BaseModel, ValidationError
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.core.access_control.access_control import AccessDeniedError
from llama_stack.core.datatypes import (
AuthenticationRequiredError,
@ -58,7 +56,7 @@ from llama_stack.core.utils.config import redact_sensitive_fields
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
from llama_stack.core.utils.context import preserve_contexts_async_generator
from llama_stack.log import LoggingConfig, get_logger, setup_logging
from llama_stack.providers.datatypes import Api
from llama_stack_api import Api, ConflictError, PaginatedResponse, ResourceNotFoundError
from .auth import AuthenticationMiddleware
from .quota import QuotaMiddleware
@ -526,8 +524,8 @@ def extract_path_params(route: str) -> list[str]:
def remove_disabled_providers(obj):
if isinstance(obj, dict):
keys = ["provider_id", "shield_id", "provider_model_id", "model_id"]
if any(k in obj and obj[k] in ("__disabled__", "", None) for k in keys):
# Filter out items where provider_id is explicitly disabled or empty
if "provider_id" in obj and obj["provider_id"] in ("__disabled__", "", None):
return None
return {k: v for k, v in ((k, remove_disabled_providers(v)) for k, v in obj.items()) if v is not None}
elif isinstance(obj, list):

View file

@ -13,26 +13,6 @@ from typing import Any
import yaml
from llama_stack.apis.agents import Agents
from llama_stack.apis.batches import Batches
from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining
from llama_stack.apis.prompts import Prompts
from llama_stack.apis.providers import Providers
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
from llama_stack.core.datatypes import Provider, SafetyConfig, StackRunConfig, VectorStoresConfig
from llama_stack.core.distribution import get_provider_registry
@ -54,7 +34,30 @@ from llama_stack.core.storage.datatypes import (
from llama_stack.core.store.registry import create_dist_registry
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
from llama_stack_api import (
Agents,
Api,
Batches,
Benchmarks,
Conversations,
DatasetIO,
Datasets,
Eval,
Files,
Inference,
Inspect,
Models,
PostTraining,
Prompts,
Providers,
Safety,
Scoring,
ScoringFunctions,
Shields,
ToolGroups,
ToolRuntime,
VectorIO,
)
logger = get_logger(name=__name__, category="core")

View file

@ -28,7 +28,7 @@ from pydantic import BaseModel, Field
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import Primitive
from llama_stack.schema_utils import json_schema_type, register_schema
from llama_stack_api import json_schema_type, register_schema
ROOT_SPAN_MARKERS = ["__root__", "__root_span__"]

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.models import ModelType
from llama_stack.core.datatypes import (
BuildProvider,
ModelInput,
@ -17,6 +16,7 @@ from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
from llama_stack.providers.remote.vector_io.chroma import ChromaVectorIOConfig
from llama_stack_api import ModelType
def get_distribution_template() -> DistributionTemplate:

View file

@ -6,7 +6,6 @@
from pathlib import Path
from llama_stack.apis.models import ModelType
from llama_stack.core.datatypes import (
BuildProvider,
ModelInput,
@ -22,6 +21,7 @@ from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack_api import ModelType
def get_distribution_template() -> DistributionTemplate:

View file

@ -5,8 +5,6 @@
# the root directory of this source tree.
from llama_stack.apis.datasets import DatasetPurpose, URIDataSource
from llama_stack.apis.models import ModelType
from llama_stack.core.datatypes import (
BenchmarkInput,
BuildProvider,
@ -34,6 +32,7 @@ from llama_stack.providers.remote.vector_io.pgvector.config import (
PGVectorVectorIOConfig,
)
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
from llama_stack_api import DatasetPurpose, ModelType, URIDataSource
def get_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderModelEntry]]]:

View file

@ -19,7 +19,6 @@ from llama_stack.core.datatypes import (
)
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
from llama_stack.providers.datatypes import RemoteProviderSpec
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
@ -38,6 +37,7 @@ from llama_stack.providers.remote.vector_io.qdrant.config import QdrantVectorIOC
from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig
from llama_stack.providers.utils.kvstore.config import PostgresKVStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig
from llama_stack_api import RemoteProviderSpec
def _get_config_for_provider(provider_spec: ProviderSpec) -> dict[str, Any]:

View file

@ -12,8 +12,6 @@ import rich
import yaml
from pydantic import BaseModel, Field
from llama_stack.apis.datasets import DatasetPurpose
from llama_stack.apis.models import ModelType
from llama_stack.core.datatypes import (
LLAMA_STACK_RUN_CONFIG_VERSION,
Api,
@ -44,6 +42,7 @@ from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.config import get_pip_packages as get_kv_pip_packages
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import get_pip_packages as get_sql_pip_packages
from llama_stack_api import DatasetPurpose, ModelType
def filter_empty_values(obj: Any) -> Any:

View file

@ -5,29 +5,29 @@
# the root directory of this source tree.
from llama_stack.apis.agents import (
from llama_stack.core.datatypes import AccessRule
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
from llama_stack_api import (
Agents,
Conversations,
Inference,
ListOpenAIResponseInputItem,
ListOpenAIResponseObject,
OpenAIDeleteResponseObject,
OpenAIResponseInput,
OpenAIResponseInputTool,
OpenAIResponseObject,
OpenAIResponsePrompt,
OpenAIResponseText,
Order,
ResponseGuardrail,
Safety,
ToolGroups,
ToolRuntime,
VectorIO,
)
from llama_stack.apis.agents.agents import ResponseGuardrail
from llama_stack.apis.agents.openai_responses import OpenAIResponsePrompt, OpenAIResponseText
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.inference import (
Inference,
)
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.datatypes import AccessRule
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
from .config import MetaReferenceAgentsImplConfig
from .responses.openai_responses import OpenAIResponsesImpl

View file

@ -10,12 +10,20 @@ from collections.abc import AsyncIterator
from pydantic import BaseModel, TypeAdapter
from llama_stack.apis.agents import Order
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
from llama_stack.apis.agents.openai_responses import (
from llama_stack.log import get_logger
from llama_stack.providers.utils.responses.responses_store import (
ResponsesStore,
_OpenAIResponseObjectWithInputAndMessages,
)
from llama_stack_api import (
ConversationItem,
Conversations,
Inference,
InvalidConversationIdError,
ListOpenAIResponseInputItem,
ListOpenAIResponseObject,
OpenAIDeleteResponseObject,
OpenAIMessageParam,
OpenAIResponseInput,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool,
@ -25,24 +33,13 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponsePrompt,
OpenAIResponseText,
OpenAIResponseTextFormat,
)
from llama_stack.apis.common.errors import (
InvalidConversationIdError,
)
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.conversations.conversations import ConversationItem
from llama_stack.apis.inference import (
Inference,
OpenAIMessageParam,
OpenAISystemMessageParam,
)
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.log import get_logger
from llama_stack.providers.utils.responses.responses_store import (
ResponsesStore,
_OpenAIResponseObjectWithInputAndMessages,
Order,
ResponseGuardrailSpec,
Safety,
ToolGroups,
ToolRuntime,
VectorIO,
)
from .streaming import StreamingResponseOrchestrator
@ -260,6 +257,19 @@ class OpenAIResponsesImpl:
stream = bool(stream)
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
# Validate MCP tools: ensure Authorization header is not passed via headers dict
if tools:
from llama_stack_api.openai_responses import OpenAIResponseInputToolMCP
for tool in tools:
if isinstance(tool, OpenAIResponseInputToolMCP) and tool.headers:
for key in tool.headers.keys():
if key.lower() == "authorization":
raise ValueError(
"Authorization header cannot be passed via 'headers'. "
"Please use the 'authorization' parameter instead."
)
guardrail_ids = extract_guardrail_ids(guardrails) if guardrails else []
if conversation is not None:

View file

@ -8,10 +8,21 @@ import uuid
from collections.abc import AsyncIterator
from typing import Any
from llama_stack.apis.agents.openai_responses import (
from llama_stack.core.telemetry import tracing
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack_api import (
AllowedToolsFilter,
ApprovalFilter,
Inference,
MCPListToolsTool,
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIChatCompletionToolCall,
OpenAIChoice,
OpenAIMessageParam,
OpenAIResponseContentPartOutputText,
OpenAIResponseContentPartReasoningText,
OpenAIResponseContentPartRefusal,
@ -56,19 +67,6 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseUsageOutputTokensDetails,
WebSearchToolTypes,
)
from llama_stack.apis.inference import (
Inference,
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIChatCompletionToolCall,
OpenAIChoice,
OpenAIMessageParam,
)
from llama_stack.core.telemetry import tracing
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from .types import ChatCompletionContext, ChatCompletionResult
from .utils import (
@ -1025,9 +1023,9 @@ class StreamingResponseOrchestrator:
"""Process all tools and emit appropriate streaming events."""
from openai.types.chat import ChatCompletionToolParam
from llama_stack.apis.tools import ToolDef
from llama_stack.models.llama.datatypes import ToolDefinition
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
from llama_stack_api import ToolDef
def make_openai_tool(tool_name: str, tool: ToolDef) -> ChatCompletionToolParam:
tool_def = ToolDefinition(
@ -1093,10 +1091,12 @@ class StreamingResponseOrchestrator:
"server_url": mcp_tool.server_url,
"mcp_list_tools_id": list_id,
}
# List MCP tools with authorization from tool config
async with tracing.span("list_mcp_tools", attributes):
tool_defs = await list_mcp_tools(
endpoint=mcp_tool.server_url,
headers=mcp_tool.headers or {},
headers=mcp_tool.headers,
authorization=mcp_tool.authorization,
)
# Create the MCP list tools message

View file

@ -9,7 +9,14 @@ import json
from collections.abc import AsyncIterator
from typing import Any
from llama_stack.apis.agents.openai_responses import (
from llama_stack.core.telemetry import tracing
from llama_stack.log import get_logger
from llama_stack_api import (
ImageContentItem,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionToolCall,
OpenAIImageURL,
OpenAIResponseInputToolFileSearch,
OpenAIResponseInputToolMCP,
OpenAIResponseObjectStreamResponseFileSearchCallCompleted,
@ -23,24 +30,14 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseObjectStreamResponseWebSearchCallSearching,
OpenAIResponseOutputMessageFileSearchToolCall,
OpenAIResponseOutputMessageFileSearchToolCallResults,
OpenAIResponseOutputMessageMCPCall,
OpenAIResponseOutputMessageWebSearchToolCall,
)
from llama_stack.apis.common.content_types import (
ImageContentItem,
TextContentItem,
)
from llama_stack.apis.inference import (
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionToolCall,
OpenAIImageURL,
OpenAIToolMessageParam,
TextContentItem,
ToolGroups,
ToolInvocationResult,
ToolRuntime,
VectorIO,
)
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.telemetry import tracing
from llama_stack.log import get_logger
from .types import ChatCompletionContext, ToolExecutionResult
@ -299,12 +296,14 @@ class ToolExecutor:
"server_url": mcp_tool.server_url,
"tool_name": function_name,
}
# Invoke MCP tool with authorization from tool config
async with tracing.span("invoke_mcp_tool", attributes):
result = await invoke_mcp_tool(
endpoint=mcp_tool.server_url,
headers=mcp_tool.headers or {},
tool_name=function_name,
kwargs=tool_kwargs,
headers=mcp_tool.headers,
authorization=mcp_tool.authorization,
)
elif function_name == "knowledge_search":
response_file_search_tool = (
@ -398,6 +397,10 @@ class ToolExecutor:
# Build output message
message: Any
if mcp_tool_to_server and function.name in mcp_tool_to_server:
from llama_stack_api import (
OpenAIResponseOutputMessageMCPCall,
)
message = OpenAIResponseOutputMessageMCPCall(
id=item_id,
arguments=function.arguments,

View file

@ -10,7 +10,10 @@ from typing import cast
from openai.types.chat import ChatCompletionToolParam
from pydantic import BaseModel
from llama_stack.apis.agents.openai_responses import (
from llama_stack_api import (
OpenAIChatCompletionToolCall,
OpenAIMessageParam,
OpenAIResponseFormatParam,
OpenAIResponseInput,
OpenAIResponseInputTool,
OpenAIResponseInputToolFileSearch,
@ -26,7 +29,6 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseTool,
OpenAIResponseToolMCP,
)
from llama_stack.apis.inference import OpenAIChatCompletionToolCall, OpenAIMessageParam, OpenAIResponseFormatParam
class ToolExecutionResult(BaseModel):

View file

@ -9,9 +9,23 @@ import re
import uuid
from collections.abc import Sequence
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
from llama_stack.apis.agents.openai_responses import (
from llama_stack_api import (
OpenAIAssistantMessageParam,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
OpenAIChoice,
OpenAIDeveloperMessageParam,
OpenAIImageURL,
OpenAIJSONSchema,
OpenAIMessageParam,
OpenAIResponseAnnotationFileCitation,
OpenAIResponseFormatJSONObject,
OpenAIResponseFormatJSONSchema,
OpenAIResponseFormatParam,
OpenAIResponseFormatText,
OpenAIResponseInput,
OpenAIResponseInputFunctionToolCallOutput,
OpenAIResponseInputMessageContent,
@ -27,28 +41,12 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseOutputMessageMCPCall,
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseText,
)
from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
OpenAIChoice,
OpenAIDeveloperMessageParam,
OpenAIImageURL,
OpenAIJSONSchema,
OpenAIMessageParam,
OpenAIResponseFormatJSONObject,
OpenAIResponseFormatJSONSchema,
OpenAIResponseFormatParam,
OpenAIResponseFormatText,
OpenAISystemMessageParam,
OpenAIToolMessageParam,
OpenAIUserMessageParam,
ResponseGuardrailSpec,
Safety,
)
from llama_stack.apis.safety import Safety
async def convert_chat_choice_to_response_message(

View file

@ -6,10 +6,9 @@
import asyncio
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
from llama_stack.core.telemetry import tracing
from llama_stack.log import get_logger
from llama_stack_api import OpenAIMessageParam, Safety, SafetyViolation, ViolationLevel
log = get_logger(name=__name__, category="agents::meta_reference")

View file

@ -6,11 +6,9 @@
from typing import Any
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference
from llama_stack.apis.models import Models
from llama_stack.core.datatypes import AccessRule, Api
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack_api import Files, Inference, Models
from .batches import ReferenceBatchesImpl
from .config import ReferenceBatchesImplConfig

View file

@ -16,24 +16,28 @@ from typing import Any, Literal
from openai.types.batch import BatchError, Errors
from pydantic import BaseModel
from llama_stack.apis.batches import Batches, BatchObject, ListBatchesResponse
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
from llama_stack.apis.files import Files, OpenAIFilePurpose
from llama_stack.apis.inference import (
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore
from llama_stack_api import (
Batches,
BatchObject,
ConflictError,
Files,
Inference,
ListBatchesResponse,
Models,
OpenAIAssistantMessageParam,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletionRequestWithExtraBody,
OpenAIDeveloperMessageParam,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIFilePurpose,
OpenAIMessageParam,
OpenAISystemMessageParam,
OpenAIToolMessageParam,
OpenAIUserMessageParam,
ResourceNotFoundError,
)
from llama_stack.apis.models import Models
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore
from .config import ReferenceBatchesImplConfig

View file

@ -5,13 +5,10 @@
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.pagination import paginate_records
from llama_stack_api import Dataset, DatasetIO, DatasetsProtocolPrivate, PaginatedResponse
from .config import LocalFSDatasetIOConfig

View file

@ -8,24 +8,27 @@ from typing import Any
from tqdm import tqdm
from llama_stack.apis.agents import Agents
from llama_stack.apis.benchmarks import Benchmark
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.inference import (
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack_api import (
Agents,
Benchmark,
BenchmarkConfig,
BenchmarksProtocolPrivate,
DatasetIO,
Datasets,
Eval,
EvaluateResponse,
Inference,
Job,
JobStatus,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletionRequestWithExtraBody,
OpenAISystemMessageParam,
OpenAIUserMessageParam,
Scoring,
)
from llama_stack.apis.scoring import Scoring
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
from llama_stack.providers.utils.kvstore import kvstore_impl
from .....apis.common.job_types import Job, JobStatus
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse
from .config import MetaReferenceEvalConfig
EVAL_TASKS_PREFIX = "benchmarks:"

View file

@ -11,16 +11,6 @@ from typing import Annotated
from fastapi import Depends, File, Form, Response, UploadFile
from llama_stack.apis.common.errors import ResourceNotFoundError
from llama_stack.apis.common.responses import Order
from llama_stack.apis.files import (
ExpiresAfter,
Files,
ListOpenAIFileResponse,
OpenAIFileDeleteResponse,
OpenAIFileObject,
OpenAIFilePurpose,
)
from llama_stack.core.datatypes import AccessRule
from llama_stack.core.id_generation import generate_object_id
from llama_stack.log import get_logger
@ -28,6 +18,16 @@ from llama_stack.providers.utils.files.form_data import parse_expires_after
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
from llama_stack_api import (
ExpiresAfter,
Files,
ListOpenAIFileResponse,
OpenAIFileDeleteResponse,
OpenAIFileObject,
OpenAIFilePurpose,
Order,
ResourceNotFoundError,
)
from .config import LocalfsFilesImplConfig

View file

@ -8,8 +8,8 @@ from typing import Any
from pydantic import BaseModel, field_validator
from llama_stack.apis.inference import QuantizationConfig
from llama_stack.providers.utils.inference import supported_inference_models
from llama_stack_api import QuantizationConfig
class MetaReferenceInferenceConfig(BaseModel):

View file

@ -10,7 +10,13 @@ from typing import Optional
import torch
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from llama_stack.apis.inference import (
from llama_stack.models.llama.datatypes import QuantizationMode, ToolPromptFormat
from llama_stack.models.llama.llama3.generation import Llama3
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.generation import Llama4
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
from llama_stack.models.llama.sku_types import Model, ModelFamily
from llama_stack_api import (
GreedySamplingStrategy,
JsonSchemaResponseFormat,
OpenAIChatCompletionRequestWithExtraBody,
@ -20,12 +26,6 @@ from llama_stack.apis.inference import (
SamplingParams,
TopPSamplingStrategy,
)
from llama_stack.models.llama.datatypes import QuantizationMode, ToolPromptFormat
from llama_stack.models.llama.llama3.generation import Llama3
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
from llama_stack.models.llama.llama4.generation import Llama4
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
from llama_stack.models.llama.sku_types import Model, ModelFamily
from .common import model_checkpoint_dir
from .config import MetaReferenceInferenceConfig

View file

@ -9,22 +9,6 @@ import time
import uuid
from collections.abc import AsyncIterator
from llama_stack.apis.inference import (
InferenceProvider,
OpenAIAssistantMessageParam,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIChatCompletionUsage,
OpenAIChoice,
OpenAICompletionRequestWithExtraBody,
OpenAIUserMessageParam,
ToolChoice,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import RawMessage, RawTextItem, ToolDefinition
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
@ -40,7 +24,6 @@ from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
)
@ -48,6 +31,22 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
build_hf_repo_model_entry,
)
from llama_stack_api import (
InferenceProvider,
Model,
ModelsProtocolPrivate,
ModelType,
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIChatCompletionUsage,
OpenAIChoice,
OpenAICompletion,
OpenAICompletionRequestWithExtraBody,
OpenAIUserMessageParam,
ToolChoice,
)
from .config import MetaReferenceInferenceConfig
from .generators import LlamaGenerator
@ -376,7 +375,7 @@ class MetaReferenceInferenceImpl(
# Convert tool calls to OpenAI format
openai_tool_calls = None
if decoded_message.tool_calls:
from llama_stack.apis.inference import (
from llama_stack_api import (
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
)
@ -441,15 +440,15 @@ class MetaReferenceInferenceImpl(
params: OpenAIChatCompletionRequestWithExtraBody,
) -> AsyncIterator[OpenAIChatCompletionChunk]:
"""Stream chat completion chunks as they're generated."""
from llama_stack.apis.inference import (
from llama_stack.models.llama.datatypes import StopReason
from llama_stack.providers.utils.inference.prompt_adapter import decode_assistant_message
from llama_stack_api import (
OpenAIChatCompletionChunk,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
OpenAIChoiceDelta,
OpenAIChunkChoice,
)
from llama_stack.models.llama.datatypes import StopReason
from llama_stack.providers.utils.inference.prompt_adapter import decode_assistant_message
response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
created = int(time.time())

View file

@ -6,22 +6,21 @@
from collections.abc import AsyncIterator
from llama_stack.apis.inference import (
InferenceProvider,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletionRequestWithExtraBody,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
)
from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
)
from llama_stack_api import (
InferenceProvider,
Model,
ModelsProtocolPrivate,
ModelType,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletion,
OpenAICompletionRequestWithExtraBody,
)
from .config import SentenceTransformersInferenceConfig

View file

@ -12,14 +12,10 @@
from typing import Any
from llama_stack.apis.common.type_system import (
ChatCompletionInputType,
DialogType,
StringType,
)
from llama_stack.providers.utils.common.data_schema_validator import (
ColumnName,
)
from llama_stack_api import ChatCompletionInputType, DialogType, StringType
EXPECTED_DATASET_SCHEMA: dict[str, list[dict[str, Any]]] = {
"instruct": [

View file

@ -6,11 +6,16 @@
from enum import Enum
from typing import Any
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
from llama_stack.providers.inline.post_training.huggingface.config import (
HuggingFacePostTrainingConfig,
)
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
from llama_stack_api import (
AlgorithmConfig,
Checkpoint,
DatasetIO,
Datasets,
DPOAlignmentConfig,
JobStatus,
ListPostTrainingJobsResponse,
@ -19,11 +24,6 @@ from llama_stack.apis.post_training import (
PostTrainingJobStatusResponse,
TrainingConfig,
)
from llama_stack.providers.inline.post_training.huggingface.config import (
HuggingFacePostTrainingConfig,
)
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
class TrainingArtifactType(Enum):

View file

@ -18,16 +18,16 @@ from transformers import (
)
from trl import SFTConfig, SFTTrainer
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
from llama_stack.log import get_logger
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
from llama_stack_api import (
Checkpoint,
DataConfig,
DatasetIO,
Datasets,
LoraFinetuningConfig,
TrainingConfig,
)
from llama_stack.log import get_logger
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
from ..config import HuggingFacePostTrainingConfig
from ..utils import (

View file

@ -16,15 +16,15 @@ from transformers import (
)
from trl import DPOConfig, DPOTrainer
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
from llama_stack.log import get_logger
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
from llama_stack_api import (
Checkpoint,
DatasetIO,
Datasets,
DPOAlignmentConfig,
TrainingConfig,
)
from llama_stack.log import get_logger
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
from ..config import HuggingFacePostTrainingConfig
from ..utils import (

View file

@ -16,6 +16,8 @@ import torch
from datasets import Dataset
from transformers import AutoConfig, AutoModelForCausalLM
from llama_stack_api import Checkpoint, DatasetIO, TrainingConfig
if TYPE_CHECKING:
from transformers import PretrainedConfig
@ -34,8 +36,6 @@ class HFAutoModel(Protocol):
def save_pretrained(self, save_directory: str | Path) -> None: ...
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.post_training import Checkpoint, TrainingConfig
from llama_stack.log import get_logger
from .config import HuggingFacePostTrainingConfig

View file

@ -21,9 +21,9 @@ from torchtune.models.llama3_1 import lora_llama3_1_8b
from torchtune.models.llama3_2 import lora_llama3_2_3b
from torchtune.modules.transforms import Transform
from llama_stack.apis.post_training import DatasetFormat
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.models.llama.sku_types import Model
from llama_stack_api import DatasetFormat
BuildLoraModelCallable = Callable[..., torch.nn.Module]
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]

View file

@ -6,11 +6,16 @@
from enum import Enum
from typing import Any
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
from llama_stack.providers.inline.post_training.torchtune.config import (
TorchtunePostTrainingConfig,
)
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
from llama_stack_api import (
AlgorithmConfig,
Checkpoint,
DatasetIO,
Datasets,
DPOAlignmentConfig,
JobStatus,
ListPostTrainingJobsResponse,
@ -20,11 +25,6 @@ from llama_stack.apis.post_training import (
PostTrainingJobStatusResponse,
TrainingConfig,
)
from llama_stack.providers.inline.post_training.torchtune.config import (
TorchtunePostTrainingConfig,
)
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
class TrainingArtifactType(Enum):

View file

@ -32,17 +32,6 @@ from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup
from torchtune.training.metric_logging import DiskLogger
from tqdm import tqdm
from llama_stack.apis.common.training_types import PostTrainingMetric
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
Checkpoint,
DataConfig,
LoraFinetuningConfig,
OptimizerConfig,
QATFinetuningConfig,
TrainingConfig,
)
from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.core.utils.model_utils import model_local_dir
from llama_stack.log import get_logger
@ -56,6 +45,17 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
TorchtunePostTrainingConfig,
)
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
from llama_stack_api import (
Checkpoint,
DataConfig,
DatasetIO,
Datasets,
LoraFinetuningConfig,
OptimizerConfig,
PostTrainingMetric,
QATFinetuningConfig,
TrainingConfig,
)
log = get_logger(name=__name__, category="post_training")

View file

@ -10,19 +10,20 @@ from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from codeshield.cs import CodeShieldScanResult
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
SafetyViolation,
ViolationLevel,
)
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from llama_stack_api import (
ModerationObject,
ModerationObjectResults,
OpenAIMessageParam,
RunShieldResponse,
Safety,
SafetyViolation,
Shield,
ViolationLevel,
)
from .config import CodeScannerConfig

View file

@ -9,29 +9,29 @@ import uuid
from string import Template
from typing import Any
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
from llama_stack.apis.inference import (
Inference,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIMessageParam,
OpenAIUserMessageParam,
)
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
SafetyViolation,
ViolationLevel,
)
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
from llama_stack.apis.shields import Shield
from llama_stack.core.datatypes import Api
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import Role
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
from llama_stack_api import (
ImageContentItem,
Inference,
ModerationObject,
ModerationObjectResults,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIMessageParam,
OpenAIUserMessageParam,
RunShieldResponse,
Safety,
SafetyViolation,
Shield,
ShieldsProtocolPrivate,
TextContentItem,
ViolationLevel,
)
from .config import LlamaGuardConfig

View file

@ -9,20 +9,20 @@ from typing import Any
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.safety import (
from llama_stack.core.utils.model_utils import model_local_dir
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack_api import (
ModerationObject,
OpenAIMessageParam,
RunShieldResponse,
Safety,
SafetyViolation,
Shield,
ShieldsProtocolPrivate,
ShieldStore,
ViolationLevel,
)
from llama_stack.apis.safety.safety import ModerationObject
from llama_stack.apis.shields import Shield
from llama_stack.core.utils.model_utils import model_local_dir
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from .config import PromptGuardConfig, PromptGuardType

View file

@ -5,21 +5,22 @@
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.scoring import (
ScoreBatchResponse,
ScoreResponse,
Scoring,
ScoringResult,
)
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.core.datatypes import Api
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import (
get_valid_schemas,
validate_dataset_schema,
)
from llama_stack_api import (
DatasetIO,
Datasets,
ScoreBatchResponse,
ScoreResponse,
Scoring,
ScoringFn,
ScoringFnParams,
ScoringFunctionsProtocolPrivate,
ScoringResult,
)
from .config import BasicScoringConfig
from .scoring_fn.docvqa_scoring_fn import DocVQAScoringFn

View file

@ -8,9 +8,8 @@ import json
import re
from typing import Any
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from llama_stack_api import ScoringFnParams, ScoringResultRow
from .fn_defs.docvqa import docvqa

View file

@ -6,9 +6,8 @@
from typing import Any
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from llama_stack_api import ScoringFnParams, ScoringResultRow
from .fn_defs.equality import equality

View file

@ -4,10 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
from llama_stack_api import (
AggregationFunctionType,
BasicScoringFnParams,
NumberType,
ScoringFn,
)

View file

@ -4,10 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
from llama_stack_api import (
AggregationFunctionType,
BasicScoringFnParams,
NumberType,
ScoringFn,
)

View file

@ -4,10 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
from llama_stack_api import (
AggregationFunctionType,
BasicScoringFnParams,
NumberType,
ScoringFn,
)

View file

@ -4,9 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
from llama_stack_api import (
AggregationFunctionType,
NumberType,
RegexParserScoringFnParams,
ScoringFn,
)

View file

@ -4,9 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
from llama_stack_api import (
AggregationFunctionType,
NumberType,
RegexParserScoringFnParams,
ScoringFn,
)

View file

@ -4,10 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
from llama_stack_api import (
AggregationFunctionType,
BasicScoringFnParams,
NumberType,
ScoringFn,
)

View file

@ -6,9 +6,8 @@
from typing import Any
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from llama_stack_api import ScoringFnParams, ScoringResultRow
from .fn_defs.ifeval import (
ifeval,

View file

@ -5,9 +5,8 @@
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from llama_stack_api import ScoringFnParams, ScoringFnParamsType, ScoringResultRow
from ..utils.math_utils import first_answer, normalize_final_answer, try_evaluate_frac, try_evaluate_latex
from .fn_defs.regex_parser_math_response import (

Some files were not shown because too many files have changed in this diff Show more