mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-04 02:03:44 +00:00
Merge remote-tracking branch 'upstream/main' into add-file-processor-skeleton
This commit is contained in:
commit
479e627b0d
645 changed files with 96136 additions and 35829 deletions
|
|
@ -1,9 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .batches import Batches, BatchObject, ListBatchesResponse
|
||||
|
||||
__all__ = ["Batches", "BatchObject", "ListBatchesResponse"]
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .benchmarks import *
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -1,27 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .conversations import (
|
||||
Conversation,
|
||||
ConversationDeletedResource,
|
||||
ConversationItem,
|
||||
ConversationItemCreateRequest,
|
||||
ConversationItemDeletedResource,
|
||||
ConversationItemList,
|
||||
Conversations,
|
||||
Metadata,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Conversation",
|
||||
"ConversationDeletedResource",
|
||||
"ConversationItem",
|
||||
"ConversationItemCreateRequest",
|
||||
"ConversationItemDeletedResource",
|
||||
"ConversationItemList",
|
||||
"Conversations",
|
||||
"Metadata",
|
||||
]
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .datasets import *
|
||||
|
|
@ -1,159 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum, EnumMeta
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class DynamicApiMeta(EnumMeta):
|
||||
def __new__(cls, name, bases, namespace):
|
||||
# Store the original enum values
|
||||
original_values = {k: v for k, v in namespace.items() if not k.startswith("_")}
|
||||
|
||||
# Create the enum class
|
||||
cls = super().__new__(cls, name, bases, namespace)
|
||||
|
||||
# Store the original values for reference
|
||||
cls._original_values = original_values
|
||||
# Initialize _dynamic_values
|
||||
cls._dynamic_values = {}
|
||||
|
||||
return cls
|
||||
|
||||
def __call__(cls, value):
|
||||
try:
|
||||
return super().__call__(value)
|
||||
except ValueError as e:
|
||||
# If this value was already dynamically added, return it
|
||||
if value in cls._dynamic_values:
|
||||
return cls._dynamic_values[value]
|
||||
|
||||
# If the value doesn't exist, create a new enum member
|
||||
# Create a new member name from the value
|
||||
member_name = value.lower().replace("-", "_")
|
||||
|
||||
# If this member name already exists in the enum, return the existing member
|
||||
if member_name in cls._member_map_:
|
||||
return cls._member_map_[member_name]
|
||||
|
||||
# Instead of creating a new member, raise ValueError to force users to use Api.add() to
|
||||
# register new APIs explicitly
|
||||
raise ValueError(f"API '{value}' does not exist. Use Api.add() to register new APIs.") from e
|
||||
|
||||
def __iter__(cls):
|
||||
# Allow iteration over both static and dynamic members
|
||||
yield from super().__iter__()
|
||||
if hasattr(cls, "_dynamic_values"):
|
||||
yield from cls._dynamic_values.values()
|
||||
|
||||
def add(cls, value):
|
||||
"""
|
||||
Add a new API to the enum.
|
||||
Used to register external APIs.
|
||||
"""
|
||||
member_name = value.lower().replace("-", "_")
|
||||
|
||||
# If this member name already exists in the enum, return it
|
||||
if member_name in cls._member_map_:
|
||||
return cls._member_map_[member_name]
|
||||
|
||||
# Create a new enum member
|
||||
member = object.__new__(cls)
|
||||
member._name_ = member_name
|
||||
member._value_ = value
|
||||
|
||||
# Add it to the enum class
|
||||
cls._member_map_[member_name] = member
|
||||
cls._member_names_.append(member_name)
|
||||
cls._member_type_ = str
|
||||
|
||||
# Store it in our dynamic values
|
||||
cls._dynamic_values[value] = member
|
||||
|
||||
return member
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Api(Enum, metaclass=DynamicApiMeta):
|
||||
"""Enumeration of all available APIs in the Llama Stack system.
|
||||
:cvar providers: Provider management and configuration
|
||||
:cvar inference: Text generation, chat completions, and embeddings
|
||||
:cvar safety: Content moderation and safety shields
|
||||
:cvar agents: Agent orchestration and execution
|
||||
:cvar batches: Batch processing for asynchronous API requests
|
||||
:cvar vector_io: Vector database operations and queries
|
||||
:cvar datasetio: Dataset input/output operations
|
||||
:cvar scoring: Model output evaluation and scoring
|
||||
:cvar eval: Model evaluation and benchmarking framework
|
||||
:cvar post_training: Fine-tuning and model training
|
||||
:cvar tool_runtime: Tool execution and management
|
||||
:cvar telemetry: Observability and system monitoring
|
||||
:cvar models: Model metadata and management
|
||||
:cvar shields: Safety shield implementations
|
||||
:cvar datasets: Dataset creation and management
|
||||
:cvar scoring_functions: Scoring function definitions
|
||||
:cvar benchmarks: Benchmark suite management
|
||||
:cvar tool_groups: Tool group organization
|
||||
:cvar files: File storage and management
|
||||
:cvar prompts: Prompt versions and management
|
||||
:cvar inspect: Built-in system inspection and introspection
|
||||
"""
|
||||
|
||||
providers = "providers"
|
||||
inference = "inference"
|
||||
safety = "safety"
|
||||
agents = "agents"
|
||||
batches = "batches"
|
||||
vector_io = "vector_io"
|
||||
datasetio = "datasetio"
|
||||
scoring = "scoring"
|
||||
eval = "eval"
|
||||
post_training = "post_training"
|
||||
tool_runtime = "tool_runtime"
|
||||
|
||||
models = "models"
|
||||
shields = "shields"
|
||||
vector_stores = "vector_stores" # only used for routing table
|
||||
datasets = "datasets"
|
||||
scoring_functions = "scoring_functions"
|
||||
benchmarks = "benchmarks"
|
||||
tool_groups = "tool_groups"
|
||||
files = "files"
|
||||
prompts = "prompts"
|
||||
conversations = "conversations"
|
||||
file_processor = "file_processor"
|
||||
|
||||
# built-in API
|
||||
inspect = "inspect"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Error(BaseModel):
|
||||
"""
|
||||
Error response from the API. Roughly follows RFC 7807.
|
||||
|
||||
:param status: HTTP status code
|
||||
:param title: Error title, a short summary of the error which is invariant for an error type
|
||||
:param detail: Error detail, a longer human-readable description of the error
|
||||
:param instance: (Optional) A URL which can be used to retrieve more information about the specific occurrence of the error
|
||||
"""
|
||||
|
||||
status: int
|
||||
title: str
|
||||
detail: str
|
||||
instance: str | None = None
|
||||
|
||||
|
||||
class ExternalApiSpec(BaseModel):
|
||||
"""Specification for an external API implementation."""
|
||||
|
||||
module: str = Field(..., description="Python module containing the API implementation")
|
||||
name: str = Field(..., description="Name of the API")
|
||||
pip_packages: list[str] = Field(default=[], description="List of pip packages to install the API")
|
||||
protocol: str = Field(..., description="Name of the protocol class for the API")
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .eval import *
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .file_processor import *
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .files import *
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .inference import *
|
||||
|
|
@ -1,43 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
)
|
||||
|
||||
|
||||
class LogEvent:
|
||||
def __init__(
|
||||
self,
|
||||
content: str = "",
|
||||
end: str = "\n",
|
||||
color="white",
|
||||
):
|
||||
self.content = content
|
||||
self.color = color
|
||||
self.end = "\n" if end is None else end
|
||||
|
||||
def print(self, flush=True):
|
||||
cprint(f"{self.content}", color=self.color, end=self.end, flush=flush)
|
||||
|
||||
|
||||
class EventLogger:
|
||||
async def log(self, event_generator):
|
||||
async for chunk in event_generator:
|
||||
if isinstance(chunk, ChatCompletionResponseStreamChunk):
|
||||
event = chunk.event
|
||||
if event.event_type == ChatCompletionResponseEventType.start:
|
||||
yield LogEvent("Assistant> ", color="cyan", end="")
|
||||
elif event.event_type == ChatCompletionResponseEventType.progress:
|
||||
yield LogEvent(event.delta, color="yellow", end="")
|
||||
elif event.event_type == ChatCompletionResponseEventType.complete:
|
||||
yield LogEvent("")
|
||||
else:
|
||||
yield LogEvent("Assistant> ", color="cyan", end="")
|
||||
yield LogEvent(chunk.completion_message.content, color="yellow")
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .inspect import *
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .models import *
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .post_training import *
|
||||
|
|
@ -1,9 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .prompts import ListPromptsResponse, Prompt, Prompts
|
||||
|
||||
__all__ = ["Prompt", "Prompts", "ListPromptsResponse"]
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .providers import *
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .safety import *
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .scoring import *
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .scoring_functions import *
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .shields import *
|
||||
|
|
@ -1,8 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .rag_tool import *
|
||||
from .tools import *
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .vector_io import *
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .vector_stores import *
|
||||
|
|
@ -21,7 +21,7 @@ from llama_stack.core.datatypes import (
|
|||
from llama_stack.core.distribution import get_provider_registry
|
||||
from llama_stack.core.stack import replace_env_vars
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack_api import Api
|
||||
|
||||
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"
|
||||
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ from llama_stack.core.storage.datatypes import (
|
|||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR, EXTERNAL_PROVIDERS_DIR
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.core.utils.image_types import LlamaStackImageType
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack_api import Api
|
||||
|
||||
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "distributions"
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import importlib.resources
|
||||
import sys
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -12,12 +11,9 @@ from termcolor import cprint
|
|||
|
||||
from llama_stack.core.datatypes import BuildConfig
|
||||
from llama_stack.core.distribution import get_provider_registry
|
||||
from llama_stack.core.external import load_external_apis
|
||||
from llama_stack.core.utils.exec import run_command
|
||||
from llama_stack.core.utils.image_types import LlamaStackImageType
|
||||
from llama_stack.distributions.template import DistributionTemplate
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack_api import Api
|
||||
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
|
|
@ -101,64 +97,3 @@ def print_pip_install_help(config: BuildConfig):
|
|||
for special_dep in special_deps:
|
||||
cprint(f"uv pip install {special_dep}", color="yellow", file=sys.stderr)
|
||||
print()
|
||||
|
||||
|
||||
def build_image(
|
||||
build_config: BuildConfig,
|
||||
image_name: str,
|
||||
distro_or_config: str,
|
||||
run_config: str | None = None,
|
||||
):
|
||||
container_base = build_config.distribution_spec.container_image or "python:3.12-slim"
|
||||
|
||||
normal_deps, special_deps, external_provider_deps = get_provider_dependencies(build_config)
|
||||
normal_deps += SERVER_DEPENDENCIES
|
||||
if build_config.external_apis_dir:
|
||||
external_apis = load_external_apis(build_config)
|
||||
if external_apis:
|
||||
for _, api_spec in external_apis.items():
|
||||
normal_deps.extend(api_spec.pip_packages)
|
||||
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||
script = str(importlib.resources.files("llama_stack") / "core/build_container.sh")
|
||||
args = [
|
||||
script,
|
||||
"--distro-or-config",
|
||||
distro_or_config,
|
||||
"--image-name",
|
||||
image_name,
|
||||
"--container-base",
|
||||
container_base,
|
||||
"--normal-deps",
|
||||
" ".join(normal_deps),
|
||||
]
|
||||
# When building from a config file (not a template), include the run config path in the
|
||||
# build arguments
|
||||
if run_config is not None:
|
||||
args.extend(["--run-config", run_config])
|
||||
else:
|
||||
script = str(importlib.resources.files("llama_stack") / "core/build_venv.sh")
|
||||
args = [
|
||||
script,
|
||||
"--env-name",
|
||||
str(image_name),
|
||||
"--normal-deps",
|
||||
" ".join(normal_deps),
|
||||
]
|
||||
|
||||
# Always pass both arguments, even if empty, to maintain consistent positional arguments
|
||||
if special_deps:
|
||||
args.extend(["--optional-deps", "#".join(special_deps)])
|
||||
if external_provider_deps:
|
||||
args.extend(
|
||||
["--external-provider-deps", "#".join(external_provider_deps)]
|
||||
) # the script will install external provider module, get its deps, and install those too.
|
||||
|
||||
return_code = run_command(args)
|
||||
|
||||
if return_code != 0:
|
||||
log.error(
|
||||
f"Failed to build target {image_name} with return code {return_code}",
|
||||
)
|
||||
|
||||
return return_code
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ import httpx
|
|||
from pydantic import BaseModel, parse_obj_as
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.providers.datatypes import RemoteProviderConfig
|
||||
from llama_stack_api import RemoteProviderConfig
|
||||
|
||||
_CLIENT_CLASSES = {}
|
||||
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
|
|||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.core.utils.prompt_for_config import prompt_for_config
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
from llama_stack_api import Api, ProviderSpec
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,11 @@ from typing import Any, Literal
|
|||
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from llama_stack.apis.conversations.conversations import (
|
||||
from llama_stack.core.datatypes import AccessRule, StackRunConfig
|
||||
from llama_stack.core.storage.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.core.storage.sqlstore.sqlstore import sqlstore_impl
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import (
|
||||
Conversation,
|
||||
ConversationDeletedResource,
|
||||
ConversationItem,
|
||||
|
|
@ -20,11 +24,7 @@ from llama_stack.apis.conversations.conversations import (
|
|||
Conversations,
|
||||
Metadata,
|
||||
)
|
||||
from llama_stack.core.datatypes import AccessRule, StackRunConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
|
||||
from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType
|
||||
|
||||
logger = get_logger(name=__name__, category="openai_conversations")
|
||||
|
||||
|
|
@ -203,16 +203,11 @@ class ConversationServiceImpl(Conversations):
|
|||
"item_data": item_dict,
|
||||
}
|
||||
|
||||
# TODO: Add support for upsert in sql_store, this will fail first if ID exists and then update
|
||||
try:
|
||||
await self.sql_store.insert(table="conversation_items", data=item_record)
|
||||
except Exception:
|
||||
# If insert fails due to ID conflict, update existing record
|
||||
await self.sql_store.update(
|
||||
table="conversation_items",
|
||||
data={"created_at": created_at, "item_data": item_dict},
|
||||
where={"id": item_id},
|
||||
)
|
||||
await self.sql_store.upsert(
|
||||
table="conversation_items",
|
||||
data=item_record,
|
||||
conflict_columns=["id"],
|
||||
)
|
||||
|
||||
created_items.append(item_dict)
|
||||
|
||||
|
|
|
|||
|
|
@ -11,20 +11,6 @@ from urllib.parse import urlparse
|
|||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Dataset, DatasetInput
|
||||
from llama_stack.apis.eval import Eval
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.models import Model, ModelInput
|
||||
from llama_stack.apis.resource import Resource
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput
|
||||
from llama_stack.apis.shields import Shield, ShieldInput
|
||||
from llama_stack.apis.tools import ToolGroup, ToolGroupInput, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.apis.vector_stores import VectorStore, VectorStoreInput
|
||||
from llama_stack.core.access_control.datatypes import AccessRule
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
KVStoreReference,
|
||||
|
|
@ -32,7 +18,32 @@ from llama_stack.core.storage.datatypes import (
|
|||
StorageConfig,
|
||||
)
|
||||
from llama_stack.log import LoggingConfig
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
from llama_stack_api import (
|
||||
Api,
|
||||
Benchmark,
|
||||
BenchmarkInput,
|
||||
Dataset,
|
||||
DatasetInput,
|
||||
DatasetIO,
|
||||
Eval,
|
||||
Inference,
|
||||
Model,
|
||||
ModelInput,
|
||||
ProviderSpec,
|
||||
Resource,
|
||||
Safety,
|
||||
Scoring,
|
||||
ScoringFn,
|
||||
ScoringFnInput,
|
||||
Shield,
|
||||
ShieldInput,
|
||||
ToolGroup,
|
||||
ToolGroupInput,
|
||||
ToolRuntime,
|
||||
VectorIO,
|
||||
VectorStore,
|
||||
VectorStoreInput,
|
||||
)
|
||||
|
||||
LLAMA_STACK_BUILD_CONFIG_VERSION = 2
|
||||
LLAMA_STACK_RUN_CONFIG_VERSION = 2
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from pydantic import BaseModel
|
|||
from llama_stack.core.datatypes import BuildConfig, DistributionSpec
|
||||
from llama_stack.core.external import load_external_apis
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
from llama_stack_api import (
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
|
|
|
|||
|
|
@ -7,9 +7,9 @@
|
|||
|
||||
import yaml
|
||||
|
||||
from llama_stack.apis.datatypes import Api, ExternalApiSpec
|
||||
from llama_stack.core.datatypes import BuildConfig, StackRunConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import Api, ExternalApiSpec
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
|
|
|||
|
|
@ -8,18 +8,17 @@ from importlib.metadata import version
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.inspect import (
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.core.external import load_external_apis
|
||||
from llama_stack.core.server.routes import get_all_api_routes
|
||||
from llama_stack_api import (
|
||||
HealthInfo,
|
||||
HealthStatus,
|
||||
Inspect,
|
||||
ListRoutesResponse,
|
||||
RouteInfo,
|
||||
VersionInfo,
|
||||
)
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.core.external import load_external_apis
|
||||
from llama_stack.core.server.routes import get_all_api_routes
|
||||
from llama_stack.providers.datatypes import HealthStatus
|
||||
|
||||
|
||||
class DistributionInspectConfig(BaseModel):
|
||||
|
|
@ -46,8 +45,8 @@ class DistributionInspectImpl(Inspect):
|
|||
# Helper function to determine if a route should be included based on api_filter
|
||||
def should_include_route(webmethod) -> bool:
|
||||
if api_filter is None:
|
||||
# Default: only non-deprecated v1 APIs
|
||||
return not webmethod.deprecated and webmethod.level == LLAMA_STACK_API_V1
|
||||
# Default: only non-deprecated APIs
|
||||
return not webmethod.deprecated
|
||||
elif api_filter == "deprecated":
|
||||
# Special filter: show deprecated routes regardless of their actual level
|
||||
return bool(webmethod.deprecated)
|
||||
|
|
|
|||
|
|
@ -19,6 +19,8 @@ import httpx
|
|||
import yaml
|
||||
from fastapi import Response as FastAPIResponse
|
||||
|
||||
from llama_stack.core.utils.type_inspection import is_unwrapped_body_param
|
||||
|
||||
try:
|
||||
from llama_stack_client import (
|
||||
NOT_GIVEN,
|
||||
|
|
@ -40,24 +42,16 @@ from termcolor import cprint
|
|||
from llama_stack.core.build import print_pip_install_help
|
||||
from llama_stack.core.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec
|
||||
from llama_stack.core.request_headers import (
|
||||
PROVIDER_DATA_VAR,
|
||||
request_provider_data_context,
|
||||
)
|
||||
from llama_stack.core.request_headers import PROVIDER_DATA_VAR, request_provider_data_context
|
||||
from llama_stack.core.resolver import ProviderRegistry
|
||||
from llama_stack.core.server.routes import RouteImpls, find_matching_route, initialize_route_impls
|
||||
from llama_stack.core.stack import (
|
||||
Stack,
|
||||
get_stack_run_config_from_distro,
|
||||
replace_env_vars,
|
||||
)
|
||||
from llama_stack.core.stack import Stack, get_stack_run_config_from_distro, replace_env_vars
|
||||
from llama_stack.core.telemetry import Telemetry
|
||||
from llama_stack.core.telemetry.tracing import CURRENT_TRACE_CONTEXT, end_trace, setup_logger, start_trace
|
||||
from llama_stack.core.utils.config import redact_sensitive_fields
|
||||
from llama_stack.core.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.core.utils.exec import in_notebook
|
||||
from llama_stack.log import get_logger, setup_logging
|
||||
from llama_stack.strong_typing.inspection import is_unwrapped_body_param
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
|
@ -389,6 +383,12 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
matched_func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
|
||||
body |= path_params
|
||||
|
||||
# Pass through params that aren't already handled as path params
|
||||
if options.params:
|
||||
extra_query_params = {k: v for k, v in options.params.items() if k not in path_params}
|
||||
if extra_query_params:
|
||||
body["extra_query"] = extra_query_params
|
||||
|
||||
body, field_names = self._handle_file_uploads(options, body)
|
||||
|
||||
body = self._convert_body(matched_func, body, exclude_params=set(field_names))
|
||||
|
|
|
|||
|
|
@ -9,9 +9,9 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.prompts import ListPromptsResponse, Prompt, Prompts
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||
from llama_stack.core.storage.kvstore import KVStore, kvstore_impl
|
||||
from llama_stack_api import ListPromptsResponse, Prompt, Prompts
|
||||
|
||||
|
||||
class PromptServiceConfig(BaseModel):
|
||||
|
|
|
|||
|
|
@ -9,9 +9,8 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus
|
||||
from llama_stack_api import HealthResponse, HealthStatus, ListProvidersResponse, ProviderInfo, Providers
|
||||
|
||||
from .datatypes import StackRunConfig
|
||||
from .utils.config import redact_sensitive_fields
|
||||
|
|
|
|||
|
|
@ -8,30 +8,6 @@ import importlib.metadata
|
|||
import inspect
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.batches import Batches
|
||||
from llama_stack.apis.benchmarks import Benchmarks
|
||||
from llama_stack.apis.conversations import Conversations
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.datatypes import ExternalApiSpec
|
||||
from llama_stack.apis.eval import Eval
|
||||
from llama_stack.apis.file_processor import FileProcessor
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference, InferenceProvider
|
||||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.post_training import PostTraining
|
||||
from llama_stack.apis.prompts import Prompts
|
||||
from llama_stack.apis.providers import Providers as ProvidersAPI
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||
from llama_stack.apis.shields import Shields
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
|
||||
from llama_stack.core.client import get_client_impl
|
||||
from llama_stack.core.datatypes import (
|
||||
AccessRule,
|
||||
|
|
@ -45,17 +21,45 @@ from llama_stack.core.external import load_external_apis
|
|||
from llama_stack.core.store import DistributionRegistry
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
from llama_stack_api import (
|
||||
LLAMA_STACK_API_V1ALPHA,
|
||||
Agents,
|
||||
Api,
|
||||
Batches,
|
||||
Benchmarks,
|
||||
BenchmarksProtocolPrivate,
|
||||
Conversations,
|
||||
DatasetIO,
|
||||
Datasets,
|
||||
DatasetsProtocolPrivate,
|
||||
Eval,
|
||||
ExternalApiSpec,
|
||||
FileProcessor,
|
||||
Files,
|
||||
Inference,
|
||||
InferenceProvider,
|
||||
Inspect,
|
||||
Models,
|
||||
ModelsProtocolPrivate,
|
||||
PostTraining,
|
||||
Prompts,
|
||||
ProviderSpec,
|
||||
RemoteProviderConfig,
|
||||
RemoteProviderSpec,
|
||||
Safety,
|
||||
Scoring,
|
||||
ScoringFunctions,
|
||||
ScoringFunctionsProtocolPrivate,
|
||||
Shields,
|
||||
ShieldsProtocolPrivate,
|
||||
ToolGroups,
|
||||
ToolGroupsProtocolPrivate,
|
||||
ToolRuntime,
|
||||
VectorIO,
|
||||
VectorStore,
|
||||
)
|
||||
from llama_stack_api import (
|
||||
Providers as ProvidersAPI,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
|
|
|||
|
|
@ -12,8 +12,8 @@ from llama_stack.core.datatypes import (
|
|||
)
|
||||
from llama_stack.core.stack import StackRunConfig
|
||||
from llama_stack.core.store import DistributionRegistry
|
||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||
from llama_stack_api import Api, RoutingTable
|
||||
|
||||
|
||||
async def get_routing_table_impl(
|
||||
|
|
|
|||
|
|
@ -6,11 +6,8 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
from llama_stack_api import DatasetIO, DatasetPurpose, DataSource, PaginatedResponse, RoutingTable
|
||||
|
||||
logger = get_logger(name=__name__, category="core::routers")
|
||||
|
||||
|
|
|
|||
|
|
@ -6,15 +6,18 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
|
||||
from llama_stack.apis.scoring import (
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import (
|
||||
BenchmarkConfig,
|
||||
Eval,
|
||||
EvaluateResponse,
|
||||
Job,
|
||||
RoutingTable,
|
||||
ScoreBatchResponse,
|
||||
ScoreResponse,
|
||||
Scoring,
|
||||
ScoringFnParams,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
|
||||
logger = get_logger(name=__name__, category="core::routers")
|
||||
|
||||
|
|
|
|||
|
|
@ -15,13 +15,25 @@ from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatC
|
|||
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
|
||||
from llama_stack.apis.inference import (
|
||||
from llama_stack.core.telemetry.telemetry import MetricEvent
|
||||
from llama_stack.core.telemetry.tracing import enqueue_event, get_current_span
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||
from llama_stack_api import (
|
||||
HealthResponse,
|
||||
HealthStatus,
|
||||
Inference,
|
||||
ListOpenAIChatCompletionResponse,
|
||||
ModelNotFoundError,
|
||||
ModelType,
|
||||
ModelTypeError,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
|
|
@ -35,19 +47,8 @@ from llama_stack.apis.inference import (
|
|||
OpenAIMessageParam,
|
||||
Order,
|
||||
RerankResponse,
|
||||
RoutingTable,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
)
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.telemetry.telemetry import MetricEvent
|
||||
from llama_stack.core.telemetry.tracing import enqueue_event, get_current_span
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||
|
||||
logger = get_logger(name=__name__, category="core::routers")
|
||||
|
||||
|
|
@ -416,7 +417,7 @@ class InferenceRouter(Inference):
|
|||
prompt_tokens=chunk.usage.prompt_tokens,
|
||||
completion_tokens=chunk.usage.completion_tokens,
|
||||
total_tokens=chunk.usage.total_tokens,
|
||||
model_id=fully_qualified_model_id,
|
||||
fully_qualified_model_id=fully_qualified_model_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
for metric in metrics:
|
||||
|
|
|
|||
|
|
@ -6,13 +6,9 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||
from llama_stack.apis.safety.safety import ModerationObject
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.core.datatypes import SafetyConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import RoutingTable
|
||||
from llama_stack_api import ModerationObject, OpenAIMessageParam, RoutingTable, RunShieldResponse, Safety, Shield
|
||||
|
||||
logger = get_logger(name=__name__, category="core::routers")
|
||||
|
||||
|
|
@ -52,7 +48,7 @@ class SafetyRouter(Safety):
|
|||
async def run_shield(
|
||||
self,
|
||||
shield_id: str,
|
||||
messages: list[Message],
|
||||
messages: list[OpenAIMessageParam],
|
||||
params: dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
|
||||
|
|
|
|||
|
|
@ -6,14 +6,12 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import (
|
||||
URL,
|
||||
)
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolDefsResponse,
|
||||
ToolRuntime,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
||||
|
||||
|
|
@ -36,16 +34,16 @@ class ToolRuntimeRouter(ToolRuntime):
|
|||
logger.debug("ToolRuntimeRouter.shutdown")
|
||||
pass
|
||||
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
|
||||
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None) -> Any:
|
||||
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
|
||||
provider = await self.routing_table.get_provider_impl(tool_name)
|
||||
return await provider.invoke_tool(
|
||||
tool_name=tool_name,
|
||||
kwargs=kwargs,
|
||||
authorization=authorization,
|
||||
)
|
||||
|
||||
async def list_runtime_tools(
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
|
||||
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None, authorization: str | None = None
|
||||
) -> ListToolDefsResponse:
|
||||
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
|
||||
return await self.routing_table.list_tools(tool_group_id)
|
||||
return await self.routing_table.list_tools(tool_group_id, authorization=authorization)
|
||||
|
|
|
|||
|
|
@ -10,13 +10,20 @@ from typing import Annotated, Any
|
|||
|
||||
from fastapi import Body
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.vector_io import (
|
||||
from llama_stack.core.datatypes import VectorStoresConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import (
|
||||
Chunk,
|
||||
HealthResponse,
|
||||
HealthStatus,
|
||||
InterleavedContent,
|
||||
ModelNotFoundError,
|
||||
ModelType,
|
||||
ModelTypeError,
|
||||
OpenAICreateVectorStoreFileBatchRequestWithExtraBody,
|
||||
OpenAICreateVectorStoreRequestWithExtraBody,
|
||||
QueryChunksResponse,
|
||||
RoutingTable,
|
||||
SearchRankingOptions,
|
||||
VectorIO,
|
||||
VectorStoreChunkingStrategy,
|
||||
|
|
@ -24,7 +31,7 @@ from llama_stack.apis.vector_io import (
|
|||
VectorStoreChunkingStrategyStaticConfig,
|
||||
VectorStoreDeleteResponse,
|
||||
VectorStoreFileBatchObject,
|
||||
VectorStoreFileContentsResponse,
|
||||
VectorStoreFileContentResponse,
|
||||
VectorStoreFileDeleteResponse,
|
||||
VectorStoreFileObject,
|
||||
VectorStoreFilesListInBatchResponse,
|
||||
|
|
@ -33,9 +40,6 @@ from llama_stack.apis.vector_io import (
|
|||
VectorStoreObject,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.core.datatypes import VectorStoresConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||
|
||||
logger = get_logger(name=__name__, category="core::routers")
|
||||
|
||||
|
|
@ -122,6 +126,14 @@ class VectorIORouter(VectorIO):
|
|||
if embedding_model is not None and embedding_dimension is None:
|
||||
embedding_dimension = await self._get_embedding_model_dimension(embedding_model)
|
||||
|
||||
# Validate that embedding model exists and is of the correct type
|
||||
if embedding_model is not None:
|
||||
model = await self.routing_table.get_object_by_identifier("model", embedding_model)
|
||||
if model is None:
|
||||
raise ModelNotFoundError(embedding_model)
|
||||
if model.model_type != ModelType.embedding:
|
||||
raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding)
|
||||
|
||||
# Auto-select provider if not specified
|
||||
if provider_id is None:
|
||||
num_providers = len(self.routing_table.impls_by_provider_id)
|
||||
|
|
@ -247,6 +259,13 @@ class VectorIORouter(VectorIO):
|
|||
metadata: dict[str, Any] | None = None,
|
||||
) -> VectorStoreObject:
|
||||
logger.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}")
|
||||
|
||||
# Check if provider_id is being changed (not supported)
|
||||
if metadata and "provider_id" in metadata:
|
||||
current_store = await self.routing_table.get_object_by_identifier("vector_store", vector_store_id)
|
||||
if current_store and current_store.provider_id != metadata["provider_id"]:
|
||||
raise ValueError("provider_id cannot be changed after vector store creation")
|
||||
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_update_vector_store(
|
||||
vector_store_id=vector_store_id,
|
||||
|
|
@ -338,12 +357,19 @@ class VectorIORouter(VectorIO):
|
|||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileContentsResponse:
|
||||
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}")
|
||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file_contents(
|
||||
include_embeddings: bool | None = False,
|
||||
include_metadata: bool | None = False,
|
||||
) -> VectorStoreFileContentResponse:
|
||||
logger.debug(
|
||||
f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}, "
|
||||
f"include_embeddings={include_embeddings}, include_metadata={include_metadata}"
|
||||
)
|
||||
|
||||
return await self.routing_table.openai_retrieve_vector_store_file_contents(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
include_embeddings=include_embeddings,
|
||||
include_metadata=include_metadata,
|
||||
)
|
||||
|
||||
async def openai_update_vector_store_file(
|
||||
|
|
|
|||
|
|
@ -6,11 +6,11 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
|
||||
from llama_stack.core.datatypes import (
|
||||
BenchmarkWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import Benchmark, Benchmarks, ListBenchmarksResponse
|
||||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
||||
|
|
|
|||
|
|
@ -6,9 +6,6 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.errors import ModelNotFoundError
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||
from llama_stack.core.access_control.datatypes import Action
|
||||
from llama_stack.core.datatypes import (
|
||||
|
|
@ -21,7 +18,7 @@ from llama_stack.core.datatypes import (
|
|||
from llama_stack.core.request_headers import get_authenticated_user
|
||||
from llama_stack.core.store import DistributionRegistry
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||
from llama_stack_api import Api, Model, ModelNotFoundError, ResourceType, RoutingTable
|
||||
|
||||
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||
|
||||
|
|
|
|||
|
|
@ -7,22 +7,22 @@
|
|||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.errors import DatasetNotFoundError
|
||||
from llama_stack.apis.datasets import (
|
||||
from llama_stack.core.datatypes import (
|
||||
DatasetWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import (
|
||||
Dataset,
|
||||
DatasetNotFoundError,
|
||||
DatasetPurpose,
|
||||
Datasets,
|
||||
DatasetType,
|
||||
DataSource,
|
||||
ListDatasetsResponse,
|
||||
ResourceType,
|
||||
RowsDataSource,
|
||||
URIDataSource,
|
||||
)
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.core.datatypes import (
|
||||
DatasetWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
||||
|
|
|
|||
|
|
@ -7,8 +7,6 @@
|
|||
import time
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.errors import ModelNotFoundError
|
||||
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
|
||||
from llama_stack.core.datatypes import (
|
||||
ModelWithOwner,
|
||||
RegistryEntrySource,
|
||||
|
|
@ -16,6 +14,15 @@ from llama_stack.core.datatypes import (
|
|||
from llama_stack.core.request_headers import PROVIDER_DATA_VAR, NeedsRequestProviderData
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import (
|
||||
ListModelsResponse,
|
||||
Model,
|
||||
ModelNotFoundError,
|
||||
Models,
|
||||
ModelType,
|
||||
OpenAIListModelsResponse,
|
||||
OpenAIModel,
|
||||
)
|
||||
|
||||
from .common import CommonRoutingTableImpl, lookup_model
|
||||
|
||||
|
|
|
|||
|
|
@ -4,18 +4,18 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import ParamType
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
ListScoringFunctionsResponse,
|
||||
ScoringFn,
|
||||
ScoringFnParams,
|
||||
ScoringFunctions,
|
||||
)
|
||||
from llama_stack.core.datatypes import (
|
||||
ScoringFnWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import (
|
||||
ListScoringFunctionsResponse,
|
||||
ParamType,
|
||||
ResourceType,
|
||||
ScoringFn,
|
||||
ScoringFnParams,
|
||||
ScoringFunctions,
|
||||
)
|
||||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
||||
|
|
|
|||
|
|
@ -6,12 +6,11 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
|
||||
from llama_stack.core.datatypes import (
|
||||
ShieldWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import ListShieldsResponse, ResourceType, Shield, Shields
|
||||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
||||
|
|
|
|||
|
|
@ -6,11 +6,17 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.common.errors import ToolGroupNotFoundError
|
||||
from llama_stack.apis.tools import ListToolDefsResponse, ListToolGroupsResponse, ToolDef, ToolGroup, ToolGroups
|
||||
from llama_stack.core.datatypes import AuthenticationRequiredError, ToolGroupWithOwner
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import (
|
||||
URL,
|
||||
ListToolDefsResponse,
|
||||
ListToolGroupsResponse,
|
||||
ToolDef,
|
||||
ToolGroup,
|
||||
ToolGroupNotFoundError,
|
||||
ToolGroups,
|
||||
)
|
||||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
||||
|
|
@ -43,7 +49,9 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
routing_key = self.tool_to_toolgroup[routing_key]
|
||||
return await super().get_provider_impl(routing_key, provider_id)
|
||||
|
||||
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolDefsResponse:
|
||||
async def list_tools(
|
||||
self, toolgroup_id: str | None = None, authorization: str | None = None
|
||||
) -> ListToolDefsResponse:
|
||||
if toolgroup_id:
|
||||
if group_id := parse_toolgroup_from_toolgroup_name_pair(toolgroup_id):
|
||||
toolgroup_id = group_id
|
||||
|
|
@ -55,7 +63,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
for toolgroup in toolgroups:
|
||||
if toolgroup.identifier not in self.toolgroups_to_tools:
|
||||
try:
|
||||
await self._index_tools(toolgroup)
|
||||
await self._index_tools(toolgroup, authorization=authorization)
|
||||
except AuthenticationRequiredError:
|
||||
# Send authentication errors back to the client so it knows
|
||||
# that it needs to supply credentials for remote MCP servers.
|
||||
|
|
@ -70,9 +78,11 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
|
||||
return ListToolDefsResponse(data=all_tools)
|
||||
|
||||
async def _index_tools(self, toolgroup: ToolGroup):
|
||||
async def _index_tools(self, toolgroup: ToolGroup, authorization: str | None = None):
|
||||
provider_impl = await super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
|
||||
tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint)
|
||||
tooldefs_response = await provider_impl.list_runtime_tools(
|
||||
toolgroup.identifier, toolgroup.mcp_endpoint, authorization=authorization
|
||||
)
|
||||
|
||||
tooldefs = tooldefs_response.data
|
||||
for t in tooldefs:
|
||||
|
|
|
|||
|
|
@ -6,26 +6,27 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.core.datatypes import (
|
||||
VectorStoreWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
# Removed VectorStores import to avoid exposing public API
|
||||
from llama_stack.apis.vector_io.vector_io import (
|
||||
from llama_stack_api import (
|
||||
ModelNotFoundError,
|
||||
ModelType,
|
||||
ModelTypeError,
|
||||
ResourceType,
|
||||
SearchRankingOptions,
|
||||
VectorStoreChunkingStrategy,
|
||||
VectorStoreDeleteResponse,
|
||||
VectorStoreFileContentsResponse,
|
||||
VectorStoreFileContentResponse,
|
||||
VectorStoreFileDeleteResponse,
|
||||
VectorStoreFileObject,
|
||||
VectorStoreFileStatus,
|
||||
VectorStoreObject,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.core.datatypes import (
|
||||
VectorStoreWithOwner,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .common import CommonRoutingTableImpl, lookup_model
|
||||
|
||||
|
|
@ -195,12 +196,17 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl):
|
|||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileContentsResponse:
|
||||
include_embeddings: bool | None = False,
|
||||
include_metadata: bool | None = False,
|
||||
) -> VectorStoreFileContentResponse:
|
||||
await self.assert_action_allowed("read", "vector_store", vector_store_id)
|
||||
|
||||
provider = await self.get_provider_impl(vector_store_id)
|
||||
return await provider.openai_retrieve_vector_store_file_contents(
|
||||
vector_store_id=vector_store_id,
|
||||
file_id=file_id,
|
||||
include_embeddings=include_embeddings,
|
||||
include_metadata=include_metadata,
|
||||
)
|
||||
|
||||
async def openai_update_vector_store_file(
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ import httpx
|
|||
import jwt
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.common.errors import TokenValidationError
|
||||
from llama_stack.core.datatypes import (
|
||||
AuthenticationConfig,
|
||||
CustomAuthConfig,
|
||||
|
|
@ -23,6 +22,7 @@ from llama_stack.core.datatypes import (
|
|||
User,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api import TokenValidationError
|
||||
|
||||
logger = get_logger(name=__name__, category="core::auth")
|
||||
|
||||
|
|
|
|||
|
|
@ -11,9 +11,9 @@ from datetime import UTC, datetime, timedelta
|
|||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendType
|
||||
from llama_stack.core.storage.kvstore.kvstore import _KVSTORE_BACKENDS, kvstore_impl
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.kvstore.kvstore import _KVSTORE_BACKENDS, kvstore_impl
|
||||
from llama_stack_api.internal.kvstore import KVStore
|
||||
|
||||
logger = get_logger(name=__name__, category="core::server")
|
||||
|
||||
|
|
|
|||
|
|
@ -12,9 +12,8 @@ from typing import Any
|
|||
from aiohttp import hdrs
|
||||
from starlette.routing import Route
|
||||
|
||||
from llama_stack.apis.datatypes import Api, ExternalApiSpec
|
||||
from llama_stack.core.resolver import api_protocol_map
|
||||
from llama_stack.schema_utils import WebMethod
|
||||
from llama_stack_api import Api, ExternalApiSpec, WebMethod
|
||||
|
||||
EndpointFunc = Callable[..., Any]
|
||||
PathParams = dict[str, str]
|
||||
|
|
|
|||
|
|
@ -31,8 +31,6 @@ from fastapi.responses import JSONResponse, StreamingResponse
|
|||
from openai import BadRequestError
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.core.access_control.access_control import AccessDeniedError
|
||||
from llama_stack.core.datatypes import (
|
||||
AuthenticationRequiredError,
|
||||
|
|
@ -58,7 +56,7 @@ from llama_stack.core.utils.config import redact_sensitive_fields
|
|||
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
|
||||
from llama_stack.core.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.log import LoggingConfig, get_logger, setup_logging
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack_api import Api, ConflictError, PaginatedResponse, ResourceNotFoundError
|
||||
|
||||
from .auth import AuthenticationMiddleware
|
||||
from .quota import QuotaMiddleware
|
||||
|
|
@ -526,8 +524,8 @@ def extract_path_params(route: str) -> list[str]:
|
|||
|
||||
def remove_disabled_providers(obj):
|
||||
if isinstance(obj, dict):
|
||||
keys = ["provider_id", "shield_id", "provider_model_id", "model_id"]
|
||||
if any(k in obj and obj[k] in ("__disabled__", "", None) for k in keys):
|
||||
# Filter out items where provider_id is explicitly disabled or empty
|
||||
if "provider_id" in obj and obj["provider_id"] in ("__disabled__", "", None):
|
||||
return None
|
||||
return {k: v for k, v in ((k, remove_disabled_providers(v)) for k, v in obj.items()) if v is not None}
|
||||
elif isinstance(obj, list):
|
||||
|
|
|
|||
|
|
@ -13,26 +13,6 @@ from typing import Any
|
|||
|
||||
import yaml
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.batches import Batches
|
||||
from llama_stack.apis.benchmarks import Benchmarks
|
||||
from llama_stack.apis.conversations import Conversations
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.eval import Eval
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.post_training import PostTraining
|
||||
from llama_stack.apis.prompts import Prompts
|
||||
from llama_stack.apis.providers import Providers
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||
from llama_stack.apis.shields import Shields
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
|
||||
from llama_stack.core.datatypes import Provider, SafetyConfig, StackRunConfig, VectorStoresConfig
|
||||
from llama_stack.core.distribution import get_provider_registry
|
||||
|
|
@ -54,7 +34,30 @@ from llama_stack.core.storage.datatypes import (
|
|||
from llama_stack.core.store.registry import create_dist_registry
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack_api import (
|
||||
Agents,
|
||||
Api,
|
||||
Batches,
|
||||
Benchmarks,
|
||||
Conversations,
|
||||
DatasetIO,
|
||||
Datasets,
|
||||
Eval,
|
||||
Files,
|
||||
Inference,
|
||||
Inspect,
|
||||
Models,
|
||||
PostTraining,
|
||||
Prompts,
|
||||
Providers,
|
||||
Safety,
|
||||
Scoring,
|
||||
ScoringFunctions,
|
||||
Shields,
|
||||
ToolGroups,
|
||||
ToolRuntime,
|
||||
VectorIO,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
|
@ -382,8 +385,8 @@ def _initialize_storage(run_config: StackRunConfig):
|
|||
else:
|
||||
raise ValueError(f"Unknown storage backend type: {type}")
|
||||
|
||||
from llama_stack.providers.utils.kvstore.kvstore import register_kvstore_backends
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
from llama_stack.core.storage.kvstore.kvstore import register_kvstore_backends
|
||||
from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends
|
||||
|
||||
register_kvstore_backends(kv_backends)
|
||||
register_sqlstore_backends(sql_backends)
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@ from typing import Annotated, Literal
|
|||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
|
||||
|
||||
class StorageBackendType(StrEnum):
|
||||
KV_REDIS = "kv_redis"
|
||||
|
|
@ -256,15 +258,24 @@ class ResponsesStoreReference(InferenceStoreReference):
|
|||
|
||||
class ServerStoresConfig(BaseModel):
|
||||
metadata: KVStoreReference | None = Field(
|
||||
default=None,
|
||||
default=KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="registry",
|
||||
),
|
||||
description="Metadata store configuration (uses KV backend)",
|
||||
)
|
||||
inference: InferenceStoreReference | None = Field(
|
||||
default=None,
|
||||
default=InferenceStoreReference(
|
||||
backend="sql_default",
|
||||
table_name="inference_store",
|
||||
),
|
||||
description="Inference store configuration (uses SQL backend)",
|
||||
)
|
||||
conversations: SqlStoreReference | None = Field(
|
||||
default=None,
|
||||
default=SqlStoreReference(
|
||||
backend="sql_default",
|
||||
table_name="openai_conversations",
|
||||
),
|
||||
description="Conversations store configuration (uses SQL backend)",
|
||||
)
|
||||
responses: ResponsesStoreReference | None = Field(
|
||||
|
|
@ -272,13 +283,21 @@ class ServerStoresConfig(BaseModel):
|
|||
description="Responses store configuration (uses SQL backend)",
|
||||
)
|
||||
prompts: KVStoreReference | None = Field(
|
||||
default=None,
|
||||
default=KVStoreReference(backend="kv_default", namespace="prompts"),
|
||||
description="Prompts store configuration (uses KV backend)",
|
||||
)
|
||||
|
||||
|
||||
class StorageConfig(BaseModel):
|
||||
backends: dict[str, StorageBackendConfig] = Field(
|
||||
default={
|
||||
"kv_default": SqliteKVStoreConfig(
|
||||
db_path=f"${{env.SQLITE_STORE_DIR:={DISTRIBS_BASE_DIR}}}/kvstore.db",
|
||||
),
|
||||
"sql_default": SqliteSqlStoreConfig(
|
||||
db_path=f"${{env.SQLITE_STORE_DIR:={DISTRIBS_BASE_DIR}}}/sql_store.db",
|
||||
),
|
||||
},
|
||||
description="Named backend configurations (e.g., 'default', 'cache')",
|
||||
)
|
||||
stores: ServerStoresConfig = Field(
|
||||
|
|
|
|||
|
|
@ -4,4 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack_api.internal.kvstore import KVStore as KVStore
|
||||
|
||||
from .kvstore import * # noqa: F401, F403
|
||||
|
|
@ -11,10 +11,21 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendConfig, StorageBackendType
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from .api import KVStore
|
||||
from .config import KVStoreConfig
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, StorageBackendConfig
|
||||
from llama_stack_api.internal.kvstore import KVStore
|
||||
|
||||
from .config import (
|
||||
KVStoreConfig,
|
||||
MongoDBKVStoreConfig,
|
||||
PostgresKVStoreConfig,
|
||||
RedisKVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
|
||||
|
||||
def kvstore_dependencies():
|
||||
|
|
@ -30,7 +41,7 @@ def kvstore_dependencies():
|
|||
|
||||
class InmemoryKVStoreImpl(KVStore):
|
||||
def __init__(self):
|
||||
self._store = {}
|
||||
self._store: dict[str, str] = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
|
@ -38,7 +49,7 @@ class InmemoryKVStoreImpl(KVStore):
|
|||
async def get(self, key: str) -> str | None:
|
||||
return self._store.get(key)
|
||||
|
||||
async def set(self, key: str, value: str) -> None:
|
||||
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
|
||||
self._store[key] = value
|
||||
|
||||
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||
|
|
@ -53,45 +64,65 @@ class InmemoryKVStoreImpl(KVStore):
|
|||
|
||||
|
||||
_KVSTORE_BACKENDS: dict[str, KVStoreConfig] = {}
|
||||
_KVSTORE_INSTANCES: dict[tuple[str, str], KVStore] = {}
|
||||
_KVSTORE_LOCKS: defaultdict[tuple[str, str], asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||
|
||||
|
||||
def register_kvstore_backends(backends: dict[str, StorageBackendConfig]) -> None:
|
||||
"""Register the set of available KV store backends for reference resolution."""
|
||||
global _KVSTORE_BACKENDS
|
||||
global _KVSTORE_INSTANCES
|
||||
global _KVSTORE_LOCKS
|
||||
|
||||
_KVSTORE_BACKENDS.clear()
|
||||
_KVSTORE_INSTANCES.clear()
|
||||
_KVSTORE_LOCKS.clear()
|
||||
for name, cfg in backends.items():
|
||||
_KVSTORE_BACKENDS[name] = cfg
|
||||
typed_cfg = cast(KVStoreConfig, cfg)
|
||||
_KVSTORE_BACKENDS[name] = typed_cfg
|
||||
|
||||
|
||||
async def kvstore_impl(reference: KVStoreReference) -> KVStore:
|
||||
backend_name = reference.backend
|
||||
cache_key = (backend_name, reference.namespace)
|
||||
|
||||
existing = _KVSTORE_INSTANCES.get(cache_key)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
backend_config = _KVSTORE_BACKENDS.get(backend_name)
|
||||
if backend_config is None:
|
||||
raise ValueError(f"Unknown KVStore backend '{backend_name}'. Registered backends: {sorted(_KVSTORE_BACKENDS)}")
|
||||
|
||||
config = backend_config.model_copy()
|
||||
config.namespace = reference.namespace
|
||||
lock = _KVSTORE_LOCKS[cache_key]
|
||||
async with lock:
|
||||
existing = _KVSTORE_INSTANCES.get(cache_key)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
if config.type == StorageBackendType.KV_REDIS.value:
|
||||
from .redis import RedisKVStoreImpl
|
||||
config = backend_config.model_copy()
|
||||
config.namespace = reference.namespace
|
||||
|
||||
impl = RedisKVStoreImpl(config)
|
||||
elif config.type == StorageBackendType.KV_SQLITE.value:
|
||||
from .sqlite import SqliteKVStoreImpl
|
||||
impl: KVStore
|
||||
if isinstance(config, RedisKVStoreConfig):
|
||||
from .redis import RedisKVStoreImpl
|
||||
|
||||
impl = SqliteKVStoreImpl(config)
|
||||
elif config.type == StorageBackendType.KV_POSTGRES.value:
|
||||
from .postgres import PostgresKVStoreImpl
|
||||
impl = RedisKVStoreImpl(config)
|
||||
elif isinstance(config, SqliteKVStoreConfig):
|
||||
from .sqlite import SqliteKVStoreImpl
|
||||
|
||||
impl = PostgresKVStoreImpl(config)
|
||||
elif config.type == StorageBackendType.KV_MONGODB.value:
|
||||
from .mongodb import MongoDBKVStoreImpl
|
||||
impl = SqliteKVStoreImpl(config)
|
||||
elif isinstance(config, PostgresKVStoreConfig):
|
||||
from .postgres import PostgresKVStoreImpl
|
||||
|
||||
impl = MongoDBKVStoreImpl(config)
|
||||
else:
|
||||
raise ValueError(f"Unknown kvstore type {config.type}")
|
||||
impl = PostgresKVStoreImpl(config)
|
||||
elif isinstance(config, MongoDBKVStoreConfig):
|
||||
from .mongodb import MongoDBKVStoreImpl
|
||||
|
||||
await impl.initialize()
|
||||
return impl
|
||||
impl = MongoDBKVStoreImpl(config)
|
||||
else:
|
||||
raise ValueError(f"Unknown kvstore type {config.type}")
|
||||
|
||||
await impl.initialize()
|
||||
_KVSTORE_INSTANCES[cache_key] = impl
|
||||
return impl
|
||||
|
|
@ -9,8 +9,8 @@ from datetime import datetime
|
|||
from pymongo import AsyncMongoClient
|
||||
from pymongo.asynchronous.collection import AsyncCollection
|
||||
|
||||
from llama_stack.core.storage.kvstore import KVStore
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
from ..config import MongoDBKVStoreConfig
|
||||
|
||||
|
|
@ -6,12 +6,13 @@
|
|||
|
||||
from datetime import datetime
|
||||
|
||||
import psycopg2
|
||||
from psycopg2.extras import DictCursor
|
||||
import psycopg2 # type: ignore[import-not-found]
|
||||
from psycopg2.extensions import connection as PGConnection # type: ignore[import-not-found]
|
||||
from psycopg2.extras import DictCursor # type: ignore[import-not-found]
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api.internal.kvstore import KVStore
|
||||
|
||||
from ..api import KVStore
|
||||
from ..config import PostgresKVStoreConfig
|
||||
|
||||
log = get_logger(name=__name__, category="providers::utils")
|
||||
|
|
@ -20,12 +21,12 @@ log = get_logger(name=__name__, category="providers::utils")
|
|||
class PostgresKVStoreImpl(KVStore):
|
||||
def __init__(self, config: PostgresKVStoreConfig):
|
||||
self.config = config
|
||||
self.conn = None
|
||||
self.cursor = None
|
||||
self._conn: PGConnection | None = None
|
||||
self._cursor: DictCursor | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
try:
|
||||
self.conn = psycopg2.connect(
|
||||
self._conn = psycopg2.connect(
|
||||
host=self.config.host,
|
||||
port=self.config.port,
|
||||
database=self.config.db,
|
||||
|
|
@ -34,11 +35,11 @@ class PostgresKVStoreImpl(KVStore):
|
|||
sslmode=self.config.ssl_mode,
|
||||
sslrootcert=self.config.ca_cert_path,
|
||||
)
|
||||
self.conn.autocommit = True
|
||||
self.cursor = self.conn.cursor(cursor_factory=DictCursor)
|
||||
self._conn.autocommit = True
|
||||
self._cursor = self._conn.cursor(cursor_factory=DictCursor)
|
||||
|
||||
# Create table if it doesn't exist
|
||||
self.cursor.execute(
|
||||
self._cursor.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.config.table_name} (
|
||||
key TEXT PRIMARY KEY,
|
||||
|
|
@ -51,6 +52,11 @@ class PostgresKVStoreImpl(KVStore):
|
|||
log.exception("Could not connect to PostgreSQL database server")
|
||||
raise RuntimeError("Could not connect to PostgreSQL database server") from e
|
||||
|
||||
def _cursor_or_raise(self) -> DictCursor:
|
||||
if self._cursor is None:
|
||||
raise RuntimeError("Postgres client not initialized")
|
||||
return self._cursor
|
||||
|
||||
def _namespaced_key(self, key: str) -> str:
|
||||
if not self.config.namespace:
|
||||
return key
|
||||
|
|
@ -58,7 +64,8 @@ class PostgresKVStoreImpl(KVStore):
|
|||
|
||||
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
|
||||
key = self._namespaced_key(key)
|
||||
self.cursor.execute(
|
||||
cursor = self._cursor_or_raise()
|
||||
cursor.execute(
|
||||
f"""
|
||||
INSERT INTO {self.config.table_name} (key, value, expiration)
|
||||
VALUES (%s, %s, %s)
|
||||
|
|
@ -70,7 +77,8 @@ class PostgresKVStoreImpl(KVStore):
|
|||
|
||||
async def get(self, key: str) -> str | None:
|
||||
key = self._namespaced_key(key)
|
||||
self.cursor.execute(
|
||||
cursor = self._cursor_or_raise()
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT value FROM {self.config.table_name}
|
||||
WHERE key = %s
|
||||
|
|
@ -78,12 +86,13 @@ class PostgresKVStoreImpl(KVStore):
|
|||
""",
|
||||
(key,),
|
||||
)
|
||||
result = self.cursor.fetchone()
|
||||
result = cursor.fetchone()
|
||||
return result[0] if result else None
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
key = self._namespaced_key(key)
|
||||
self.cursor.execute(
|
||||
cursor = self._cursor_or_raise()
|
||||
cursor.execute(
|
||||
f"DELETE FROM {self.config.table_name} WHERE key = %s",
|
||||
(key,),
|
||||
)
|
||||
|
|
@ -92,7 +101,8 @@ class PostgresKVStoreImpl(KVStore):
|
|||
start_key = self._namespaced_key(start_key)
|
||||
end_key = self._namespaced_key(end_key)
|
||||
|
||||
self.cursor.execute(
|
||||
cursor = self._cursor_or_raise()
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT value FROM {self.config.table_name}
|
||||
WHERE key >= %s AND key < %s
|
||||
|
|
@ -101,14 +111,15 @@ class PostgresKVStoreImpl(KVStore):
|
|||
""",
|
||||
(start_key, end_key),
|
||||
)
|
||||
return [row[0] for row in self.cursor.fetchall()]
|
||||
return [row[0] for row in cursor.fetchall()]
|
||||
|
||||
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||
start_key = self._namespaced_key(start_key)
|
||||
end_key = self._namespaced_key(end_key)
|
||||
|
||||
self.cursor.execute(
|
||||
cursor = self._cursor_or_raise()
|
||||
cursor.execute(
|
||||
f"SELECT key FROM {self.config.table_name} WHERE key >= %s AND key < %s",
|
||||
(start_key, end_key),
|
||||
)
|
||||
return [row[0] for row in self.cursor.fetchall()]
|
||||
return [row[0] for row in cursor.fetchall()]
|
||||
|
|
@ -6,18 +6,25 @@
|
|||
|
||||
from datetime import datetime
|
||||
|
||||
from redis.asyncio import Redis
|
||||
from redis.asyncio import Redis # type: ignore[import-not-found]
|
||||
|
||||
from llama_stack_api.internal.kvstore import KVStore
|
||||
|
||||
from ..api import KVStore
|
||||
from ..config import RedisKVStoreConfig
|
||||
|
||||
|
||||
class RedisKVStoreImpl(KVStore):
|
||||
def __init__(self, config: RedisKVStoreConfig):
|
||||
self.config = config
|
||||
self._redis: Redis | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.redis = Redis.from_url(self.config.url)
|
||||
self._redis = Redis.from_url(self.config.url)
|
||||
|
||||
def _client(self) -> Redis:
|
||||
if self._redis is None:
|
||||
raise RuntimeError("Redis client not initialized")
|
||||
return self._redis
|
||||
|
||||
def _namespaced_key(self, key: str) -> str:
|
||||
if not self.config.namespace:
|
||||
|
|
@ -26,30 +33,37 @@ class RedisKVStoreImpl(KVStore):
|
|||
|
||||
async def set(self, key: str, value: str, expiration: datetime | None = None) -> None:
|
||||
key = self._namespaced_key(key)
|
||||
await self.redis.set(key, value)
|
||||
client = self._client()
|
||||
await client.set(key, value)
|
||||
if expiration:
|
||||
await self.redis.expireat(key, expiration)
|
||||
await client.expireat(key, expiration)
|
||||
|
||||
async def get(self, key: str) -> str | None:
|
||||
key = self._namespaced_key(key)
|
||||
value = await self.redis.get(key)
|
||||
client = self._client()
|
||||
value = await client.get(key)
|
||||
if value is None:
|
||||
return None
|
||||
await self.redis.ttl(key)
|
||||
return value
|
||||
await client.ttl(key)
|
||||
if isinstance(value, bytes):
|
||||
return value.decode("utf-8")
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
key = self._namespaced_key(key)
|
||||
await self.redis.delete(key)
|
||||
await self._client().delete(key)
|
||||
|
||||
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||
start_key = self._namespaced_key(start_key)
|
||||
end_key = self._namespaced_key(end_key)
|
||||
client = self._client()
|
||||
cursor = 0
|
||||
pattern = start_key + "*" # Match all keys starting with start_key prefix
|
||||
matching_keys = []
|
||||
matching_keys: list[str | bytes] = []
|
||||
while True:
|
||||
cursor, keys = await self.redis.scan(cursor, match=pattern, count=1000)
|
||||
cursor, keys = await client.scan(cursor, match=pattern, count=1000)
|
||||
|
||||
for key in keys:
|
||||
key_str = key.decode("utf-8") if isinstance(key, bytes) else key
|
||||
|
|
@ -61,7 +75,7 @@ class RedisKVStoreImpl(KVStore):
|
|||
|
||||
# Then fetch all values in a single MGET call
|
||||
if matching_keys:
|
||||
values = await self.redis.mget(matching_keys)
|
||||
values = await client.mget(matching_keys)
|
||||
return [
|
||||
value.decode("utf-8") if isinstance(value, bytes) else value for value in values if value is not None
|
||||
]
|
||||
|
|
@ -70,7 +84,18 @@ class RedisKVStoreImpl(KVStore):
|
|||
|
||||
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||
"""Get all keys in the given range."""
|
||||
matching_keys = await self.redis.zrangebylex(self.namespace, f"[{start_key}", f"[{end_key}")
|
||||
if not matching_keys:
|
||||
return []
|
||||
return [k.decode("utf-8") for k in matching_keys]
|
||||
start_key = self._namespaced_key(start_key)
|
||||
end_key = self._namespaced_key(end_key)
|
||||
client = self._client()
|
||||
cursor = 0
|
||||
pattern = start_key + "*"
|
||||
result: list[str] = []
|
||||
while True:
|
||||
cursor, keys = await client.scan(cursor, match=pattern, count=1000)
|
||||
for key in keys:
|
||||
key_str = key.decode("utf-8") if isinstance(key, bytes) else str(key)
|
||||
if start_key <= key_str <= end_key:
|
||||
result.append(key_str)
|
||||
if cursor == 0:
|
||||
break
|
||||
return result
|
||||
|
|
@ -10,8 +10,8 @@ from datetime import datetime
|
|||
import aiosqlite
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack_api.internal.kvstore import KVStore
|
||||
|
||||
from ..api import KVStore
|
||||
from ..config import SqliteKVStoreConfig
|
||||
|
||||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
17
src/llama_stack/core/storage/sqlstore/__init__.py
Normal file
17
src/llama_stack/core/storage/sqlstore/__init__.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack_api.internal.sqlstore import (
|
||||
ColumnDefinition as ColumnDefinition,
|
||||
)
|
||||
from llama_stack_api.internal.sqlstore import (
|
||||
ColumnType as ColumnType,
|
||||
)
|
||||
from llama_stack_api.internal.sqlstore import (
|
||||
SqlStore as SqlStore,
|
||||
)
|
||||
|
||||
from .sqlstore import * # noqa: F401,F403
|
||||
|
|
@ -14,8 +14,8 @@ from llama_stack.core.datatypes import User
|
|||
from llama_stack.core.request_headers import get_authenticated_user
|
||||
from llama_stack.core.storage.datatypes import StorageBackendType
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore
|
||||
from llama_stack_api import PaginatedResponse
|
||||
from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType, SqlStore
|
||||
|
||||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
|
@ -45,8 +45,13 @@ def _enhance_item_with_access_control(item: Mapping[str, Any], current_user: Use
|
|||
enhanced["owner_principal"] = current_user.principal
|
||||
enhanced["access_attributes"] = current_user.attributes
|
||||
else:
|
||||
enhanced["owner_principal"] = None
|
||||
enhanced["access_attributes"] = None
|
||||
# IMPORTANT: Use empty string and null value (not None) to match public access filter
|
||||
# The public access filter in _get_public_access_conditions() expects:
|
||||
# - owner_principal = '' (empty string)
|
||||
# - access_attributes = null (JSON null, which serializes to the string 'null')
|
||||
# Setting them to None (SQL NULL) will cause rows to be filtered out on read.
|
||||
enhanced["owner_principal"] = ""
|
||||
enhanced["access_attributes"] = None # Pydantic/JSON will serialize this as JSON null
|
||||
return enhanced
|
||||
|
||||
|
||||
|
|
@ -124,6 +129,23 @@ class AuthorizedSqlStore:
|
|||
enhanced_data = [_enhance_item_with_access_control(item, current_user) for item in data]
|
||||
await self.sql_store.insert(table, enhanced_data)
|
||||
|
||||
async def upsert(
|
||||
self,
|
||||
table: str,
|
||||
data: Mapping[str, Any],
|
||||
conflict_columns: list[str],
|
||||
update_columns: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Upsert a row with automatic access control attribute capture."""
|
||||
current_user = get_authenticated_user()
|
||||
enhanced_data = _enhance_item_with_access_control(data, current_user)
|
||||
await self.sql_store.upsert(
|
||||
table=table,
|
||||
data=enhanced_data,
|
||||
conflict_columns=conflict_columns,
|
||||
update_columns=update_columns,
|
||||
)
|
||||
|
||||
async def fetch_all(
|
||||
self,
|
||||
table: str,
|
||||
|
|
@ -188,8 +210,9 @@ class AuthorizedSqlStore:
|
|||
enhanced_data["owner_principal"] = current_user.principal
|
||||
enhanced_data["access_attributes"] = current_user.attributes
|
||||
else:
|
||||
enhanced_data["owner_principal"] = None
|
||||
enhanced_data["access_attributes"] = None
|
||||
# IMPORTANT: Use empty string for owner_principal to match public access filter
|
||||
enhanced_data["owner_principal"] = ""
|
||||
enhanced_data["access_attributes"] = None # Will serialize as JSON null
|
||||
|
||||
await self.sql_store.update(table, enhanced_data, where)
|
||||
|
||||
|
|
@ -245,14 +268,24 @@ class AuthorizedSqlStore:
|
|||
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||
|
||||
def _get_public_access_conditions(self) -> list[str]:
|
||||
"""Get the SQL conditions for public access."""
|
||||
# Public records are records that have no owner_principal or access_attributes
|
||||
"""Get the SQL conditions for public access.
|
||||
|
||||
Public records are those with:
|
||||
- owner_principal = '' (empty string)
|
||||
- access_attributes is either SQL NULL or JSON null
|
||||
|
||||
Note: Different databases serialize None differently:
|
||||
- SQLite: None → JSON null (text = 'null')
|
||||
- Postgres: None → SQL NULL (IS NULL)
|
||||
"""
|
||||
conditions = ["owner_principal = ''"]
|
||||
if self.database_type == StorageBackendType.SQL_POSTGRES.value:
|
||||
# Postgres stores JSON null as 'null'
|
||||
conditions.append("access_attributes::text = 'null'")
|
||||
# Accept both SQL NULL and JSON null for Postgres compatibility
|
||||
# This handles both old rows (SQL NULL) and new rows (JSON null)
|
||||
conditions.append("(access_attributes IS NULL OR access_attributes::text = 'null')")
|
||||
elif self.database_type == StorageBackendType.SQL_SQLITE.value:
|
||||
conditions.append("access_attributes = 'null'")
|
||||
# SQLite serializes None as JSON null
|
||||
conditions.append("(access_attributes IS NULL OR access_attributes = 'null')")
|
||||
else:
|
||||
raise ValueError(f"Unsupported database type: {self.database_type}")
|
||||
return conditions
|
||||
|
|
@ -26,11 +26,10 @@ from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
|||
from sqlalchemy.ext.asyncio.engine import AsyncEngine
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.core.storage.datatypes import SqlAlchemySqlStoreConfig
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .api import ColumnDefinition, ColumnType, SqlStore
|
||||
from llama_stack_api import PaginatedResponse
|
||||
from llama_stack_api.internal.sqlstore import ColumnDefinition, ColumnType, SqlStore
|
||||
|
||||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
|
@ -72,13 +71,14 @@ def _build_where_expr(column: ColumnElement, value: Any) -> ColumnElement:
|
|||
class SqlAlchemySqlStoreImpl(SqlStore):
|
||||
def __init__(self, config: SqlAlchemySqlStoreConfig):
|
||||
self.config = config
|
||||
self._is_sqlite_backend = "sqlite" in self.config.engine_str
|
||||
self.async_session = async_sessionmaker(self.create_engine())
|
||||
self.metadata = MetaData()
|
||||
|
||||
def create_engine(self) -> AsyncEngine:
|
||||
# Configure connection args for better concurrency support
|
||||
connect_args = {}
|
||||
if "sqlite" in self.config.engine_str:
|
||||
if self._is_sqlite_backend:
|
||||
# SQLite-specific optimizations for concurrent access
|
||||
# With WAL mode, most locks resolve in milliseconds, but allow up to 5s for edge cases
|
||||
connect_args["timeout"] = 5.0
|
||||
|
|
@ -91,7 +91,7 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
|||
)
|
||||
|
||||
# Enable WAL mode for SQLite to support concurrent readers and writers
|
||||
if "sqlite" in self.config.engine_str:
|
||||
if self._is_sqlite_backend:
|
||||
|
||||
@event.listens_for(engine.sync_engine, "connect")
|
||||
def set_sqlite_pragma(dbapi_conn, connection_record):
|
||||
|
|
@ -151,6 +151,29 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
|||
await session.execute(self.metadata.tables[table].insert(), data)
|
||||
await session.commit()
|
||||
|
||||
async def upsert(
|
||||
self,
|
||||
table: str,
|
||||
data: Mapping[str, Any],
|
||||
conflict_columns: list[str],
|
||||
update_columns: list[str] | None = None,
|
||||
) -> None:
|
||||
table_obj = self.metadata.tables[table]
|
||||
dialect_insert = self._get_dialect_insert(table_obj)
|
||||
insert_stmt = dialect_insert.values(**data)
|
||||
|
||||
if update_columns is None:
|
||||
update_columns = [col for col in data.keys() if col not in conflict_columns]
|
||||
|
||||
update_mapping = {col: getattr(insert_stmt.excluded, col) for col in update_columns}
|
||||
conflict_cols = [table_obj.c[col] for col in conflict_columns]
|
||||
|
||||
stmt = insert_stmt.on_conflict_do_update(index_elements=conflict_cols, set_=update_mapping)
|
||||
|
||||
async with self.async_session() as session:
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
async def fetch_all(
|
||||
self,
|
||||
table: str,
|
||||
|
|
@ -333,9 +356,18 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
|||
add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}")
|
||||
|
||||
await conn.execute(add_column_sql)
|
||||
|
||||
except Exception as e:
|
||||
# If any error occurs during migration, log it but don't fail
|
||||
# The table creation will handle adding the column
|
||||
logger.error(f"Error adding column {column_name} to table {table}: {e}")
|
||||
pass
|
||||
|
||||
def _get_dialect_insert(self, table: Table):
|
||||
if self._is_sqlite_backend:
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
|
||||
return sqlite_insert(table)
|
||||
else:
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
return pg_insert(table)
|
||||
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from threading import Lock
|
||||
from typing import Annotated, cast
|
||||
|
||||
from pydantic import Field
|
||||
|
|
@ -15,12 +16,13 @@ from llama_stack.core.storage.datatypes import (
|
|||
StorageBackendConfig,
|
||||
StorageBackendType,
|
||||
)
|
||||
|
||||
from .api import SqlStore
|
||||
from llama_stack_api.internal.sqlstore import SqlStore
|
||||
|
||||
sql_store_pip_packages = ["sqlalchemy[asyncio]", "aiosqlite", "asyncpg"]
|
||||
|
||||
_SQLSTORE_BACKENDS: dict[str, StorageBackendConfig] = {}
|
||||
_SQLSTORE_INSTANCES: dict[str, SqlStore] = {}
|
||||
_SQLSTORE_LOCKS: dict[str, Lock] = {}
|
||||
|
||||
|
||||
SqlStoreConfig = Annotated[
|
||||
|
|
@ -52,19 +54,34 @@ def sqlstore_impl(reference: SqlStoreReference) -> SqlStore:
|
|||
f"Unknown SQL store backend '{backend_name}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
|
||||
)
|
||||
|
||||
if isinstance(backend_config, SqliteSqlStoreConfig | PostgresSqlStoreConfig):
|
||||
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
||||
existing = _SQLSTORE_INSTANCES.get(backend_name)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
config = cast(SqliteSqlStoreConfig | PostgresSqlStoreConfig, backend_config).model_copy()
|
||||
return SqlAlchemySqlStoreImpl(config)
|
||||
else:
|
||||
raise ValueError(f"Unknown sqlstore type {backend_config.type}")
|
||||
lock = _SQLSTORE_LOCKS.setdefault(backend_name, Lock())
|
||||
with lock:
|
||||
existing = _SQLSTORE_INSTANCES.get(backend_name)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
if isinstance(backend_config, SqliteSqlStoreConfig | PostgresSqlStoreConfig):
|
||||
from .sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl
|
||||
|
||||
config = cast(SqliteSqlStoreConfig | PostgresSqlStoreConfig, backend_config).model_copy()
|
||||
instance = SqlAlchemySqlStoreImpl(config)
|
||||
_SQLSTORE_INSTANCES[backend_name] = instance
|
||||
return instance
|
||||
else:
|
||||
raise ValueError(f"Unknown sqlstore type {backend_config.type}")
|
||||
|
||||
|
||||
def register_sqlstore_backends(backends: dict[str, StorageBackendConfig]) -> None:
|
||||
"""Register the set of available SQL store backends for reference resolution."""
|
||||
global _SQLSTORE_BACKENDS
|
||||
global _SQLSTORE_INSTANCES
|
||||
|
||||
_SQLSTORE_BACKENDS.clear()
|
||||
_SQLSTORE_INSTANCES.clear()
|
||||
_SQLSTORE_LOCKS.clear()
|
||||
for name, cfg in backends.items():
|
||||
_SQLSTORE_BACKENDS[name] = cfg
|
||||
|
|
@ -12,8 +12,8 @@ import pydantic
|
|||
|
||||
from llama_stack.core.datatypes import RoutableObjectWithProvider
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||
from llama_stack.core.storage.kvstore import KVStore, kvstore_impl
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||
|
||||
logger = get_logger(__name__, category="core::registry")
|
||||
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ from pydantic import BaseModel, Field
|
|||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import Primitive
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||
from llama_stack_api import json_schema_type, register_schema
|
||||
|
||||
ROOT_SPAN_MARKERS = ["__root__", "__root_span__"]
|
||||
|
||||
|
|
|
|||
45
src/llama_stack/core/utils/type_inspection.py
Normal file
45
src/llama_stack/core/utils/type_inspection.py
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Utility functions for type inspection and parameter handling.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import typing
|
||||
from typing import Any, get_args, get_origin
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
|
||||
def is_unwrapped_body_param(param_type: Any) -> bool:
|
||||
"""
|
||||
Check if a parameter type represents an unwrapped body parameter.
|
||||
An unwrapped body parameter is an Annotated type with Body(embed=False)
|
||||
|
||||
This is used to determine whether request parameters should be flattened
|
||||
in OpenAPI specs and client libraries (matching FastAPI's embed=False behavior).
|
||||
|
||||
Args:
|
||||
param_type: The parameter type annotation to check
|
||||
|
||||
Returns:
|
||||
True if the parameter should be treated as an unwrapped body parameter
|
||||
"""
|
||||
# Check if it's Annotated with Body(embed=False)
|
||||
if get_origin(param_type) is typing.Annotated:
|
||||
args = get_args(param_type)
|
||||
base_type = args[0]
|
||||
metadata = args[1:]
|
||||
|
||||
# Look for Body annotation with embed=False
|
||||
# Body() returns a FieldInfo object, so we check for that type and the embed attribute
|
||||
for item in metadata:
|
||||
if isinstance(item, FieldInfo) and hasattr(item, "embed") and not item.embed:
|
||||
return inspect.isclass(base_type) and issubclass(base_type, BaseModel)
|
||||
|
||||
return False
|
||||
|
|
@ -13,6 +13,5 @@ from ..starter.starter import get_distribution_template as get_starter_distribut
|
|||
def get_distribution_template() -> DistributionTemplate:
|
||||
template = get_starter_distribution_template(name="ci-tests")
|
||||
template.description = "CI tests for Llama Stack"
|
||||
template.run_configs.pop("run-with-postgres-store.yaml", None)
|
||||
|
||||
return template
|
||||
|
|
|
|||
|
|
@ -0,0 +1,296 @@
|
|||
version: 2
|
||||
image_name: ci-tests
|
||||
apis:
|
||||
- agents
|
||||
- batches
|
||||
- datasetio
|
||||
- eval
|
||||
- file_processor
|
||||
- files
|
||||
- inference
|
||||
- post_training
|
||||
- safety
|
||||
- scoring
|
||||
- tool_runtime
|
||||
- vector_io
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
||||
provider_type: remote::cerebras
|
||||
config:
|
||||
base_url: https://api.cerebras.ai/v1
|
||||
api_key: ${env.CEREBRAS_API_KEY:=}
|
||||
- provider_id: ${env.OLLAMA_URL:+ollama}
|
||||
provider_type: remote::ollama
|
||||
config:
|
||||
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
|
||||
- provider_id: ${env.VLLM_URL:+vllm}
|
||||
provider_type: remote::vllm
|
||||
config:
|
||||
base_url: ${env.VLLM_URL:=}
|
||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||
- provider_id: ${env.TGI_URL:+tgi}
|
||||
provider_type: remote::tgi
|
||||
config:
|
||||
base_url: ${env.TGI_URL:=}
|
||||
- provider_id: fireworks
|
||||
provider_type: remote::fireworks
|
||||
config:
|
||||
base_url: https://api.fireworks.ai/inference/v1
|
||||
api_key: ${env.FIREWORKS_API_KEY:=}
|
||||
- provider_id: together
|
||||
provider_type: remote::together
|
||||
config:
|
||||
base_url: https://api.together.xyz/v1
|
||||
api_key: ${env.TOGETHER_API_KEY:=}
|
||||
- provider_id: bedrock
|
||||
provider_type: remote::bedrock
|
||||
config:
|
||||
api_key: ${env.AWS_BEARER_TOKEN_BEDROCK:=}
|
||||
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
|
||||
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
|
||||
api_key: ${env.NVIDIA_API_KEY:=}
|
||||
- provider_id: openai
|
||||
provider_type: remote::openai
|
||||
config:
|
||||
api_key: ${env.OPENAI_API_KEY:=}
|
||||
base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1}
|
||||
- provider_id: anthropic
|
||||
provider_type: remote::anthropic
|
||||
config:
|
||||
api_key: ${env.ANTHROPIC_API_KEY:=}
|
||||
- provider_id: gemini
|
||||
provider_type: remote::gemini
|
||||
config:
|
||||
api_key: ${env.GEMINI_API_KEY:=}
|
||||
- provider_id: ${env.VERTEX_AI_PROJECT:+vertexai}
|
||||
provider_type: remote::vertexai
|
||||
config:
|
||||
project: ${env.VERTEX_AI_PROJECT:=}
|
||||
location: ${env.VERTEX_AI_LOCATION:=us-central1}
|
||||
- provider_id: groq
|
||||
provider_type: remote::groq
|
||||
config:
|
||||
base_url: https://api.groq.com/openai/v1
|
||||
api_key: ${env.GROQ_API_KEY:=}
|
||||
- provider_id: sambanova
|
||||
provider_type: remote::sambanova
|
||||
config:
|
||||
base_url: https://api.sambanova.ai/v1
|
||||
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||
- provider_id: ${env.AZURE_API_KEY:+azure}
|
||||
provider_type: remote::azure
|
||||
config:
|
||||
api_key: ${env.AZURE_API_KEY:=}
|
||||
base_url: ${env.AZURE_API_BASE:=}
|
||||
api_version: ${env.AZURE_API_VERSION:=}
|
||||
api_type: ${env.AZURE_API_TYPE:=}
|
||||
- provider_id: sentence-transformers
|
||||
provider_type: inline::sentence-transformers
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
config:
|
||||
persistence:
|
||||
namespace: vector_io::faiss
|
||||
backend: kv_default
|
||||
- provider_id: sqlite-vec
|
||||
provider_type: inline::sqlite-vec
|
||||
config:
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/sqlite_vec.db
|
||||
persistence:
|
||||
namespace: vector_io::sqlite_vec
|
||||
backend: kv_default
|
||||
- provider_id: ${env.MILVUS_URL:+milvus}
|
||||
provider_type: inline::milvus
|
||||
config:
|
||||
db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/ci-tests}/milvus.db
|
||||
persistence:
|
||||
namespace: vector_io::milvus
|
||||
backend: kv_default
|
||||
- provider_id: ${env.CHROMADB_URL:+chromadb}
|
||||
provider_type: remote::chromadb
|
||||
config:
|
||||
url: ${env.CHROMADB_URL:=}
|
||||
persistence:
|
||||
namespace: vector_io::chroma_remote
|
||||
backend: kv_default
|
||||
- provider_id: ${env.PGVECTOR_DB:+pgvector}
|
||||
provider_type: remote::pgvector
|
||||
config:
|
||||
host: ${env.PGVECTOR_HOST:=localhost}
|
||||
port: ${env.PGVECTOR_PORT:=5432}
|
||||
db: ${env.PGVECTOR_DB:=}
|
||||
user: ${env.PGVECTOR_USER:=}
|
||||
password: ${env.PGVECTOR_PASSWORD:=}
|
||||
persistence:
|
||||
namespace: vector_io::pgvector
|
||||
backend: kv_default
|
||||
- provider_id: ${env.QDRANT_URL:+qdrant}
|
||||
provider_type: remote::qdrant
|
||||
config:
|
||||
api_key: ${env.QDRANT_API_KEY:=}
|
||||
persistence:
|
||||
namespace: vector_io::qdrant_remote
|
||||
backend: kv_default
|
||||
- provider_id: ${env.WEAVIATE_CLUSTER_URL:+weaviate}
|
||||
provider_type: remote::weaviate
|
||||
config:
|
||||
weaviate_api_key: null
|
||||
weaviate_cluster_url: ${env.WEAVIATE_CLUSTER_URL:=localhost:8080}
|
||||
persistence:
|
||||
namespace: vector_io::weaviate
|
||||
backend: kv_default
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/ci-tests/files}
|
||||
metadata_store:
|
||||
table_name: files_metadata
|
||||
backend: sql_default
|
||||
file_processor:
|
||||
- provider_id: reference
|
||||
provider_type: inline::reference
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config:
|
||||
excluded_categories: []
|
||||
- provider_id: code-scanner
|
||||
provider_type: inline::code-scanner
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence:
|
||||
agent_state:
|
||||
namespace: agents
|
||||
backend: kv_default
|
||||
responses:
|
||||
table_name: responses
|
||||
backend: sql_default
|
||||
max_write_queue_size: 10000
|
||||
num_writers: 4
|
||||
post_training:
|
||||
- provider_id: torchtune-cpu
|
||||
provider_type: inline::torchtune-cpu
|
||||
config:
|
||||
checkpoint_format: meta
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
kvstore:
|
||||
namespace: eval
|
||||
backend: kv_default
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config:
|
||||
kvstore:
|
||||
namespace: datasetio::huggingface
|
||||
backend: kv_default
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
kvstore:
|
||||
namespace: datasetio::localfs
|
||||
backend: kv_default
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
- provider_id: llm-as-judge
|
||||
provider_type: inline::llm-as-judge
|
||||
- provider_id: braintrust
|
||||
provider_type: inline::braintrust
|
||||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:=}
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
config:
|
||||
api_key: ${env.BRAVE_SEARCH_API_KEY:=}
|
||||
max_results: 3
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
|
||||
max_results: 3
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
batches:
|
||||
- provider_id: reference
|
||||
provider_type: inline::reference
|
||||
config:
|
||||
kvstore:
|
||||
namespace: batches
|
||||
backend: kv_default
|
||||
storage:
|
||||
backends:
|
||||
kv_default:
|
||||
type: kv_postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
db: ${env.POSTGRES_DB:=llamastack}
|
||||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
table_name: ${env.POSTGRES_TABLE_NAME:=llamastack_kvstore}
|
||||
sql_default:
|
||||
type: sql_postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
db: ${env.POSTGRES_DB:=llamastack}
|
||||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
stores:
|
||||
metadata:
|
||||
namespace: registry
|
||||
backend: kv_default
|
||||
inference:
|
||||
table_name: inference_store
|
||||
backend: sql_default
|
||||
max_write_queue_size: 10000
|
||||
num_writers: 4
|
||||
conversations:
|
||||
table_name: openai_conversations
|
||||
backend: sql_default
|
||||
prompts:
|
||||
namespace: prompts
|
||||
backend: kv_default
|
||||
registered_resources:
|
||||
models: []
|
||||
shields:
|
||||
- shield_id: llama-guard
|
||||
provider_id: ${env.SAFETY_MODEL:+llama-guard}
|
||||
provider_shield_id: ${env.SAFETY_MODEL:=}
|
||||
- shield_id: code-scanner
|
||||
provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner}
|
||||
provider_shield_id: ${env.CODE_SCANNER_MODEL:=}
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
benchmarks: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
enabled: true
|
||||
vector_stores:
|
||||
default_provider_id: faiss
|
||||
default_embedding_model:
|
||||
provider_id: sentence-transformers
|
||||
model_id: nomic-ai/nomic-embed-text-v1.5
|
||||
safety:
|
||||
default_shield_id: llama-guard
|
||||
|
|
@ -18,44 +18,43 @@ providers:
|
|||
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
||||
provider_type: remote::cerebras
|
||||
config:
|
||||
base_url: https://api.cerebras.ai
|
||||
base_url: https://api.cerebras.ai/v1
|
||||
api_key: ${env.CEREBRAS_API_KEY:=}
|
||||
- provider_id: ${env.OLLAMA_URL:+ollama}
|
||||
provider_type: remote::ollama
|
||||
config:
|
||||
url: ${env.OLLAMA_URL:=http://localhost:11434}
|
||||
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
|
||||
- provider_id: ${env.VLLM_URL:+vllm}
|
||||
provider_type: remote::vllm
|
||||
config:
|
||||
url: ${env.VLLM_URL:=}
|
||||
base_url: ${env.VLLM_URL:=}
|
||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||
- provider_id: ${env.TGI_URL:+tgi}
|
||||
provider_type: remote::tgi
|
||||
config:
|
||||
url: ${env.TGI_URL:=}
|
||||
base_url: ${env.TGI_URL:=}
|
||||
- provider_id: fireworks
|
||||
provider_type: remote::fireworks
|
||||
config:
|
||||
url: https://api.fireworks.ai/inference/v1
|
||||
base_url: https://api.fireworks.ai/inference/v1
|
||||
api_key: ${env.FIREWORKS_API_KEY:=}
|
||||
- provider_id: together
|
||||
provider_type: remote::together
|
||||
config:
|
||||
url: https://api.together.xyz/v1
|
||||
base_url: https://api.together.xyz/v1
|
||||
api_key: ${env.TOGETHER_API_KEY:=}
|
||||
- provider_id: bedrock
|
||||
provider_type: remote::bedrock
|
||||
config:
|
||||
api_key: ${env.AWS_BEDROCK_API_KEY:=}
|
||||
api_key: ${env.AWS_BEARER_TOKEN_BEDROCK:=}
|
||||
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
|
||||
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
||||
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
|
||||
api_key: ${env.NVIDIA_API_KEY:=}
|
||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
||||
- provider_id: openai
|
||||
provider_type: remote::openai
|
||||
config:
|
||||
|
|
@ -77,18 +76,18 @@ providers:
|
|||
- provider_id: groq
|
||||
provider_type: remote::groq
|
||||
config:
|
||||
url: https://api.groq.com
|
||||
base_url: https://api.groq.com/openai/v1
|
||||
api_key: ${env.GROQ_API_KEY:=}
|
||||
- provider_id: sambanova
|
||||
provider_type: remote::sambanova
|
||||
config:
|
||||
url: https://api.sambanova.ai/v1
|
||||
base_url: https://api.sambanova.ai/v1
|
||||
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||
- provider_id: ${env.AZURE_API_KEY:+azure}
|
||||
provider_type: remote::azure
|
||||
config:
|
||||
api_key: ${env.AZURE_API_KEY:=}
|
||||
api_base: ${env.AZURE_API_BASE:=}
|
||||
base_url: ${env.AZURE_API_BASE:=}
|
||||
api_version: ${env.AZURE_API_VERSION:=}
|
||||
api_type: ${env.AZURE_API_TYPE:=}
|
||||
- provider_id: sentence-transformers
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.datatypes import (
|
||||
BuildProvider,
|
||||
ModelInput,
|
||||
|
|
@ -17,6 +16,7 @@ from llama_stack.providers.inline.inference.sentence_transformers import (
|
|||
SentenceTransformersInferenceConfig,
|
||||
)
|
||||
from llama_stack.providers.remote.vector_io.chroma import ChromaVectorIOConfig
|
||||
from llama_stack_api import ModelType
|
||||
|
||||
|
||||
def get_distribution_template() -> DistributionTemplate:
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@
|
|||
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.datatypes import (
|
||||
BuildProvider,
|
||||
ModelInput,
|
||||
|
|
@ -22,6 +21,7 @@ from llama_stack.providers.inline.inference.sentence_transformers import (
|
|||
SentenceTransformersInferenceConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack_api import ModelType
|
||||
|
||||
|
||||
def get_distribution_template() -> DistributionTemplate:
|
||||
|
|
|
|||
|
|
@ -16,9 +16,8 @@ providers:
|
|||
- provider_id: nvidia
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
||||
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
|
||||
api_key: ${env.NVIDIA_API_KEY:=}
|
||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
||||
- provider_id: nvidia
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
|
|
|
|||
|
|
@ -16,9 +16,8 @@ providers:
|
|||
- provider_id: nvidia
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
||||
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
|
||||
api_key: ${env.NVIDIA_API_KEY:=}
|
||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
|
|
|
|||
|
|
@ -4,4 +4,4 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .agents import *
|
||||
from .oci import get_distribution_template # noqa: F401
|
||||
35
src/llama_stack/distributions/oci/build.yaml
Normal file
35
src/llama_stack/distributions/oci/build.yaml
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
version: 2
|
||||
distribution_spec:
|
||||
description: Use Oracle Cloud Infrastructure (OCI) Generative AI for running LLM
|
||||
inference with scalable cloud services
|
||||
providers:
|
||||
inference:
|
||||
- provider_type: remote::oci
|
||||
vector_io:
|
||||
- provider_type: inline::faiss
|
||||
- provider_type: remote::chromadb
|
||||
- provider_type: remote::pgvector
|
||||
safety:
|
||||
- provider_type: inline::llama-guard
|
||||
agents:
|
||||
- provider_type: inline::meta-reference
|
||||
eval:
|
||||
- provider_type: inline::meta-reference
|
||||
datasetio:
|
||||
- provider_type: remote::huggingface
|
||||
- provider_type: inline::localfs
|
||||
scoring:
|
||||
- provider_type: inline::basic
|
||||
- provider_type: inline::llm-as-judge
|
||||
- provider_type: inline::braintrust
|
||||
tool_runtime:
|
||||
- provider_type: remote::brave-search
|
||||
- provider_type: remote::tavily-search
|
||||
- provider_type: inline::rag-runtime
|
||||
- provider_type: remote::model-context-protocol
|
||||
files:
|
||||
- provider_type: inline::localfs
|
||||
image_type: venv
|
||||
additional_pip_packages:
|
||||
- aiosqlite
|
||||
- sqlalchemy[asyncio]
|
||||
140
src/llama_stack/distributions/oci/doc_template.md
Normal file
140
src/llama_stack/distributions/oci/doc_template.md
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
---
|
||||
orphan: true
|
||||
---
|
||||
# OCI Distribution
|
||||
|
||||
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations.
|
||||
|
||||
{{ providers_table }}
|
||||
|
||||
{% if run_config_env_vars %}
|
||||
### Environment Variables
|
||||
|
||||
The following environment variables can be configured:
|
||||
|
||||
{% for var, (default_value, description) in run_config_env_vars.items() %}
|
||||
- `{{ var }}`: {{ description }} (default: `{{ default_value }}`)
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
{% if default_models %}
|
||||
### Models
|
||||
|
||||
The following models are available by default:
|
||||
|
||||
{% for model in default_models %}
|
||||
- `{{ model.model_id }} {{ model.doc_string }}`
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
## Prerequisites
|
||||
### Oracle Cloud Infrastructure Setup
|
||||
|
||||
Before using the OCI Generative AI distribution, ensure you have:
|
||||
|
||||
1. **Oracle Cloud Infrastructure Account**: Sign up at [Oracle Cloud Infrastructure](https://cloud.oracle.com/)
|
||||
2. **Generative AI Service Access**: Enable the Generative AI service in your OCI tenancy
|
||||
3. **Compartment**: Create or identify a compartment where you'll deploy Generative AI models
|
||||
4. **Authentication**: Configure authentication using either:
|
||||
- **Instance Principal** (recommended for cloud-hosted deployments)
|
||||
- **API Key** (for on-premises or development environments)
|
||||
|
||||
### Authentication Methods
|
||||
|
||||
#### Instance Principal Authentication (Recommended)
|
||||
Instance Principal authentication allows OCI resources to authenticate using the identity of the compute instance they're running on. This is the most secure method for production deployments.
|
||||
|
||||
Requirements:
|
||||
- Instance must be running in an Oracle Cloud Infrastructure compartment
|
||||
- Instance must have appropriate IAM policies to access Generative AI services
|
||||
|
||||
#### API Key Authentication
|
||||
For development or on-premises deployments, follow [this doc](https://docs.oracle.com/en-us/iaas/Content/API/Concepts/apisigningkey.htm) to learn how to create your API signing key for your config file.
|
||||
|
||||
### Required IAM Policies
|
||||
|
||||
Ensure your OCI user or instance has the following policy statements:
|
||||
|
||||
```
|
||||
Allow group <group_name> to use generative-ai-inference-endpoints in compartment <compartment_name>
|
||||
Allow group <group_name> to manage generative-ai-inference-endpoints in compartment <compartment_name>
|
||||
```
|
||||
|
||||
## Supported Services
|
||||
|
||||
### Inference: OCI Generative AI
|
||||
Oracle Cloud Infrastructure Generative AI provides access to high-performance AI models through OCI's Platform-as-a-Service offering. The service supports:
|
||||
|
||||
- **Chat Completions**: Conversational AI with context awareness
|
||||
- **Text Generation**: Complete prompts and generate text content
|
||||
|
||||
#### Available Models
|
||||
Common OCI Generative AI models include access to Meta, Cohere, OpenAI, Grok, and more models.
|
||||
|
||||
### Safety: Llama Guard
|
||||
For content safety and moderation, this distribution uses Meta's LlamaGuard model through the OCI Generative AI service to provide:
|
||||
- Content filtering and moderation
|
||||
- Policy compliance checking
|
||||
- Harmful content detection
|
||||
|
||||
### Vector Storage: Multiple Options
|
||||
The distribution supports several vector storage providers:
|
||||
- **FAISS**: Local in-memory vector search
|
||||
- **ChromaDB**: Distributed vector database
|
||||
- **PGVector**: PostgreSQL with vector extensions
|
||||
|
||||
### Additional Services
|
||||
- **Dataset I/O**: Local filesystem and Hugging Face integration
|
||||
- **Tool Runtime**: Web search (Brave, Tavily) and RAG capabilities
|
||||
- **Evaluation**: Meta reference evaluation framework
|
||||
|
||||
## Running Llama Stack with OCI
|
||||
|
||||
You can run the OCI distribution via Docker or local virtual environment.
|
||||
|
||||
### Via venv
|
||||
|
||||
If you've set up your local development environment, you can also build the image using your local virtual environment.
|
||||
|
||||
```bash
|
||||
OCI_AUTH=$OCI_AUTH_TYPE OCI_REGION=$OCI_REGION OCI_COMPARTMENT_OCID=$OCI_COMPARTMENT_OCID llama stack run --port 8321 oci
|
||||
```
|
||||
|
||||
### Configuration Examples
|
||||
|
||||
#### Using Instance Principal (Recommended for Production)
|
||||
```bash
|
||||
export OCI_AUTH_TYPE=instance_principal
|
||||
export OCI_REGION=us-chicago-1
|
||||
export OCI_COMPARTMENT_OCID=ocid1.compartment.oc1..<your-compartment-id>
|
||||
```
|
||||
|
||||
#### Using API Key Authentication (Development)
|
||||
```bash
|
||||
export OCI_AUTH_TYPE=config_file
|
||||
export OCI_CONFIG_FILE_PATH=~/.oci/config
|
||||
export OCI_CLI_PROFILE=DEFAULT
|
||||
export OCI_REGION=us-chicago-1
|
||||
export OCI_COMPARTMENT_OCID=ocid1.compartment.oc1..your-compartment-id
|
||||
```
|
||||
|
||||
## Regional Endpoints
|
||||
|
||||
OCI Generative AI is available in multiple regions. The service automatically routes to the appropriate regional endpoint based on your configuration. For a full list of regional model availability, visit:
|
||||
|
||||
https://docs.oracle.com/en-us/iaas/Content/generative-ai/overview.htm#regions
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Authentication Errors**: Verify your OCI credentials and IAM policies
|
||||
2. **Model Not Found**: Ensure the model OCID is correct and the model is available in your region
|
||||
3. **Permission Denied**: Check compartment permissions and Generative AI service access
|
||||
4. **Region Unavailable**: Verify the specified region supports Generative AI services
|
||||
|
||||
### Getting Help
|
||||
|
||||
For additional support:
|
||||
- [OCI Generative AI Documentation](https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm)
|
||||
- [Llama Stack Issues](https://github.com/meta-llama/llama-stack/issues)
|
||||
108
src/llama_stack/distributions/oci/oci.py
Normal file
108
src/llama_stack/distributions/oci/oci.py
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.core.datatypes import BuildProvider, Provider, ToolGroupInput
|
||||
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
|
||||
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.remote.inference.oci.config import OCIConfig
|
||||
|
||||
|
||||
def get_distribution_template(name: str = "oci") -> DistributionTemplate:
|
||||
providers = {
|
||||
"inference": [BuildProvider(provider_type="remote::oci")],
|
||||
"vector_io": [
|
||||
BuildProvider(provider_type="inline::faiss"),
|
||||
BuildProvider(provider_type="remote::chromadb"),
|
||||
BuildProvider(provider_type="remote::pgvector"),
|
||||
],
|
||||
"safety": [BuildProvider(provider_type="inline::llama-guard")],
|
||||
"agents": [BuildProvider(provider_type="inline::meta-reference")],
|
||||
"eval": [BuildProvider(provider_type="inline::meta-reference")],
|
||||
"datasetio": [
|
||||
BuildProvider(provider_type="remote::huggingface"),
|
||||
BuildProvider(provider_type="inline::localfs"),
|
||||
],
|
||||
"scoring": [
|
||||
BuildProvider(provider_type="inline::basic"),
|
||||
BuildProvider(provider_type="inline::llm-as-judge"),
|
||||
BuildProvider(provider_type="inline::braintrust"),
|
||||
],
|
||||
"tool_runtime": [
|
||||
BuildProvider(provider_type="remote::brave-search"),
|
||||
BuildProvider(provider_type="remote::tavily-search"),
|
||||
BuildProvider(provider_type="inline::rag-runtime"),
|
||||
BuildProvider(provider_type="remote::model-context-protocol"),
|
||||
],
|
||||
"files": [BuildProvider(provider_type="inline::localfs")],
|
||||
}
|
||||
|
||||
inference_provider = Provider(
|
||||
provider_id="oci",
|
||||
provider_type="remote::oci",
|
||||
config=OCIConfig.sample_run_config(),
|
||||
)
|
||||
|
||||
vector_io_provider = Provider(
|
||||
provider_id="faiss",
|
||||
provider_type="inline::faiss",
|
||||
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
)
|
||||
|
||||
files_provider = Provider(
|
||||
provider_id="meta-reference-files",
|
||||
provider_type="inline::localfs",
|
||||
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
)
|
||||
default_tool_groups = [
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::websearch",
|
||||
provider_id="tavily-search",
|
||||
),
|
||||
]
|
||||
|
||||
return DistributionTemplate(
|
||||
name=name,
|
||||
distro_type="remote_hosted",
|
||||
description="Use Oracle Cloud Infrastructure (OCI) Generative AI for running LLM inference with scalable cloud services",
|
||||
container_image=None,
|
||||
template_path=Path(__file__).parent / "doc_template.md",
|
||||
providers=providers,
|
||||
run_configs={
|
||||
"run.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
"inference": [inference_provider],
|
||||
"vector_io": [vector_io_provider],
|
||||
"files": [files_provider],
|
||||
},
|
||||
default_tool_groups=default_tool_groups,
|
||||
),
|
||||
},
|
||||
run_config_env_vars={
|
||||
"OCI_AUTH_TYPE": (
|
||||
"instance_principal",
|
||||
"OCI authentication type (instance_principal or config_file)",
|
||||
),
|
||||
"OCI_REGION": (
|
||||
"",
|
||||
"OCI region (e.g., us-ashburn-1, us-chicago-1, us-phoenix-1, eu-frankfurt-1)",
|
||||
),
|
||||
"OCI_COMPARTMENT_OCID": (
|
||||
"",
|
||||
"OCI compartment ID for the Generative AI service",
|
||||
),
|
||||
"OCI_CONFIG_FILE_PATH": (
|
||||
"~/.oci/config",
|
||||
"OCI config file path (required if OCI_AUTH_TYPE is config_file)",
|
||||
),
|
||||
"OCI_CLI_PROFILE": (
|
||||
"DEFAULT",
|
||||
"OCI CLI profile name to use from config file",
|
||||
),
|
||||
},
|
||||
)
|
||||
136
src/llama_stack/distributions/oci/run.yaml
Normal file
136
src/llama_stack/distributions/oci/run.yaml
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
version: 2
|
||||
image_name: oci
|
||||
apis:
|
||||
- agents
|
||||
- datasetio
|
||||
- eval
|
||||
- files
|
||||
- inference
|
||||
- safety
|
||||
- scoring
|
||||
- tool_runtime
|
||||
- vector_io
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: oci
|
||||
provider_type: remote::oci
|
||||
config:
|
||||
oci_auth_type: ${env.OCI_AUTH_TYPE:=instance_principal}
|
||||
oci_config_file_path: ${env.OCI_CONFIG_FILE_PATH:=~/.oci/config}
|
||||
oci_config_profile: ${env.OCI_CLI_PROFILE:=DEFAULT}
|
||||
oci_region: ${env.OCI_REGION:=us-ashburn-1}
|
||||
oci_compartment_id: ${env.OCI_COMPARTMENT_OCID:=}
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
config:
|
||||
persistence:
|
||||
namespace: vector_io::faiss
|
||||
backend: kv_default
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config:
|
||||
excluded_categories: []
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence:
|
||||
agent_state:
|
||||
namespace: agents
|
||||
backend: kv_default
|
||||
responses:
|
||||
table_name: responses
|
||||
backend: sql_default
|
||||
max_write_queue_size: 10000
|
||||
num_writers: 4
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
kvstore:
|
||||
namespace: eval
|
||||
backend: kv_default
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config:
|
||||
kvstore:
|
||||
namespace: datasetio::huggingface
|
||||
backend: kv_default
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
kvstore:
|
||||
namespace: datasetio::localfs
|
||||
backend: kv_default
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
- provider_id: llm-as-judge
|
||||
provider_type: inline::llm-as-judge
|
||||
- provider_id: braintrust
|
||||
provider_type: inline::braintrust
|
||||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:=}
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
config:
|
||||
api_key: ${env.BRAVE_SEARCH_API_KEY:=}
|
||||
max_results: 3
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
|
||||
max_results: 3
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
files:
|
||||
- provider_id: meta-reference-files
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/oci/files}
|
||||
metadata_store:
|
||||
table_name: files_metadata
|
||||
backend: sql_default
|
||||
storage:
|
||||
backends:
|
||||
kv_default:
|
||||
type: kv_sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/oci}/kvstore.db
|
||||
sql_default:
|
||||
type: sql_sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/oci}/sql_store.db
|
||||
stores:
|
||||
metadata:
|
||||
namespace: registry
|
||||
backend: kv_default
|
||||
inference:
|
||||
table_name: inference_store
|
||||
backend: sql_default
|
||||
max_write_queue_size: 10000
|
||||
num_writers: 4
|
||||
conversations:
|
||||
table_name: openai_conversations
|
||||
backend: sql_default
|
||||
prompts:
|
||||
namespace: prompts
|
||||
backend: kv_default
|
||||
registered_resources:
|
||||
models: []
|
||||
shields: []
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
benchmarks: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
enabled: true
|
||||
|
|
@ -5,8 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.apis.datasets import DatasetPurpose, URIDataSource
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.datatypes import (
|
||||
BenchmarkInput,
|
||||
BuildProvider,
|
||||
|
|
@ -34,6 +32,7 @@ from llama_stack.providers.remote.vector_io.pgvector.config import (
|
|||
PGVectorVectorIOConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
|
||||
from llama_stack_api import DatasetPurpose, ModelType, URIDataSource
|
||||
|
||||
|
||||
def get_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderModelEntry]]]:
|
||||
|
|
|
|||
|
|
@ -27,12 +27,12 @@ providers:
|
|||
- provider_id: groq
|
||||
provider_type: remote::groq
|
||||
config:
|
||||
url: https://api.groq.com
|
||||
base_url: https://api.groq.com/openai/v1
|
||||
api_key: ${env.GROQ_API_KEY:=}
|
||||
- provider_id: together
|
||||
provider_type: remote::together
|
||||
config:
|
||||
url: https://api.together.xyz/v1
|
||||
base_url: https://api.together.xyz/v1
|
||||
api_key: ${env.TOGETHER_API_KEY:=}
|
||||
vector_io:
|
||||
- provider_id: sqlite-vec
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ providers:
|
|||
- provider_id: vllm-inference
|
||||
provider_type: remote::vllm
|
||||
config:
|
||||
url: ${env.VLLM_URL:=http://localhost:8000/v1}
|
||||
base_url: ${env.VLLM_URL:=}
|
||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||
|
|
|
|||
|
|
@ -18,44 +18,43 @@ providers:
|
|||
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
||||
provider_type: remote::cerebras
|
||||
config:
|
||||
base_url: https://api.cerebras.ai
|
||||
base_url: https://api.cerebras.ai/v1
|
||||
api_key: ${env.CEREBRAS_API_KEY:=}
|
||||
- provider_id: ${env.OLLAMA_URL:+ollama}
|
||||
provider_type: remote::ollama
|
||||
config:
|
||||
url: ${env.OLLAMA_URL:=http://localhost:11434}
|
||||
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
|
||||
- provider_id: ${env.VLLM_URL:+vllm}
|
||||
provider_type: remote::vllm
|
||||
config:
|
||||
url: ${env.VLLM_URL:=}
|
||||
base_url: ${env.VLLM_URL:=}
|
||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||
- provider_id: ${env.TGI_URL:+tgi}
|
||||
provider_type: remote::tgi
|
||||
config:
|
||||
url: ${env.TGI_URL:=}
|
||||
base_url: ${env.TGI_URL:=}
|
||||
- provider_id: fireworks
|
||||
provider_type: remote::fireworks
|
||||
config:
|
||||
url: https://api.fireworks.ai/inference/v1
|
||||
base_url: https://api.fireworks.ai/inference/v1
|
||||
api_key: ${env.FIREWORKS_API_KEY:=}
|
||||
- provider_id: together
|
||||
provider_type: remote::together
|
||||
config:
|
||||
url: https://api.together.xyz/v1
|
||||
base_url: https://api.together.xyz/v1
|
||||
api_key: ${env.TOGETHER_API_KEY:=}
|
||||
- provider_id: bedrock
|
||||
provider_type: remote::bedrock
|
||||
config:
|
||||
api_key: ${env.AWS_BEDROCK_API_KEY:=}
|
||||
api_key: ${env.AWS_BEARER_TOKEN_BEDROCK:=}
|
||||
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
|
||||
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
||||
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
|
||||
api_key: ${env.NVIDIA_API_KEY:=}
|
||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
||||
- provider_id: openai
|
||||
provider_type: remote::openai
|
||||
config:
|
||||
|
|
@ -77,18 +76,18 @@ providers:
|
|||
- provider_id: groq
|
||||
provider_type: remote::groq
|
||||
config:
|
||||
url: https://api.groq.com
|
||||
base_url: https://api.groq.com/openai/v1
|
||||
api_key: ${env.GROQ_API_KEY:=}
|
||||
- provider_id: sambanova
|
||||
provider_type: remote::sambanova
|
||||
config:
|
||||
url: https://api.sambanova.ai/v1
|
||||
base_url: https://api.sambanova.ai/v1
|
||||
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||
- provider_id: ${env.AZURE_API_KEY:+azure}
|
||||
provider_type: remote::azure
|
||||
config:
|
||||
api_key: ${env.AZURE_API_KEY:=}
|
||||
api_base: ${env.AZURE_API_BASE:=}
|
||||
base_url: ${env.AZURE_API_BASE:=}
|
||||
api_version: ${env.AZURE_API_VERSION:=}
|
||||
api_type: ${env.AZURE_API_TYPE:=}
|
||||
- provider_id: sentence-transformers
|
||||
|
|
@ -169,20 +168,15 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
type: sql_postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
db: ${env.POSTGRES_DB:=llamastack}
|
||||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
responses_store:
|
||||
type: sql_postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
db: ${env.POSTGRES_DB:=llamastack}
|
||||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
persistence:
|
||||
agent_state:
|
||||
namespace: agents
|
||||
backend: kv_default
|
||||
responses:
|
||||
table_name: responses
|
||||
backend: sql_default
|
||||
max_write_queue_size: 10000
|
||||
num_writers: 4
|
||||
post_training:
|
||||
- provider_id: huggingface-gpu
|
||||
provider_type: inline::huggingface-gpu
|
||||
|
|
@ -241,10 +235,10 @@ providers:
|
|||
config:
|
||||
kvstore:
|
||||
namespace: batches
|
||||
backend: kv_postgres
|
||||
backend: kv_default
|
||||
storage:
|
||||
backends:
|
||||
kv_postgres:
|
||||
kv_default:
|
||||
type: kv_postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
|
|
@ -252,7 +246,7 @@ storage:
|
|||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
table_name: ${env.POSTGRES_TABLE_NAME:=llamastack_kvstore}
|
||||
sql_postgres:
|
||||
sql_default:
|
||||
type: sql_postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
|
|
@ -262,27 +256,44 @@ storage:
|
|||
stores:
|
||||
metadata:
|
||||
namespace: registry
|
||||
backend: kv_postgres
|
||||
backend: kv_default
|
||||
inference:
|
||||
table_name: inference_store
|
||||
backend: sql_postgres
|
||||
backend: sql_default
|
||||
max_write_queue_size: 10000
|
||||
num_writers: 4
|
||||
conversations:
|
||||
table_name: openai_conversations
|
||||
backend: sql_postgres
|
||||
backend: sql_default
|
||||
prompts:
|
||||
namespace: prompts
|
||||
backend: kv_postgres
|
||||
backend: kv_default
|
||||
registered_resources:
|
||||
models: []
|
||||
shields: []
|
||||
shields:
|
||||
- shield_id: llama-guard
|
||||
provider_id: ${env.SAFETY_MODEL:+llama-guard}
|
||||
provider_shield_id: ${env.SAFETY_MODEL:=}
|
||||
- shield_id: code-scanner
|
||||
provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner}
|
||||
provider_shield_id: ${env.CODE_SCANNER_MODEL:=}
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
benchmarks: []
|
||||
tool_groups: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
enabled: true
|
||||
vector_stores:
|
||||
default_provider_id: faiss
|
||||
default_embedding_model:
|
||||
provider_id: sentence-transformers
|
||||
model_id: nomic-ai/nomic-embed-text-v1.5
|
||||
safety:
|
||||
default_shield_id: llama-guard
|
||||
|
|
|
|||
|
|
@ -18,44 +18,43 @@ providers:
|
|||
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
||||
provider_type: remote::cerebras
|
||||
config:
|
||||
base_url: https://api.cerebras.ai
|
||||
base_url: https://api.cerebras.ai/v1
|
||||
api_key: ${env.CEREBRAS_API_KEY:=}
|
||||
- provider_id: ${env.OLLAMA_URL:+ollama}
|
||||
provider_type: remote::ollama
|
||||
config:
|
||||
url: ${env.OLLAMA_URL:=http://localhost:11434}
|
||||
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
|
||||
- provider_id: ${env.VLLM_URL:+vllm}
|
||||
provider_type: remote::vllm
|
||||
config:
|
||||
url: ${env.VLLM_URL:=}
|
||||
base_url: ${env.VLLM_URL:=}
|
||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||
- provider_id: ${env.TGI_URL:+tgi}
|
||||
provider_type: remote::tgi
|
||||
config:
|
||||
url: ${env.TGI_URL:=}
|
||||
base_url: ${env.TGI_URL:=}
|
||||
- provider_id: fireworks
|
||||
provider_type: remote::fireworks
|
||||
config:
|
||||
url: https://api.fireworks.ai/inference/v1
|
||||
base_url: https://api.fireworks.ai/inference/v1
|
||||
api_key: ${env.FIREWORKS_API_KEY:=}
|
||||
- provider_id: together
|
||||
provider_type: remote::together
|
||||
config:
|
||||
url: https://api.together.xyz/v1
|
||||
base_url: https://api.together.xyz/v1
|
||||
api_key: ${env.TOGETHER_API_KEY:=}
|
||||
- provider_id: bedrock
|
||||
provider_type: remote::bedrock
|
||||
config:
|
||||
api_key: ${env.AWS_BEDROCK_API_KEY:=}
|
||||
api_key: ${env.AWS_BEARER_TOKEN_BEDROCK:=}
|
||||
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
|
||||
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
||||
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
|
||||
api_key: ${env.NVIDIA_API_KEY:=}
|
||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
||||
- provider_id: openai
|
||||
provider_type: remote::openai
|
||||
config:
|
||||
|
|
@ -77,18 +76,18 @@ providers:
|
|||
- provider_id: groq
|
||||
provider_type: remote::groq
|
||||
config:
|
||||
url: https://api.groq.com
|
||||
base_url: https://api.groq.com/openai/v1
|
||||
api_key: ${env.GROQ_API_KEY:=}
|
||||
- provider_id: sambanova
|
||||
provider_type: remote::sambanova
|
||||
config:
|
||||
url: https://api.sambanova.ai/v1
|
||||
base_url: https://api.sambanova.ai/v1
|
||||
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||
- provider_id: ${env.AZURE_API_KEY:+azure}
|
||||
provider_type: remote::azure
|
||||
config:
|
||||
api_key: ${env.AZURE_API_KEY:=}
|
||||
api_base: ${env.AZURE_API_BASE:=}
|
||||
base_url: ${env.AZURE_API_BASE:=}
|
||||
api_version: ${env.AZURE_API_VERSION:=}
|
||||
api_type: ${env.AZURE_API_TYPE:=}
|
||||
- provider_id: sentence-transformers
|
||||
|
|
|
|||
|
|
@ -18,44 +18,43 @@ providers:
|
|||
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
||||
provider_type: remote::cerebras
|
||||
config:
|
||||
base_url: https://api.cerebras.ai
|
||||
base_url: https://api.cerebras.ai/v1
|
||||
api_key: ${env.CEREBRAS_API_KEY:=}
|
||||
- provider_id: ${env.OLLAMA_URL:+ollama}
|
||||
provider_type: remote::ollama
|
||||
config:
|
||||
url: ${env.OLLAMA_URL:=http://localhost:11434}
|
||||
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
|
||||
- provider_id: ${env.VLLM_URL:+vllm}
|
||||
provider_type: remote::vllm
|
||||
config:
|
||||
url: ${env.VLLM_URL:=}
|
||||
base_url: ${env.VLLM_URL:=}
|
||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||
- provider_id: ${env.TGI_URL:+tgi}
|
||||
provider_type: remote::tgi
|
||||
config:
|
||||
url: ${env.TGI_URL:=}
|
||||
base_url: ${env.TGI_URL:=}
|
||||
- provider_id: fireworks
|
||||
provider_type: remote::fireworks
|
||||
config:
|
||||
url: https://api.fireworks.ai/inference/v1
|
||||
base_url: https://api.fireworks.ai/inference/v1
|
||||
api_key: ${env.FIREWORKS_API_KEY:=}
|
||||
- provider_id: together
|
||||
provider_type: remote::together
|
||||
config:
|
||||
url: https://api.together.xyz/v1
|
||||
base_url: https://api.together.xyz/v1
|
||||
api_key: ${env.TOGETHER_API_KEY:=}
|
||||
- provider_id: bedrock
|
||||
provider_type: remote::bedrock
|
||||
config:
|
||||
api_key: ${env.AWS_BEDROCK_API_KEY:=}
|
||||
api_key: ${env.AWS_BEARER_TOKEN_BEDROCK:=}
|
||||
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
|
||||
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
||||
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
|
||||
api_key: ${env.NVIDIA_API_KEY:=}
|
||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
||||
- provider_id: openai
|
||||
provider_type: remote::openai
|
||||
config:
|
||||
|
|
@ -77,18 +76,18 @@ providers:
|
|||
- provider_id: groq
|
||||
provider_type: remote::groq
|
||||
config:
|
||||
url: https://api.groq.com
|
||||
base_url: https://api.groq.com/openai/v1
|
||||
api_key: ${env.GROQ_API_KEY:=}
|
||||
- provider_id: sambanova
|
||||
provider_type: remote::sambanova
|
||||
config:
|
||||
url: https://api.sambanova.ai/v1
|
||||
base_url: https://api.sambanova.ai/v1
|
||||
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||
- provider_id: ${env.AZURE_API_KEY:+azure}
|
||||
provider_type: remote::azure
|
||||
config:
|
||||
api_key: ${env.AZURE_API_KEY:=}
|
||||
api_base: ${env.AZURE_API_BASE:=}
|
||||
base_url: ${env.AZURE_API_BASE:=}
|
||||
api_version: ${env.AZURE_API_VERSION:=}
|
||||
api_type: ${env.AZURE_API_TYPE:=}
|
||||
- provider_id: sentence-transformers
|
||||
|
|
@ -169,20 +168,15 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
type: sql_postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
db: ${env.POSTGRES_DB:=llamastack}
|
||||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
responses_store:
|
||||
type: sql_postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
db: ${env.POSTGRES_DB:=llamastack}
|
||||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
persistence:
|
||||
agent_state:
|
||||
namespace: agents
|
||||
backend: kv_default
|
||||
responses:
|
||||
table_name: responses
|
||||
backend: sql_default
|
||||
max_write_queue_size: 10000
|
||||
num_writers: 4
|
||||
post_training:
|
||||
- provider_id: torchtune-cpu
|
||||
provider_type: inline::torchtune-cpu
|
||||
|
|
@ -238,10 +232,10 @@ providers:
|
|||
config:
|
||||
kvstore:
|
||||
namespace: batches
|
||||
backend: kv_postgres
|
||||
backend: kv_default
|
||||
storage:
|
||||
backends:
|
||||
kv_postgres:
|
||||
kv_default:
|
||||
type: kv_postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
|
|
@ -249,7 +243,7 @@ storage:
|
|||
user: ${env.POSTGRES_USER:=llamastack}
|
||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||
table_name: ${env.POSTGRES_TABLE_NAME:=llamastack_kvstore}
|
||||
sql_postgres:
|
||||
sql_default:
|
||||
type: sql_postgres
|
||||
host: ${env.POSTGRES_HOST:=localhost}
|
||||
port: ${env.POSTGRES_PORT:=5432}
|
||||
|
|
@ -259,27 +253,44 @@ storage:
|
|||
stores:
|
||||
metadata:
|
||||
namespace: registry
|
||||
backend: kv_postgres
|
||||
backend: kv_default
|
||||
inference:
|
||||
table_name: inference_store
|
||||
backend: sql_postgres
|
||||
backend: sql_default
|
||||
max_write_queue_size: 10000
|
||||
num_writers: 4
|
||||
conversations:
|
||||
table_name: openai_conversations
|
||||
backend: sql_postgres
|
||||
backend: sql_default
|
||||
prompts:
|
||||
namespace: prompts
|
||||
backend: kv_postgres
|
||||
backend: kv_default
|
||||
registered_resources:
|
||||
models: []
|
||||
shields: []
|
||||
shields:
|
||||
- shield_id: llama-guard
|
||||
provider_id: ${env.SAFETY_MODEL:+llama-guard}
|
||||
provider_shield_id: ${env.SAFETY_MODEL:=}
|
||||
- shield_id: code-scanner
|
||||
provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner}
|
||||
provider_shield_id: ${env.CODE_SCANNER_MODEL:=}
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
benchmarks: []
|
||||
tool_groups: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
server:
|
||||
port: 8321
|
||||
telemetry:
|
||||
enabled: true
|
||||
vector_stores:
|
||||
default_provider_id: faiss
|
||||
default_embedding_model:
|
||||
provider_id: sentence-transformers
|
||||
model_id: nomic-ai/nomic-embed-text-v1.5
|
||||
safety:
|
||||
default_shield_id: llama-guard
|
||||
|
|
|
|||
|
|
@ -18,44 +18,43 @@ providers:
|
|||
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
||||
provider_type: remote::cerebras
|
||||
config:
|
||||
base_url: https://api.cerebras.ai
|
||||
base_url: https://api.cerebras.ai/v1
|
||||
api_key: ${env.CEREBRAS_API_KEY:=}
|
||||
- provider_id: ${env.OLLAMA_URL:+ollama}
|
||||
provider_type: remote::ollama
|
||||
config:
|
||||
url: ${env.OLLAMA_URL:=http://localhost:11434}
|
||||
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
|
||||
- provider_id: ${env.VLLM_URL:+vllm}
|
||||
provider_type: remote::vllm
|
||||
config:
|
||||
url: ${env.VLLM_URL:=}
|
||||
base_url: ${env.VLLM_URL:=}
|
||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||
- provider_id: ${env.TGI_URL:+tgi}
|
||||
provider_type: remote::tgi
|
||||
config:
|
||||
url: ${env.TGI_URL:=}
|
||||
base_url: ${env.TGI_URL:=}
|
||||
- provider_id: fireworks
|
||||
provider_type: remote::fireworks
|
||||
config:
|
||||
url: https://api.fireworks.ai/inference/v1
|
||||
base_url: https://api.fireworks.ai/inference/v1
|
||||
api_key: ${env.FIREWORKS_API_KEY:=}
|
||||
- provider_id: together
|
||||
provider_type: remote::together
|
||||
config:
|
||||
url: https://api.together.xyz/v1
|
||||
base_url: https://api.together.xyz/v1
|
||||
api_key: ${env.TOGETHER_API_KEY:=}
|
||||
- provider_id: bedrock
|
||||
provider_type: remote::bedrock
|
||||
config:
|
||||
api_key: ${env.AWS_BEDROCK_API_KEY:=}
|
||||
api_key: ${env.AWS_BEARER_TOKEN_BEDROCK:=}
|
||||
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
|
||||
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
||||
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
|
||||
api_key: ${env.NVIDIA_API_KEY:=}
|
||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
||||
- provider_id: openai
|
||||
provider_type: remote::openai
|
||||
config:
|
||||
|
|
@ -77,18 +76,18 @@ providers:
|
|||
- provider_id: groq
|
||||
provider_type: remote::groq
|
||||
config:
|
||||
url: https://api.groq.com
|
||||
base_url: https://api.groq.com/openai/v1
|
||||
api_key: ${env.GROQ_API_KEY:=}
|
||||
- provider_id: sambanova
|
||||
provider_type: remote::sambanova
|
||||
config:
|
||||
url: https://api.sambanova.ai/v1
|
||||
base_url: https://api.sambanova.ai/v1
|
||||
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||
- provider_id: ${env.AZURE_API_KEY:+azure}
|
||||
provider_type: remote::azure
|
||||
config:
|
||||
api_key: ${env.AZURE_API_KEY:=}
|
||||
api_base: ${env.AZURE_API_BASE:=}
|
||||
base_url: ${env.AZURE_API_BASE:=}
|
||||
api_version: ${env.AZURE_API_VERSION:=}
|
||||
api_type: ${env.AZURE_API_TYPE:=}
|
||||
- provider_id: sentence-transformers
|
||||
|
|
|
|||
|
|
@ -17,14 +17,10 @@ from llama_stack.core.datatypes import (
|
|||
ToolGroupInput,
|
||||
VectorStoresConfig,
|
||||
)
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
InferenceStoreReference,
|
||||
KVStoreReference,
|
||||
SqlStoreReference,
|
||||
)
|
||||
from llama_stack.core.storage.kvstore.config import PostgresKVStoreConfig
|
||||
from llama_stack.core.storage.sqlstore.sqlstore import PostgresSqlStoreConfig
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
|
||||
from llama_stack.providers.datatypes import RemoteProviderSpec
|
||||
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
|
||||
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||
SentenceTransformersInferenceConfig,
|
||||
|
|
@ -41,8 +37,7 @@ from llama_stack.providers.remote.vector_io.pgvector.config import (
|
|||
)
|
||||
from llama_stack.providers.remote.vector_io.qdrant.config import QdrantVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig
|
||||
from llama_stack.providers.utils.kvstore.config import PostgresKVStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig
|
||||
from llama_stack_api import RemoteProviderSpec
|
||||
|
||||
|
||||
def _get_config_for_provider(provider_spec: ProviderSpec) -> dict[str, Any]:
|
||||
|
|
@ -155,10 +150,11 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
|
|||
BuildProvider(provider_type="inline::reference"),
|
||||
],
|
||||
}
|
||||
files_config = LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}")
|
||||
files_provider = Provider(
|
||||
provider_id="meta-reference-files",
|
||||
provider_type="inline::localfs",
|
||||
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
config=files_config,
|
||||
)
|
||||
embedding_provider = Provider(
|
||||
provider_id="sentence-transformers",
|
||||
|
|
@ -188,7 +184,8 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
|
|||
provider_shield_id="${env.CODE_SCANNER_MODEL:=}",
|
||||
),
|
||||
]
|
||||
postgres_config = PostgresSqlStoreConfig.sample_run_config()
|
||||
postgres_sql_config = PostgresSqlStoreConfig.sample_run_config()
|
||||
postgres_kv_config = PostgresKVStoreConfig.sample_run_config()
|
||||
default_overrides = {
|
||||
"inference": remote_inference_providers + [embedding_provider],
|
||||
"vector_io": [
|
||||
|
|
@ -245,6 +242,33 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
|
|||
"files": [files_provider],
|
||||
}
|
||||
|
||||
base_run_settings = RunConfigSettings(
|
||||
provider_overrides=default_overrides,
|
||||
default_models=[],
|
||||
default_tool_groups=default_tool_groups,
|
||||
default_shields=default_shields,
|
||||
vector_stores_config=VectorStoresConfig(
|
||||
default_provider_id="faiss",
|
||||
default_embedding_model=QualifiedModel(
|
||||
provider_id="sentence-transformers",
|
||||
model_id="nomic-ai/nomic-embed-text-v1.5",
|
||||
),
|
||||
),
|
||||
safety_config=SafetyConfig(
|
||||
default_shield_id="llama-guard",
|
||||
),
|
||||
)
|
||||
|
||||
postgres_run_settings = base_run_settings.model_copy(
|
||||
update={
|
||||
"storage_backends": {
|
||||
"kv_default": postgres_kv_config,
|
||||
"sql_default": postgres_sql_config,
|
||||
}
|
||||
},
|
||||
deep=True,
|
||||
)
|
||||
|
||||
return DistributionTemplate(
|
||||
name=name,
|
||||
distro_type="self_hosted",
|
||||
|
|
@ -254,71 +278,8 @@ def get_distribution_template(name: str = "starter") -> DistributionTemplate:
|
|||
providers=providers,
|
||||
additional_pip_packages=list(set(PostgresSqlStoreConfig.pip_packages() + PostgresKVStoreConfig.pip_packages())),
|
||||
run_configs={
|
||||
"run.yaml": RunConfigSettings(
|
||||
provider_overrides=default_overrides,
|
||||
default_models=[],
|
||||
default_tool_groups=default_tool_groups,
|
||||
default_shields=default_shields,
|
||||
vector_stores_config=VectorStoresConfig(
|
||||
default_provider_id="faiss",
|
||||
default_embedding_model=QualifiedModel(
|
||||
provider_id="sentence-transformers",
|
||||
model_id="nomic-ai/nomic-embed-text-v1.5",
|
||||
),
|
||||
),
|
||||
safety_config=SafetyConfig(
|
||||
default_shield_id="llama-guard",
|
||||
),
|
||||
),
|
||||
"run-with-postgres-store.yaml": RunConfigSettings(
|
||||
provider_overrides={
|
||||
**default_overrides,
|
||||
"agents": [
|
||||
Provider(
|
||||
provider_id="meta-reference",
|
||||
provider_type="inline::meta-reference",
|
||||
config=dict(
|
||||
persistence_store=postgres_config,
|
||||
responses_store=postgres_config,
|
||||
),
|
||||
)
|
||||
],
|
||||
"batches": [
|
||||
Provider(
|
||||
provider_id="reference",
|
||||
provider_type="inline::reference",
|
||||
config=dict(
|
||||
kvstore=KVStoreReference(
|
||||
backend="kv_postgres",
|
||||
namespace="batches",
|
||||
).model_dump(exclude_none=True),
|
||||
),
|
||||
)
|
||||
],
|
||||
},
|
||||
storage_backends={
|
||||
"kv_postgres": PostgresKVStoreConfig.sample_run_config(),
|
||||
"sql_postgres": postgres_config,
|
||||
},
|
||||
storage_stores={
|
||||
"metadata": KVStoreReference(
|
||||
backend="kv_postgres",
|
||||
namespace="registry",
|
||||
).model_dump(exclude_none=True),
|
||||
"inference": InferenceStoreReference(
|
||||
backend="sql_postgres",
|
||||
table_name="inference_store",
|
||||
).model_dump(exclude_none=True),
|
||||
"conversations": SqlStoreReference(
|
||||
backend="sql_postgres",
|
||||
table_name="openai_conversations",
|
||||
).model_dump(exclude_none=True),
|
||||
"prompts": KVStoreReference(
|
||||
backend="kv_postgres",
|
||||
namespace="prompts",
|
||||
).model_dump(exclude_none=True),
|
||||
},
|
||||
),
|
||||
"run.yaml": base_run_settings,
|
||||
"run-with-postgres-store.yaml": postgres_run_settings,
|
||||
},
|
||||
run_config_env_vars={
|
||||
"LLAMA_STACK_PORT": (
|
||||
|
|
|
|||
|
|
@ -12,8 +12,6 @@ import rich
|
|||
import yaml
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.datasets import DatasetPurpose
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.datatypes import (
|
||||
LLAMA_STACK_RUN_CONFIG_VERSION,
|
||||
Api,
|
||||
|
|
@ -37,13 +35,14 @@ from llama_stack.core.storage.datatypes import (
|
|||
SqlStoreReference,
|
||||
StorageBackendType,
|
||||
)
|
||||
from llama_stack.core.storage.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.core.storage.kvstore.config import get_pip_packages as get_kv_pip_packages
|
||||
from llama_stack.core.storage.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
from llama_stack.core.storage.sqlstore.sqlstore import get_pip_packages as get_sql_pip_packages
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.core.utils.image_types import LlamaStackImageType
|
||||
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.config import get_pip_packages as get_kv_pip_packages
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import get_pip_packages as get_sql_pip_packages
|
||||
from llama_stack_api import DatasetPurpose, ModelType
|
||||
|
||||
|
||||
def filter_empty_values(obj: Any) -> Any:
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ providers:
|
|||
- provider_id: watsonx
|
||||
provider_type: remote::watsonx
|
||||
config:
|
||||
url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}
|
||||
base_url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}
|
||||
api_key: ${env.WATSONX_API_KEY:=}
|
||||
project_id: ${env.WATSONX_PROJECT_ID:=}
|
||||
vector_io:
|
||||
|
|
|
|||
|
|
@ -26,8 +26,10 @@ from fairscale.nn.model_parallel.initialize import (
|
|||
)
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.models.llama.datatypes import ToolPromptFormat
|
||||
|
||||
from ..checkpoint import maybe_reshard_state_dict
|
||||
from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage, ToolPromptFormat
|
||||
from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage
|
||||
from .args import ModelArgs
|
||||
from .chat_format import ChatFormat, LLMInput
|
||||
from .model import Transformer
|
||||
|
|
|
|||
|
|
@ -15,13 +15,10 @@ from pathlib import Path
|
|||
|
||||
from termcolor import colored
|
||||
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall, ToolDefinition, ToolPromptFormat
|
||||
|
||||
from ..datatypes import (
|
||||
BuiltinTool,
|
||||
RawMessage,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from . import template_data
|
||||
from .chat_format import ChatFormat
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue