mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 18:00:36 +00:00
Merge branch 'main' into feat/add-dana-agent-provider-stub
This commit is contained in:
commit
3b3a2d0ceb
418 changed files with 24245 additions and 1794 deletions
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .agents import *
|
||||
|
|
@ -1,9 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .batches import Batches, BatchObject, ListBatchesResponse
|
||||
|
||||
__all__ = ["Batches", "BatchObject", "ListBatchesResponse"]
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .benchmarks import *
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -1,27 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .conversations import (
|
||||
Conversation,
|
||||
ConversationDeletedResource,
|
||||
ConversationItem,
|
||||
ConversationItemCreateRequest,
|
||||
ConversationItemDeletedResource,
|
||||
ConversationItemList,
|
||||
Conversations,
|
||||
Metadata,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Conversation",
|
||||
"ConversationDeletedResource",
|
||||
"ConversationItem",
|
||||
"ConversationItemCreateRequest",
|
||||
"ConversationItemDeletedResource",
|
||||
"ConversationItemList",
|
||||
"Conversations",
|
||||
"Metadata",
|
||||
]
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .datasetio import *
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .datasets import *
|
||||
|
|
@ -1,158 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum, EnumMeta
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class DynamicApiMeta(EnumMeta):
|
||||
def __new__(cls, name, bases, namespace):
|
||||
# Store the original enum values
|
||||
original_values = {k: v for k, v in namespace.items() if not k.startswith("_")}
|
||||
|
||||
# Create the enum class
|
||||
cls = super().__new__(cls, name, bases, namespace)
|
||||
|
||||
# Store the original values for reference
|
||||
cls._original_values = original_values
|
||||
# Initialize _dynamic_values
|
||||
cls._dynamic_values = {}
|
||||
|
||||
return cls
|
||||
|
||||
def __call__(cls, value):
|
||||
try:
|
||||
return super().__call__(value)
|
||||
except ValueError as e:
|
||||
# If this value was already dynamically added, return it
|
||||
if value in cls._dynamic_values:
|
||||
return cls._dynamic_values[value]
|
||||
|
||||
# If the value doesn't exist, create a new enum member
|
||||
# Create a new member name from the value
|
||||
member_name = value.lower().replace("-", "_")
|
||||
|
||||
# If this member name already exists in the enum, return the existing member
|
||||
if member_name in cls._member_map_:
|
||||
return cls._member_map_[member_name]
|
||||
|
||||
# Instead of creating a new member, raise ValueError to force users to use Api.add() to
|
||||
# register new APIs explicitly
|
||||
raise ValueError(f"API '{value}' does not exist. Use Api.add() to register new APIs.") from e
|
||||
|
||||
def __iter__(cls):
|
||||
# Allow iteration over both static and dynamic members
|
||||
yield from super().__iter__()
|
||||
if hasattr(cls, "_dynamic_values"):
|
||||
yield from cls._dynamic_values.values()
|
||||
|
||||
def add(cls, value):
|
||||
"""
|
||||
Add a new API to the enum.
|
||||
Used to register external APIs.
|
||||
"""
|
||||
member_name = value.lower().replace("-", "_")
|
||||
|
||||
# If this member name already exists in the enum, return it
|
||||
if member_name in cls._member_map_:
|
||||
return cls._member_map_[member_name]
|
||||
|
||||
# Create a new enum member
|
||||
member = object.__new__(cls)
|
||||
member._name_ = member_name
|
||||
member._value_ = value
|
||||
|
||||
# Add it to the enum class
|
||||
cls._member_map_[member_name] = member
|
||||
cls._member_names_.append(member_name)
|
||||
cls._member_type_ = str
|
||||
|
||||
# Store it in our dynamic values
|
||||
cls._dynamic_values[value] = member
|
||||
|
||||
return member
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Api(Enum, metaclass=DynamicApiMeta):
|
||||
"""Enumeration of all available APIs in the Llama Stack system.
|
||||
:cvar providers: Provider management and configuration
|
||||
:cvar inference: Text generation, chat completions, and embeddings
|
||||
:cvar safety: Content moderation and safety shields
|
||||
:cvar agents: Agent orchestration and execution
|
||||
:cvar batches: Batch processing for asynchronous API requests
|
||||
:cvar vector_io: Vector database operations and queries
|
||||
:cvar datasetio: Dataset input/output operations
|
||||
:cvar scoring: Model output evaluation and scoring
|
||||
:cvar eval: Model evaluation and benchmarking framework
|
||||
:cvar post_training: Fine-tuning and model training
|
||||
:cvar tool_runtime: Tool execution and management
|
||||
:cvar telemetry: Observability and system monitoring
|
||||
:cvar models: Model metadata and management
|
||||
:cvar shields: Safety shield implementations
|
||||
:cvar datasets: Dataset creation and management
|
||||
:cvar scoring_functions: Scoring function definitions
|
||||
:cvar benchmarks: Benchmark suite management
|
||||
:cvar tool_groups: Tool group organization
|
||||
:cvar files: File storage and management
|
||||
:cvar prompts: Prompt versions and management
|
||||
:cvar inspect: Built-in system inspection and introspection
|
||||
"""
|
||||
|
||||
providers = "providers"
|
||||
inference = "inference"
|
||||
safety = "safety"
|
||||
agents = "agents"
|
||||
batches = "batches"
|
||||
vector_io = "vector_io"
|
||||
datasetio = "datasetio"
|
||||
scoring = "scoring"
|
||||
eval = "eval"
|
||||
post_training = "post_training"
|
||||
tool_runtime = "tool_runtime"
|
||||
|
||||
models = "models"
|
||||
shields = "shields"
|
||||
vector_stores = "vector_stores" # only used for routing table
|
||||
datasets = "datasets"
|
||||
scoring_functions = "scoring_functions"
|
||||
benchmarks = "benchmarks"
|
||||
tool_groups = "tool_groups"
|
||||
files = "files"
|
||||
prompts = "prompts"
|
||||
conversations = "conversations"
|
||||
|
||||
# built-in API
|
||||
inspect = "inspect"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Error(BaseModel):
|
||||
"""
|
||||
Error response from the API. Roughly follows RFC 7807.
|
||||
|
||||
:param status: HTTP status code
|
||||
:param title: Error title, a short summary of the error which is invariant for an error type
|
||||
:param detail: Error detail, a longer human-readable description of the error
|
||||
:param instance: (Optional) A URL which can be used to retrieve more information about the specific occurrence of the error
|
||||
"""
|
||||
|
||||
status: int
|
||||
title: str
|
||||
detail: str
|
||||
instance: str | None = None
|
||||
|
||||
|
||||
class ExternalApiSpec(BaseModel):
|
||||
"""Specification for an external API implementation."""
|
||||
|
||||
module: str = Field(..., description="Python module containing the API implementation")
|
||||
name: str = Field(..., description="Name of the API")
|
||||
pip_packages: list[str] = Field(default=[], description="List of pip packages to install the API")
|
||||
protocol: str = Field(..., description="Name of the protocol class for the API")
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .eval import *
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .files import *
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .inference import *
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .inspect import *
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .models import *
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .post_training import *
|
||||
|
|
@ -1,9 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .prompts import ListPromptsResponse, Prompt, Prompts
|
||||
|
||||
__all__ = ["Prompt", "Prompts", "ListPromptsResponse"]
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .providers import *
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .safety import *
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .scoring import *
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .scoring_functions import *
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .shields import *
|
||||
|
|
@ -1,8 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .rag_tool import *
|
||||
from .tools import *
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .vector_io import *
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .vector_stores import *
|
||||
|
|
@ -21,7 +21,7 @@ from llama_stack.core.datatypes import (
|
|||
from llama_stack.core.distribution import get_provider_registry
|
||||
from llama_stack.core.stack import replace_env_vars
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack_api import Api
|
||||
|
||||
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"
|
||||
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ from llama_stack.core.storage.datatypes import (
|
|||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR, EXTERNAL_PROVIDERS_DIR
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.core.utils.image_types import LlamaStackImageType
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack_api import Api
|
||||
|
||||
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "distributions"
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from llama_stack.core.datatypes import BuildConfig
|
|||
from llama_stack.core.distribution import get_provider_registry
|
||||
from llama_stack.distributions.template import DistributionTemplate
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack_api import Api
|
||||
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ import httpx
|
|||
from pydantic import BaseModel, parse_obj_as
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.providers.datatypes import RemoteProviderConfig
|
||||
from llama_stack_api import RemoteProviderConfig
|
||||
|
||||
_CLIENT_CLASSES = {}
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
|
|||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.core.utils.prompt_for_config import prompt_for_config
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
from llama_stack_api import Api, ProviderSpec
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,12 @@ from typing import Any, Literal
|
|||
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from llama_stack.apis.conversations.conversations import (
|
||||
from llama_stack.core.datatypes import AccessRule, StackRunConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
|
||||
from llama_stack_api import (
|
||||
Conversation,
|
||||
ConversationDeletedResource,
|
||||
ConversationItem,
|
||||
|
|
@ -20,11 +25,6 @@ from llama_stack.apis.conversations.conversations import (
|
|||
Conversations,
|
||||
Metadata,
|
||||
)
|
||||
from llama_stack.core.datatypes import AccessRule, StackRunConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
|
||||
|
||||
logger = get_logger(name=__name__, category="openai_conversations")
|
||||
|
||||
|
|
|
|||
|
|
@ -11,20 +11,6 @@ from urllib.parse import urlparse
|
|||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Dataset, DatasetInput
|
||||
from llama_stack.apis.eval import Eval
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.models import Model, ModelInput
|
||||
from llama_stack.apis.resource import Resource
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
|
||||
from llama_stack.apis.shields import Shield, ShieldInput
|
||||
from llama_stack.apis.tools import ToolGroup, ToolGroupInput, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.apis.vector_stores import VectorStore, VectorStoreInput
|
||||
from llama_stack.core.access_control.datatypes import AccessRule
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
KVStoreReference,
|
||||
|
|
@ -32,7 +18,32 @@ from llama_stack.core.storage.datatypes import (
|
|||
StorageConfig,
|
||||
)
|
||||
from llama_stack.log import LoggingConfig
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
from llama_stack_api import (
|
||||
Api,
|
||||
Benchmark,
|
||||
BenchmarkInput,
|
||||
Dataset,
|
||||
DatasetInput,
|
||||
DatasetIO,
|
||||
Eval,
|
||||
Inference,
|
||||
Model,
|
||||
ModelInput,
|
||||
ProviderSpec,
|
||||
Resource,
|
||||
Safety,
|
||||
Scoring,
|
||||
ScoringFn,
|
||||
ScoringFnInput,
|
||||
Shield,
|
||||
ShieldInput,
|
||||
ToolGroup,
|
||||
ToolGroupInput,
|
||||
ToolRuntime,
|
||||
VectorIO,
|
||||
VectorStore,
|
||||
VectorStoreInput,
|
||||
)
|
||||
|
||||
LLAMA_STACK_BUILD_CONFIG_VERSION = 2
|
||||
LLAMA_STACK_RUN_CONFIG_VERSION = 2
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from pydantic import BaseModel
|
|||
from llama_stack.core.datatypes import BuildConfig, DistributionSpec
|
||||
from llama_stack.core.external import load_external_apis
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
from llama_stack_api import (
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
|
|
|
|||
|
|
@ -7,9 +7,9 @@
|
|||
|
||||
import yaml
|
||||
|
||||
from llama_stack.apis.datatypes import Api, ExternalApiSpec
|
||||
from llama_stack.core.datatypes import BuildConfig, StackRunConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import Api, ExternalApiSpec
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
|
|
|||
|
|
@ -8,17 +8,17 @@ from importlib.metadata import version
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.inspect import (
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.core.external import load_external_apis
|
||||
from llama_stack.core.server.routes import get_all_api_routes
|
||||
from llama_stack_api import (
|
||||
HealthInfo,
|
||||
HealthStatus,
|
||||
Inspect,
|
||||
ListRoutesResponse,
|
||||
RouteInfo,
|
||||
VersionInfo,
|
||||
)
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.core.external import load_external_apis
|
||||
from llama_stack.core.server.routes import get_all_api_routes
|
||||
from llama_stack.providers.datatypes import HealthStatus
|
||||
|
||||
|
||||
class DistributionInspectConfig(BaseModel):
|
||||
|
|
|
|||
|
|
@ -19,6 +19,8 @@ import httpx
|
|||
import yaml
|
||||
from fastapi import Response as FastAPIResponse
|
||||
|
||||
from llama_stack_api import is_unwrapped_body_param
|
||||
|
||||
try:
|
||||
from llama_stack_client import (
|
||||
NOT_GIVEN,
|
||||
|
|
@ -57,7 +59,6 @@ from llama_stack.core.utils.config import redact_sensitive_fields
|
|||
from llama_stack.core.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.core.utils.exec import in_notebook
|
||||
from llama_stack.log import get_logger, setup_logging
|
||||
from llama_stack.strong_typing.inspection import is_unwrapped_body_param
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
|
|
|||
|
|
@ -9,9 +9,9 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.prompts import ListPromptsResponse, Prompt, Prompts
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||
from llama_stack_api import ListPromptsResponse, Prompt, Prompts
|
||||
|
||||
|
||||
class PromptServiceConfig(BaseModel):
|
||||
|
|
|
|||
|
|
@ -9,9 +9,8 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus
|
||||
from llama_stack_api import HealthResponse, HealthStatus, ListProvidersResponse, ProviderInfo, Providers
|
||||
|
||||
from .datatypes import StackRunConfig
|
||||
from .utils.config import redact_sensitive_fields
|
||||
|
|
|
|||
|
|
@ -8,29 +8,6 @@ import importlib.metadata
|
|||
import inspect
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.batches import Batches
|
||||
from llama_stack.apis.benchmarks import Benchmarks
|
||||
from llama_stack.apis.conversations import Conversations
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.datatypes import ExternalApiSpec
|
||||
from llama_stack.apis.eval import Eval
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference, InferenceProvider
|
||||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.post_training import PostTraining
|
||||
from llama_stack.apis.prompts import Prompts
|
||||
from llama_stack.apis.providers import Providers as ProvidersAPI
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||
from llama_stack.apis.shields import Shields
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
|
||||
from llama_stack.core.client import get_client_impl
|
||||
from llama_stack.core.datatypes import (
|
||||
AccessRule,
|
||||
|
|
@ -44,17 +21,44 @@ from llama_stack.core.external import load_external_apis
|
|||
from llama_stack.core.store import DistributionRegistry
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
from llama_stack_api import (
|
||||
LLAMA_STACK_API_V1ALPHA,
|
||||
Agents,
|
||||
Api,
|
||||
Batches,
|
||||
Benchmarks,
|
||||
BenchmarksProtocolPrivate,
|
||||
Conversations,
|
||||
DatasetIO,
|
||||
Datasets,
|
||||
DatasetsProtocolPrivate,
|
||||
Eval,
|
||||
ExternalApiSpec,
|
||||
Files,
|
||||
Inference,
|
||||
InferenceProvider,
|
||||
Inspect,
|
||||
Models,
|
||||
ModelsProtocolPrivate,
|
||||
PostTraining,
|
||||
Prompts,
|
||||
ProviderSpec,
|
||||
RemoteProviderConfig,
|
||||
RemoteProviderSpec,
|
||||
Safety,
|
||||
Scoring,
|
||||
ScoringFunctions,
|
||||
ScoringFunctionsProtocolPrivate,
|
||||
Shields,
|
||||
ShieldsProtocolPrivate,
|
||||
ToolGroups,
|
||||
ToolGroupsProtocolPrivate,
|
||||
ToolRuntime,
|
||||
VectorIO,
|
||||
VectorStore,
|
||||
)
|
||||
from llama_stack_api import (
|
||||
Providers as ProvidersAPI,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
|
|
|||
|
|
@ -12,8 +12,8 @@ from llama_stack.core.datatypes import (
|
|||
)
|
||||
from llama_stack.core.stack import StackRunConfig
|
||||
from llama_stack.core.store import DistributionRegistry
|
||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||
from llama_stack_api import Api, RoutingTable
|
||||
|
||||
|
||||
async def get_routing_table_impl(
|
||||
|
|
|
|||
|
|
@ -6,11 +6,8 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
from llama_stack_api import DatasetIO, DatasetPurpose, DataSource, PaginatedResponse, RoutingTable
|
||||
|
||||
logger = get_logger(name=__name__, category="core::routers")
|
||||
|
||||
|
|
|
|||
|
|
@ -6,15 +6,18 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
|
||||
from llama_stack.apis.scoring import (
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import (
|
||||
BenchmarkConfig,
|
||||
Eval,
|
||||
EvaluateResponse,
|
||||
Job,
|
||||
RoutingTable,
|
||||
ScoreBatchResponse,
|
||||
ScoreResponse,
|
||||
Scoring,
|
||||
ScoringFnParams,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
|
||||
logger = get_logger(name=__name__, category="core::routers")
|
||||
|
||||
|
|
|
|||
|
|
@ -15,13 +15,25 @@ from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatC
|
|||
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
|
||||
from llama_stack.apis.inference import (
|
||||
from llama_stack.core.telemetry.telemetry import MetricEvent
|
||||
from llama_stack.core.telemetry.tracing import enqueue_event, get_current_span
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||
from llama_stack_api import (
|
||||
HealthResponse,
|
||||
HealthStatus,
|
||||
Inference,
|
||||
ListOpenAIChatCompletionResponse,
|
||||
ModelNotFoundError,
|
||||
ModelType,
|
||||
ModelTypeError,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
|
|
@ -35,19 +47,8 @@ from llama_stack.apis.inference import (
|
|||
OpenAIMessageParam,
|
||||
Order,
|
||||
RerankResponse,
|
||||
RoutingTable,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
)
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.telemetry.telemetry import MetricEvent
|
||||
from llama_stack.core.telemetry.tracing import enqueue_event, get_current_span
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||
|
||||
logger = get_logger(name=__name__, category="core::routers")
|
||||
|
||||
|
|
@ -416,7 +417,7 @@ class InferenceRouter(Inference):
|
|||
prompt_tokens=chunk.usage.prompt_tokens,
|
||||
completion_tokens=chunk.usage.completion_tokens,
|
||||
total_tokens=chunk.usage.total_tokens,
|
||||
model_id=fully_qualified_model_id,
|
||||
fully_qualified_model_id=fully_qualified_model_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
for metric in metrics:
|
||||
|
|
|
|||
|
|
@ -6,13 +6,9 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import OpenAIMessageParam
|
||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||
from llama_stack.apis.safety.safety import ModerationObject
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.core.datatypes import SafetyConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
from llama_stack_api import ModerationObject, OpenAIMessageParam, RoutingTable, RunShieldResponse, Safety, Shield
|
||||
|
||||
logger = get_logger(name=__name__, category="core::routers")
|
||||
|
||||
|
|
|
|||
|
|
@ -6,14 +6,12 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import (
|
||||
URL,
|
||||
)
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
||||
|
||||
|
|
@ -36,16 +34,16 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
logger.debug("ToolRuntimeRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None) -> Any:
|
||||
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
|
||||
provider = await self.routing_table.get_provider_impl(tool_name)
|
||||
return await provider.invoke_tool(
|
||||
tool_name=tool_name,
|
||||
kwargs=kwargs,
|
||||
authorization=authorization,
|
||||
)
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None, authorization: str | None = None
|
||||
) -> ListToolDefsResponse:
|
||||
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||
return await self.routing_table.list_tools(tool_group_id)
|
||||
return await self.routing_table.list_tools(tool_group_id, authorization=authorization)
|
||||
|
|
|
|||
|
|
@ -10,13 +10,20 @@ from typing import Annotated, Any
|
|||
|
||||
from fastapi import Body
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.vector_io import (
|
||||
from llama_stack.core.datatypes import VectorStoresConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import (
|
||||
Chunk,
|
||||
HealthResponse,
|
||||
HealthStatus,
|
||||
InterleavedContent,
|
||||
ModelNotFoundError,
|
||||
ModelType,
|
||||
ModelTypeError,
|
||||
OpenAICreateVectorStoreFileBatchRequestWithExtraBody,
|
||||
OpenAICreateVectorStoreRequestWithExtraBody,
|
||||
QueryChunksResponse,
|
||||
RoutingTable,
|
||||
SearchRankingOptions,
|
||||
VectorIO,
|
||||
VectorStoreChunkingStrategy,
|
||||
|
|
@ -33,9 +40,6 @@ from llama_stack.apis.vector_io import (
|
|||
VectorStoreObject,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.core.datatypes import VectorStoresConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||
|
||||
logger = get_logger(name=__name__, category="core::routers")
|
||||
|
||||
|
|
@ -122,6 +126,14 @@ class VectorIORouter(VectorIO):
|
|||
if embedding_model is not None and embedding_dimension is None:
|
||||
embedding_dimension = await self._get_embedding_model_dimension(embedding_model)
|
||||
|
||||
# Validate that embedding model exists and is of the correct type
|
||||
if embedding_model is not None:
|
||||
model = await self.routing_table.get_object_by_identifier("model", embedding_model)
|
||||
if model is None:
|
||||
raise ModelNotFoundError(embedding_model)
|
||||
if model.model_type != ModelType.embedding:
|
||||
raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding)
|
||||
|
||||
# Auto-select provider if not specified
|
||||
if provider_id is None:
|
||||
num_providers = len(self.routing_table.impls_by_provider_id)
|
||||
|
|
|
|||
|
|
@ -6,11 +6,11 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
|
||||
from llama_stack.core.datatypes import (
|
||||
BenchmarkWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import Benchmark, Benchmarks, ListBenchmarksResponse
|
||||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
||||
|
|
|
|||
|
|
@ -6,9 +6,6 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.errors import ModelNotFoundError
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||
from llama_stack.core.access_control.datatypes import Action
|
||||
from llama_stack.core.datatypes import (
|
||||
|
|
@ -21,7 +18,7 @@ from llama_stack.core.datatypes import (
|
|||
from llama_stack.core.request_headers import get_authenticated_user
|
||||
from llama_stack.core.store import DistributionRegistry
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||
from llama_stack_api import Api, Model, ModelNotFoundError, ResourceType, RoutingTable
|
||||
|
||||
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||
|
||||
|
|
|
|||
|
|
@ -7,22 +7,22 @@
|
|||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.errors import DatasetNotFoundError
|
||||
from llama_stack.apis.datasets import (
|
||||
from llama_stack.core.datatypes import (
|
||||
DatasetWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import (
|
||||
Dataset,
|
||||
DatasetNotFoundError,
|
||||
DatasetPurpose,
|
||||
Datasets,
|
||||
DatasetType,
|
||||
DataSource,
|
||||
ListDatasetsResponse,
|
||||
ResourceType,
|
||||
RowsDataSource,
|
||||
URIDataSource,
|
||||
)
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.core.datatypes import (
|
||||
DatasetWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
||||
|
|
|
|||
|
|
@ -7,8 +7,6 @@
|
|||
import time
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.errors import ModelNotFoundError
|
||||
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
|
||||
from llama_stack.core.datatypes import (
|
||||
ModelWithOwner,
|
||||
RegistryEntrySource,
|
||||
|
|
@ -16,6 +14,15 @@ from llama_stack.core.datatypes import (
|
|||
from llama_stack.core.request_headers import PROVIDER_DATA_VAR, NeedsRequestProviderData
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import (
|
||||
ListModelsResponse,
|
||||
Model,
|
||||
ModelNotFoundError,
|
||||
Models,
|
||||
ModelType,
|
||||
OpenAIListModelsResponse,
|
||||
OpenAIModel,
|
||||
)
|
||||
|
||||
from .common import CommonRoutingTableImpl, lookup_model
|
||||
|
||||
|
|
|
|||
|
|
@ -4,18 +4,18 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import ParamType
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
ListScoringFunctionsResponse,
|
||||
ScoringFn,
|
||||
ScoringFnParams,
|
||||
ScoringFunctions,
|
||||
)
|
||||
from llama_stack.core.datatypes import (
|
||||
ScoringFnWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import (
|
||||
ListScoringFunctionsResponse,
|
||||
ParamType,
|
||||
ResourceType,
|
||||
ScoringFn,
|
||||
ScoringFnParams,
|
||||
ScoringFunctions,
|
||||
)
|
||||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
||||
|
|
|
|||
|
|
@ -6,12 +6,11 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
|
||||
from llama_stack.core.datatypes import (
|
||||
ShieldWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import ListShieldsResponse, ResourceType, Shield, Shields
|
||||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
||||
|
|
|
|||
|
|
@ -6,11 +6,17 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.common.errors import ToolGroupNotFoundError
|
||||
from llama_stack.apis.tools import ListToolDefsResponse, ListToolGroupsResponse, ToolDef, ToolGroup, ToolGroups
|
||||
from llama_stack.core.datatypes import AuthenticationRequiredError, ToolGroupWithOwner
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import (
|
||||
URL,
|
||||
ListToolDefsResponse,
|
||||
ListToolGroupsResponse,
|
||||
ToolDef,
|
||||
ToolGroup,
|
||||
ToolGroupNotFoundError,
|
||||
ToolGroups,
|
||||
)
|
||||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
||||
|
|
@ -43,7 +49,9 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
routing_key = self.tool_to_toolgroup[routing_key]
|
||||
return await super().get_provider_impl(routing_key, provider_id)
|
||||
|
||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolDefsResponse:
|
||||
async def list_tools(
|
||||
self, toolgroup_id: str | None = None, authorization: str | None = None
|
||||
) -> ListToolDefsResponse:
|
||||
if toolgroup_id:
|
||||
if group_id := parse_toolgroup_from_toolgroup_name_pair(toolgroup_id):
|
||||
toolgroup_id = group_id
|
||||
|
|
@ -55,7 +63,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
for toolgroup in toolgroups:
|
||||
if toolgroup.identifier not in self.toolgroups_to_tools:
|
||||
try:
|
||||
await self._index_tools(toolgroup)
|
||||
await self._index_tools(toolgroup, authorization=authorization)
|
||||
except AuthenticationRequiredError:
|
||||
# Send authentication errors back to the client so it knows
|
||||
# that it needs to supply credentials for remote MCP servers.
|
||||
|
|
@ -70,9 +78,11 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
|
||||
return ListToolDefsResponse(data=all_tools)
|
||||
|
||||
async def _index_tools(self, toolgroup: ToolGroup):
|
||||
async def _index_tools(self, toolgroup: ToolGroup, authorization: str | None = None):
|
||||
provider_impl = await super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
|
||||
tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint)
|
||||
tooldefs_response = await provider_impl.list_runtime_tools(
|
||||
toolgroup.identifier, toolgroup.mcp_endpoint, authorization=authorization
|
||||
)
|
||||
|
||||
tooldefs = tooldefs_response.data
|
||||
for t in tooldefs:
|
||||
|
|
|
|||
|
|
@ -6,12 +6,17 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.core.datatypes import (
|
||||
VectorStoreWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
# Removed VectorStores import to avoid exposing public API
|
||||
from llama_stack.apis.vector_io.vector_io import (
|
||||
from llama_stack_api import (
|
||||
ModelNotFoundError,
|
||||
ModelType,
|
||||
ModelTypeError,
|
||||
ResourceType,
|
||||
SearchRankingOptions,
|
||||
VectorStoreChunkingStrategy,
|
||||
VectorStoreDeleteResponse,
|
||||
|
|
@ -22,10 +27,6 @@ from llama_stack.apis.vector_io.vector_io import (
|
|||
VectorStoreObject,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.core.datatypes import (
|
||||
VectorStoreWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .common import CommonRoutingTableImpl, lookup_model
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ import httpx
|
|||
import jwt
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.common.errors import TokenValidationError
|
||||
from llama_stack.core.datatypes import (
|
||||
AuthenticationConfig,
|
||||
CustomAuthConfig,
|
||||
|
|
@ -23,6 +22,7 @@ from llama_stack.core.datatypes import (
|
|||
User,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import TokenValidationError
|
||||
|
||||
logger = get_logger(name=__name__, category="core::auth")
|
||||
|
||||
|
|
|
|||
|
|
@ -12,9 +12,8 @@ from typing import Any
|
|||
from aiohttp import hdrs
|
||||
from starlette.routing import Route
|
||||
|
||||
from llama_stack.apis.datatypes import Api, ExternalApiSpec
|
||||
from llama_stack.core.resolver import api_protocol_map
|
||||
from llama_stack.schema_utils import WebMethod
|
||||
from llama_stack_api import Api, ExternalApiSpec, WebMethod
|
||||
|
||||
EndpointFunc = Callable[..., Any]
|
||||
PathParams = dict[str, str]
|
||||
|
|
|
|||
|
|
@ -31,8 +31,6 @@ from fastapi.responses import JSONResponse, StreamingResponse
|
|||
from openai import BadRequestError
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.core.access_control.access_control import AccessDeniedError
|
||||
from llama_stack.core.datatypes import (
|
||||
AuthenticationRequiredError,
|
||||
|
|
@ -58,7 +56,7 @@ from llama_stack.core.utils.config import redact_sensitive_fields
|
|||
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
|
||||
from llama_stack.core.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.log import LoggingConfig, get_logger, setup_logging
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack_api import Api, ConflictError, PaginatedResponse, ResourceNotFoundError
|
||||
|
||||
from .auth import AuthenticationMiddleware
|
||||
from .quota import QuotaMiddleware
|
||||
|
|
@ -526,8 +524,8 @@ def extract_path_params(route: str) -> list[str]:
|
|||
|
||||
def remove_disabled_providers(obj):
|
||||
if isinstance(obj, dict):
|
||||
keys = ["provider_id", "shield_id", "provider_model_id", "model_id"]
|
||||
if any(k in obj and obj[k] in ("__disabled__", "", None) for k in keys):
|
||||
# Filter out items where provider_id is explicitly disabled or empty
|
||||
if "provider_id" in obj and obj["provider_id"] in ("__disabled__", "", None):
|
||||
return None
|
||||
return {k: v for k, v in ((k, remove_disabled_providers(v)) for k, v in obj.items()) if v is not None}
|
||||
elif isinstance(obj, list):
|
||||
|
|
|
|||
|
|
@ -13,26 +13,6 @@ from typing import Any
|
|||
|
||||
import yaml
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.batches import Batches
|
||||
from llama_stack.apis.benchmarks import Benchmarks
|
||||
from llama_stack.apis.conversations import Conversations
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.eval import Eval
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.post_training import PostTraining
|
||||
from llama_stack.apis.prompts import Prompts
|
||||
from llama_stack.apis.providers import Providers
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||
from llama_stack.apis.shields import Shields
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
|
||||
from llama_stack.core.datatypes import Provider, SafetyConfig, StackRunConfig, VectorStoresConfig
|
||||
from llama_stack.core.distribution import get_provider_registry
|
||||
|
|
@ -54,7 +34,30 @@ from llama_stack.core.storage.datatypes import (
|
|||
from llama_stack.core.store.registry import create_dist_registry
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack_api import (
|
||||
Agents,
|
||||
Api,
|
||||
Batches,
|
||||
Benchmarks,
|
||||
Conversations,
|
||||
DatasetIO,
|
||||
Datasets,
|
||||
Eval,
|
||||
Files,
|
||||
Inference,
|
||||
Inspect,
|
||||
Models,
|
||||
PostTraining,
|
||||
Prompts,
|
||||
Providers,
|
||||
Safety,
|
||||
Scoring,
|
||||
ScoringFunctions,
|
||||
Shields,
|
||||
ToolGroups,
|
||||
ToolRuntime,
|
||||
VectorIO,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ from pydantic import BaseModel, Field
|
|||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import Primitive
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||
from llama_stack_api import json_schema_type, register_schema
|
||||
|
||||
ROOT_SPAN_MARKERS = ["__root__", "__root_span__"]
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.datatypes import (
|
||||
BuildProvider,
|
||||
ModelInput,
|
||||
|
|
@ -17,6 +16,7 @@ from llama_stack.providers.inline.inference.sentence_transformers import (
|
|||
SentenceTransformersInferenceConfig,
|
||||
)
|
||||
from llama_stack.providers.remote.vector_io.chroma import ChromaVectorIOConfig
|
||||
from llama_stack_api import ModelType
|
||||
|
||||
|
||||
def get_distribution_template() -> DistributionTemplate:
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@
|
|||
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.datatypes import (
|
||||
BuildProvider,
|
||||
ModelInput,
|
||||
|
|
@ -22,6 +21,7 @@ from llama_stack.providers.inline.inference.sentence_transformers import (
|
|||
SentenceTransformersInferenceConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack_api import ModelType
|
||||
|
||||
|
||||
def get_distribution_template() -> DistributionTemplate:
|
||||
|
|
|
|||
|
|
@ -5,8 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.apis.datasets import DatasetPurpose, URIDataSource
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.datatypes import (
|
||||
BenchmarkInput,
|
||||
BuildProvider,
|
||||
|
|
@ -34,6 +32,7 @@ from llama_stack.providers.remote.vector_io.pgvector.config import (
|
|||
PGVectorVectorIOConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
|
||||
from llama_stack_api import DatasetPurpose, ModelType, URIDataSource
|
||||
|
||||
|
||||
def get_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderModelEntry]]]:
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@ from llama_stack.core.datatypes import (
|
|||
)
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
|
||||
from llama_stack.providers.datatypes import RemoteProviderSpec
|
||||
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
|
||||
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||
SentenceTransformersInferenceConfig,
|
||||
|
|
@ -38,6 +37,7 @@ from llama_stack.providers.remote.vector_io.qdrant.config import QdrantVectorIOC
|
|||
from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig
|
||||
from llama_stack.providers.utils.kvstore.config import PostgresKVStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig
|
||||
from llama_stack_api import RemoteProviderSpec
|
||||
|
||||
|
||||
def _get_config_for_provider(provider_spec: ProviderSpec) -> dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -12,8 +12,6 @@ import rich
|
|||
import yaml
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.datasets import DatasetPurpose
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.datatypes import (
|
||||
LLAMA_STACK_RUN_CONFIG_VERSION,
|
||||
Api,
|
||||
|
|
@ -44,6 +42,7 @@ from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
|||
from llama_stack.providers.utils.kvstore.config import get_pip_packages as get_kv_pip_packages
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import get_pip_packages as get_sql_pip_packages
|
||||
from llama_stack_api import DatasetPurpose, ModelType
|
||||
|
||||
|
||||
def filter_empty_values(obj: Any) -> Any:
|
||||
|
|
|
|||
|
|
@ -5,29 +5,29 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
from llama_stack_api import (
|
||||
Agents,
|
||||
Conversations,
|
||||
Inference,
|
||||
ListOpenAIResponseInputItem,
|
||||
ListOpenAIResponseObject,
|
||||
OpenAIDeleteResponseObject,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponsePrompt,
|
||||
OpenAIResponseText,
|
||||
Order,
|
||||
ResponseGuardrail,
|
||||
Safety,
|
||||
ToolGroups,
|
||||
ToolRuntime,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.apis.agents.agents import ResponseGuardrail
|
||||
from llama_stack.apis.agents.openai_responses import OpenAIResponsePrompt, OpenAIResponseText
|
||||
from llama_stack.apis.conversations import Conversations
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
)
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
|
||||
from .config import MetaReferenceAgentsImplConfig
|
||||
from .responses.openai_responses import OpenAIResponsesImpl
|
||||
|
|
|
|||
|
|
@ -10,12 +10,20 @@ from collections.abc import AsyncIterator
|
|||
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from llama_stack.apis.agents import Order
|
||||
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.responses.responses_store import (
|
||||
ResponsesStore,
|
||||
_OpenAIResponseObjectWithInputAndMessages,
|
||||
)
|
||||
from llama_stack_api import (
|
||||
ConversationItem,
|
||||
Conversations,
|
||||
Inference,
|
||||
InvalidConversationIdError,
|
||||
ListOpenAIResponseInputItem,
|
||||
ListOpenAIResponseObject,
|
||||
OpenAIDeleteResponseObject,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputTool,
|
||||
|
|
@ -25,24 +33,13 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponsePrompt,
|
||||
OpenAIResponseText,
|
||||
OpenAIResponseTextFormat,
|
||||
)
|
||||
from llama_stack.apis.common.errors import (
|
||||
InvalidConversationIdError,
|
||||
)
|
||||
from llama_stack.apis.conversations import Conversations
|
||||
from llama_stack.apis.conversations.conversations import ConversationItem
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
OpenAIMessageParam,
|
||||
OpenAISystemMessageParam,
|
||||
)
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.responses.responses_store import (
|
||||
ResponsesStore,
|
||||
_OpenAIResponseObjectWithInputAndMessages,
|
||||
Order,
|
||||
ResponseGuardrailSpec,
|
||||
Safety,
|
||||
ToolGroups,
|
||||
ToolRuntime,
|
||||
VectorIO,
|
||||
)
|
||||
|
||||
from .streaming import StreamingResponseOrchestrator
|
||||
|
|
@ -260,6 +257,19 @@ class OpenAIResponsesImpl:
|
|||
stream = bool(stream)
|
||||
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
|
||||
|
||||
# Validate MCP tools: ensure Authorization header is not passed via headers dict
|
||||
if tools:
|
||||
from llama_stack_api.openai_responses import OpenAIResponseInputToolMCP
|
||||
|
||||
for tool in tools:
|
||||
if isinstance(tool, OpenAIResponseInputToolMCP) and tool.headers:
|
||||
for key in tool.headers.keys():
|
||||
if key.lower() == "authorization":
|
||||
raise ValueError(
|
||||
"Authorization header cannot be passed via 'headers'. "
|
||||
"Please use the 'authorization' parameter instead."
|
||||
)
|
||||
|
||||
guardrail_ids = extract_guardrail_ids(guardrails) if guardrails else []
|
||||
|
||||
if conversation is not None:
|
||||
|
|
|
|||
|
|
@ -8,10 +8,21 @@ import uuid
|
|||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
from llama_stack.core.telemetry import tracing
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
from llama_stack_api import (
|
||||
AllowedToolsFilter,
|
||||
ApprovalFilter,
|
||||
Inference,
|
||||
MCPListToolsTool,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChoice,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseContentPartOutputText,
|
||||
OpenAIResponseContentPartReasoningText,
|
||||
OpenAIResponseContentPartRefusal,
|
||||
|
|
@ -56,19 +67,6 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseUsageOutputTokensDetails,
|
||||
WebSearchToolTypes,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChoice,
|
||||
OpenAIMessageParam,
|
||||
)
|
||||
from llama_stack.core.telemetry import tracing
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
|
||||
from .types import ChatCompletionContext, ChatCompletionResult
|
||||
from .utils import (
|
||||
|
|
@ -1025,9 +1023,9 @@ class StreamingResponseOrchestrator:
|
|||
"""Process all tools and emit appropriate streaming events."""
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from llama_stack.apis.tools import ToolDef
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
|
||||
from llama_stack_api import ToolDef
|
||||
|
||||
def make_openai_tool(tool_name: str, tool: ToolDef) -> ChatCompletionToolParam:
|
||||
tool_def = ToolDefinition(
|
||||
|
|
@ -1093,10 +1091,12 @@ class StreamingResponseOrchestrator:
|
|||
"server_url": mcp_tool.server_url,
|
||||
"mcp_list_tools_id": list_id,
|
||||
}
|
||||
# List MCP tools with authorization from tool config
|
||||
async with tracing.span("list_mcp_tools", attributes):
|
||||
tool_defs = await list_mcp_tools(
|
||||
endpoint=mcp_tool.server_url,
|
||||
headers=mcp_tool.headers or {},
|
||||
headers=mcp_tool.headers,
|
||||
authorization=mcp_tool.authorization,
|
||||
)
|
||||
|
||||
# Create the MCP list tools message
|
||||
|
|
|
|||
|
|
@ -9,7 +9,14 @@ import json
|
|||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
from llama_stack.core.telemetry import tracing
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import (
|
||||
ImageContentItem,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIImageURL,
|
||||
OpenAIResponseInputToolFileSearch,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseObjectStreamResponseFileSearchCallCompleted,
|
||||
|
|
@ -23,24 +30,14 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseObjectStreamResponseWebSearchCallSearching,
|
||||
OpenAIResponseOutputMessageFileSearchToolCall,
|
||||
OpenAIResponseOutputMessageFileSearchToolCallResults,
|
||||
OpenAIResponseOutputMessageMCPCall,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
)
|
||||
from llama_stack.apis.common.content_types import (
|
||||
ImageContentItem,
|
||||
TextContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIImageURL,
|
||||
OpenAIToolMessageParam,
|
||||
TextContentItem,
|
||||
ToolGroups,
|
||||
ToolInvocationResult,
|
||||
ToolRuntime,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.core.telemetry import tracing
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .types import ChatCompletionContext, ToolExecutionResult
|
||||
|
||||
|
|
@ -299,12 +296,14 @@ class ToolExecutor:
|
|||
"server_url": mcp_tool.server_url,
|
||||
"tool_name": function_name,
|
||||
}
|
||||
# Invoke MCP tool with authorization from tool config
|
||||
async with tracing.span("invoke_mcp_tool", attributes):
|
||||
result = await invoke_mcp_tool(
|
||||
endpoint=mcp_tool.server_url,
|
||||
headers=mcp_tool.headers or {},
|
||||
tool_name=function_name,
|
||||
kwargs=tool_kwargs,
|
||||
headers=mcp_tool.headers,
|
||||
authorization=mcp_tool.authorization,
|
||||
)
|
||||
elif function_name == "knowledge_search":
|
||||
response_file_search_tool = (
|
||||
|
|
@ -398,6 +397,10 @@ class ToolExecutor:
|
|||
# Build output message
|
||||
message: Any
|
||||
if mcp_tool_to_server and function.name in mcp_tool_to_server:
|
||||
from llama_stack_api import (
|
||||
OpenAIResponseOutputMessageMCPCall,
|
||||
)
|
||||
|
||||
message = OpenAIResponseOutputMessageMCPCall(
|
||||
id=item_id,
|
||||
arguments=function.arguments,
|
||||
|
|
|
|||
|
|
@ -10,7 +10,10 @@ from typing import cast
|
|||
from openai.types.chat import ChatCompletionToolParam
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
from llama_stack_api import (
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseInputToolFileSearch,
|
||||
|
|
@ -26,7 +29,6 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseTool,
|
||||
OpenAIResponseToolMCP,
|
||||
)
|
||||
from llama_stack.apis.inference import OpenAIChatCompletionToolCall, OpenAIMessageParam, OpenAIResponseFormatParam
|
||||
|
||||
|
||||
class ToolExecutionResult(BaseModel):
|
||||
|
|
|
|||
|
|
@ -9,9 +9,23 @@ import re
|
|||
import uuid
|
||||
from collections.abc import Sequence
|
||||
|
||||
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
from llama_stack_api import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIChoice,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIImageURL,
|
||||
OpenAIJSONSchema,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseAnnotationFileCitation,
|
||||
OpenAIResponseFormatJSONObject,
|
||||
OpenAIResponseFormatJSONSchema,
|
||||
OpenAIResponseFormatParam,
|
||||
OpenAIResponseFormatText,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputFunctionToolCallOutput,
|
||||
OpenAIResponseInputMessageContent,
|
||||
|
|
@ -27,28 +41,12 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseOutputMessageMCPCall,
|
||||
OpenAIResponseOutputMessageMCPListTools,
|
||||
OpenAIResponseText,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIChoice,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIImageURL,
|
||||
OpenAIJSONSchema,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatJSONObject,
|
||||
OpenAIResponseFormatJSONSchema,
|
||||
OpenAIResponseFormatParam,
|
||||
OpenAIResponseFormatText,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
ResponseGuardrailSpec,
|
||||
Safety,
|
||||
)
|
||||
from llama_stack.apis.safety import Safety
|
||||
|
||||
|
||||
async def convert_chat_choice_to_response_message(
|
||||
|
|
|
|||
|
|
@ -6,10 +6,9 @@
|
|||
|
||||
import asyncio
|
||||
|
||||
from llama_stack.apis.inference import OpenAIMessageParam
|
||||
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
||||
from llama_stack.core.telemetry import tracing
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import OpenAIMessageParam, Safety, SafetyViolation, ViolationLevel
|
||||
|
||||
log = get_logger(name=__name__, category="agents::meta_reference")
|
||||
|
||||
|
|
|
|||
|
|
@ -6,11 +6,9 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.core.datatypes import AccessRule, Api
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack_api import Files, Inference, Models
|
||||
|
||||
from .batches import ReferenceBatchesImpl
|
||||
from .config import ReferenceBatchesImplConfig
|
||||
|
|
|
|||
|
|
@ -16,24 +16,28 @@ from typing import Any, Literal
|
|||
from openai.types.batch import BatchError, Errors
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.batches import Batches, BatchObject, ListBatchesResponse
|
||||
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
|
||||
from llama_stack.apis.files import Files, OpenAIFilePurpose
|
||||
from llama_stack.apis.inference import (
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
from llama_stack_api import (
|
||||
Batches,
|
||||
BatchObject,
|
||||
ConflictError,
|
||||
Files,
|
||||
Inference,
|
||||
ListBatchesResponse,
|
||||
Models,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIFilePurpose,
|
||||
OpenAIMessageParam,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
ResourceNotFoundError,
|
||||
)
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
from .config import ReferenceBatchesImplConfig
|
||||
|
||||
|
|
|
|||
|
|
@ -5,13 +5,10 @@
|
|||
# the root directory of this source tree.
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Dataset
|
||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.pagination import paginate_records
|
||||
from llama_stack_api import Dataset, DatasetIO, DatasetsProtocolPrivate, PaginatedResponse
|
||||
|
||||
from .config import LocalFSDatasetIOConfig
|
||||
|
||||
|
|
|
|||
|
|
@ -8,24 +8,27 @@ from typing import Any
|
|||
|
||||
from tqdm import tqdm
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.benchmarks import Benchmark
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.inference import (
|
||||
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack_api import (
|
||||
Agents,
|
||||
Benchmark,
|
||||
BenchmarkConfig,
|
||||
BenchmarksProtocolPrivate,
|
||||
DatasetIO,
|
||||
Datasets,
|
||||
Eval,
|
||||
EvaluateResponse,
|
||||
Inference,
|
||||
Job,
|
||||
JobStatus,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
Scoring,
|
||||
)
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
|
||||
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
from .....apis.common.job_types import Job, JobStatus
|
||||
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse
|
||||
from .config import MetaReferenceEvalConfig
|
||||
|
||||
EVAL_TASKS_PREFIX = "benchmarks:"
|
||||
|
|
|
|||
|
|
@ -11,16 +11,6 @@ from typing import Annotated
|
|||
|
||||
from fastapi import Depends, File, Form, Response, UploadFile
|
||||
|
||||
from llama_stack.apis.common.errors import ResourceNotFoundError
|
||||
from llama_stack.apis.common.responses import Order
|
||||
from llama_stack.apis.files import (
|
||||
ExpiresAfter,
|
||||
Files,
|
||||
ListOpenAIFileResponse,
|
||||
OpenAIFileDeleteResponse,
|
||||
OpenAIFileObject,
|
||||
OpenAIFilePurpose,
|
||||
)
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.core.id_generation import generate_object_id
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -28,6 +18,16 @@ from llama_stack.providers.utils.files.form_data import parse_expires_after
|
|||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
|
||||
from llama_stack_api import (
|
||||
ExpiresAfter,
|
||||
Files,
|
||||
ListOpenAIFileResponse,
|
||||
OpenAIFileDeleteResponse,
|
||||
OpenAIFileObject,
|
||||
OpenAIFilePurpose,
|
||||
Order,
|
||||
ResourceNotFoundError,
|
||||
)
|
||||
|
||||
from .config import LocalfsFilesImplConfig
|
||||
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from llama_stack.apis.inference import QuantizationConfig
|
||||
from llama_stack.providers.utils.inference import supported_inference_models
|
||||
from llama_stack_api import QuantizationConfig
|
||||
|
||||
|
||||
class MetaReferenceInferenceConfig(BaseModel):
|
||||
|
|
|
|||
|
|
@ -10,7 +10,13 @@ from typing import Optional
|
|||
import torch
|
||||
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
from llama_stack.models.llama.datatypes import QuantizationMode, ToolPromptFormat
|
||||
from llama_stack.models.llama.llama3.generation import Llama3
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
||||
from llama_stack.models.llama.llama4.generation import Llama4
|
||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
||||
from llama_stack.models.llama.sku_types import Model, ModelFamily
|
||||
from llama_stack_api import (
|
||||
GreedySamplingStrategy,
|
||||
JsonSchemaResponseFormat,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
|
|
@ -20,12 +26,6 @@ from llama_stack.apis.inference import (
|
|||
SamplingParams,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import QuantizationMode, ToolPromptFormat
|
||||
from llama_stack.models.llama.llama3.generation import Llama3
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
||||
from llama_stack.models.llama.llama4.generation import Llama4
|
||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
||||
from llama_stack.models.llama.sku_types import Model, ModelFamily
|
||||
|
||||
from .common import model_checkpoint_dir
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
|
|
|
|||
|
|
@ -9,22 +9,6 @@ import time
|
|||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
InferenceProvider,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAIChatCompletionUsage,
|
||||
OpenAIChoice,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
OpenAIUserMessageParam,
|
||||
ToolChoice,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import RawMessage, RawTextItem, ToolDefinition
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
||||
|
|
@ -40,7 +24,6 @@ from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
|
|||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
)
|
||||
|
|
@ -48,6 +31,22 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
ModelRegistryHelper,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
from llama_stack_api import (
|
||||
InferenceProvider,
|
||||
Model,
|
||||
ModelsProtocolPrivate,
|
||||
ModelType,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAIChatCompletionUsage,
|
||||
OpenAIChoice,
|
||||
OpenAICompletion,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
OpenAIUserMessageParam,
|
||||
ToolChoice,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .generators import LlamaGenerator
|
||||
|
|
@ -376,7 +375,7 @@ class MetaReferenceInferenceImpl(
|
|||
# Convert tool calls to OpenAI format
|
||||
openai_tool_calls = None
|
||||
if decoded_message.tool_calls:
|
||||
from llama_stack.apis.inference import (
|
||||
from llama_stack_api import (
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
)
|
||||
|
|
@ -441,15 +440,15 @@ class MetaReferenceInferenceImpl(
|
|||
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||
) -> AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
"""Stream chat completion chunks as they're generated."""
|
||||
from llama_stack.apis.inference import (
|
||||
from llama_stack.models.llama.datatypes import StopReason
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import decode_assistant_message
|
||||
from llama_stack_api import (
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIChoiceDelta,
|
||||
OpenAIChunkChoice,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import StopReason
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import decode_assistant_message
|
||||
|
||||
response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
|
||||
created = int(time.time())
|
||||
|
|
|
|||
|
|
@ -6,22 +6,21 @@
|
|||
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
InferenceProvider,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAICompletion,
|
||||
)
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
)
|
||||
from llama_stack_api import (
|
||||
InferenceProvider,
|
||||
Model,
|
||||
ModelsProtocolPrivate,
|
||||
ModelType,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAICompletion,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
)
|
||||
|
||||
from .config import SentenceTransformersInferenceConfig
|
||||
|
||||
|
|
|
|||
|
|
@ -12,14 +12,10 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.type_system import (
|
||||
ChatCompletionInputType,
|
||||
DialogType,
|
||||
StringType,
|
||||
)
|
||||
from llama_stack.providers.utils.common.data_schema_validator import (
|
||||
ColumnName,
|
||||
)
|
||||
from llama_stack_api import ChatCompletionInputType, DialogType, StringType
|
||||
|
||||
EXPECTED_DATASET_SCHEMA: dict[str, list[dict[str, Any]]] = {
|
||||
"instruct": [
|
||||
|
|
|
|||
|
|
@ -6,11 +6,16 @@
|
|||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.post_training import (
|
||||
from llama_stack.providers.inline.post_training.huggingface.config import (
|
||||
HuggingFacePostTrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
|
||||
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
|
||||
from llama_stack_api import (
|
||||
AlgorithmConfig,
|
||||
Checkpoint,
|
||||
DatasetIO,
|
||||
Datasets,
|
||||
DPOAlignmentConfig,
|
||||
JobStatus,
|
||||
ListPostTrainingJobsResponse,
|
||||
|
|
@ -19,11 +24,6 @@ from llama_stack.apis.post_training import (
|
|||
PostTrainingJobStatusResponse,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.huggingface.config import (
|
||||
HuggingFacePostTrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
|
||||
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
|
||||
|
||||
|
||||
class TrainingArtifactType(Enum):
|
||||
|
|
|
|||
|
|
@ -18,16 +18,16 @@ from transformers import (
|
|||
)
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.post_training import (
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||
from llama_stack_api import (
|
||||
Checkpoint,
|
||||
DataConfig,
|
||||
DatasetIO,
|
||||
Datasets,
|
||||
LoraFinetuningConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||
|
||||
from ..config import HuggingFacePostTrainingConfig
|
||||
from ..utils import (
|
||||
|
|
|
|||
|
|
@ -16,15 +16,15 @@ from transformers import (
|
|||
)
|
||||
from trl import DPOConfig, DPOTrainer
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.post_training import (
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||
from llama_stack_api import (
|
||||
Checkpoint,
|
||||
DatasetIO,
|
||||
Datasets,
|
||||
DPOAlignmentConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
|
||||
|
||||
from ..config import HuggingFacePostTrainingConfig
|
||||
from ..utils import (
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@ import torch
|
|||
from datasets import Dataset
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
from llama_stack_api import Checkpoint, DatasetIO, TrainingConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
|
@ -34,8 +36,6 @@ class HFAutoModel(Protocol):
|
|||
def save_pretrained(self, save_directory: str | Path) -> None: ...
|
||||
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.post_training import Checkpoint, TrainingConfig
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .config import HuggingFacePostTrainingConfig
|
||||
|
|
|
|||
|
|
@ -21,9 +21,9 @@ from torchtune.models.llama3_1 import lora_llama3_1_8b
|
|||
from torchtune.models.llama3_2 import lora_llama3_2_3b
|
||||
from torchtune.modules.transforms import Transform
|
||||
|
||||
from llama_stack.apis.post_training import DatasetFormat
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.models.llama.sku_types import Model
|
||||
from llama_stack_api import DatasetFormat
|
||||
|
||||
BuildLoraModelCallable = Callable[..., torch.nn.Module]
|
||||
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
|
||||
|
|
|
|||
|
|
@ -6,11 +6,16 @@
|
|||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.post_training import (
|
||||
from llama_stack.providers.inline.post_training.torchtune.config import (
|
||||
TorchtunePostTrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
|
||||
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
|
||||
from llama_stack_api import (
|
||||
AlgorithmConfig,
|
||||
Checkpoint,
|
||||
DatasetIO,
|
||||
Datasets,
|
||||
DPOAlignmentConfig,
|
||||
JobStatus,
|
||||
ListPostTrainingJobsResponse,
|
||||
|
|
@ -20,11 +25,6 @@ from llama_stack.apis.post_training import (
|
|||
PostTrainingJobStatusResponse,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.torchtune.config import (
|
||||
TorchtunePostTrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
|
||||
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
|
||||
|
||||
|
||||
class TrainingArtifactType(Enum):
|
||||
|
|
|
|||
|
|
@ -32,17 +32,6 @@ from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup
|
|||
from torchtune.training.metric_logging import DiskLogger
|
||||
from tqdm import tqdm
|
||||
|
||||
from llama_stack.apis.common.training_types import PostTrainingMetric
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.post_training import (
|
||||
Checkpoint,
|
||||
DataConfig,
|
||||
LoraFinetuningConfig,
|
||||
OptimizerConfig,
|
||||
QATFinetuningConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||
from llama_stack.core.utils.model_utils import model_local_dir
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -56,6 +45,17 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
|
|||
TorchtunePostTrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
|
||||
from llama_stack_api import (
|
||||
Checkpoint,
|
||||
DataConfig,
|
||||
DatasetIO,
|
||||
Datasets,
|
||||
LoraFinetuningConfig,
|
||||
OptimizerConfig,
|
||||
PostTrainingMetric,
|
||||
QATFinetuningConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
|
||||
log = get_logger(name=__name__, category="post_training")
|
||||
|
||||
|
|
|
|||
|
|
@ -10,19 +10,20 @@ from typing import TYPE_CHECKING, Any
|
|||
if TYPE_CHECKING:
|
||||
from codeshield.cs import CodeShieldScanResult
|
||||
|
||||
from llama_stack.apis.inference import OpenAIMessageParam
|
||||
from llama_stack.apis.safety import (
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
SafetyViolation,
|
||||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
from llama_stack_api import (
|
||||
ModerationObject,
|
||||
ModerationObjectResults,
|
||||
OpenAIMessageParam,
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
SafetyViolation,
|
||||
Shield,
|
||||
ViolationLevel,
|
||||
)
|
||||
|
||||
from .config import CodeScannerConfig
|
||||
|
||||
|
|
|
|||
|
|
@ -9,29 +9,29 @@ import uuid
|
|||
from string import Template
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAIMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.apis.safety import (
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
SafetyViolation,
|
||||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.core.datatypes import Api
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import Role
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
from llama_stack_api import (
|
||||
ImageContentItem,
|
||||
Inference,
|
||||
ModerationObject,
|
||||
ModerationObjectResults,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAIMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
SafetyViolation,
|
||||
Shield,
|
||||
ShieldsProtocolPrivate,
|
||||
TextContentItem,
|
||||
ViolationLevel,
|
||||
)
|
||||
|
||||
from .config import LlamaGuardConfig
|
||||
|
||||
|
|
|
|||
|
|
@ -9,20 +9,20 @@ from typing import Any
|
|||
import torch
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
from llama_stack.apis.inference import OpenAIMessageParam
|
||||
from llama_stack.apis.safety import (
|
||||
from llama_stack.core.utils.model_utils import model_local_dir
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
from llama_stack_api import (
|
||||
ModerationObject,
|
||||
OpenAIMessageParam,
|
||||
RunShieldResponse,
|
||||
Safety,
|
||||
SafetyViolation,
|
||||
Shield,
|
||||
ShieldsProtocolPrivate,
|
||||
ShieldStore,
|
||||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.apis.safety.safety import ModerationObject
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.core.utils.model_utils import model_local_dir
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
|
||||
from .config import PromptGuardConfig, PromptGuardType
|
||||
|
||||
|
|
|
|||
|
|
@ -5,21 +5,22 @@
|
|||
# the root directory of this source tree.
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.scoring import (
|
||||
ScoreBatchResponse,
|
||||
ScoreResponse,
|
||||
Scoring,
|
||||
ScoringResult,
|
||||
)
|
||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
|
||||
from llama_stack.core.datatypes import Api
|
||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||
from llama_stack.providers.utils.common.data_schema_validator import (
|
||||
get_valid_schemas,
|
||||
validate_dataset_schema,
|
||||
)
|
||||
from llama_stack_api import (
|
||||
DatasetIO,
|
||||
Datasets,
|
||||
ScoreBatchResponse,
|
||||
ScoreResponse,
|
||||
Scoring,
|
||||
ScoringFn,
|
||||
ScoringFnParams,
|
||||
ScoringFunctionsProtocolPrivate,
|
||||
ScoringResult,
|
||||
)
|
||||
|
||||
from .config import BasicScoringConfig
|
||||
from .scoring_fn.docvqa_scoring_fn import DocVQAScoringFn
|
||||
|
|
|
|||
|
|
@ -8,9 +8,8 @@ import json
|
|||
import re
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
from llama_stack_api import ScoringFnParams, ScoringResultRow
|
||||
|
||||
from .fn_defs.docvqa import docvqa
|
||||
|
||||
|
|
|
|||
|
|
@ -6,9 +6,8 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
from llama_stack_api import ScoringFnParams, ScoringResultRow
|
||||
|
||||
from .fn_defs.equality import equality
|
||||
|
||||
|
|
|
|||
|
|
@ -4,10 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
from llama_stack_api import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
NumberType,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,10 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
from llama_stack_api import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
NumberType,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,10 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
from llama_stack_api import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
NumberType,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,9 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
from llama_stack_api import (
|
||||
AggregationFunctionType,
|
||||
NumberType,
|
||||
RegexParserScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,9 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
from llama_stack_api import (
|
||||
AggregationFunctionType,
|
||||
NumberType,
|
||||
RegexParserScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,10 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
from llama_stack_api import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
NumberType,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,9 +6,8 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
from llama_stack_api import ScoringFnParams, ScoringResultRow
|
||||
|
||||
from .fn_defs.ifeval import (
|
||||
ifeval,
|
||||
|
|
|
|||
|
|
@ -5,9 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
from llama_stack_api import ScoringFnParams, ScoringFnParamsType, ScoringResultRow
|
||||
|
||||
from ..utils.math_utils import first_answer, normalize_final_answer, try_evaluate_frac, try_evaluate_latex
|
||||
from .fn_defs.regex_parser_math_response import (
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue