mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-05 02:17:31 +00:00
Merge branch 'main' into add-mongodb-vector_io
This commit is contained in:
commit
d0064fc915
426 changed files with 99110 additions and 62778 deletions
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
|
@ -202,7 +203,7 @@ class OpenAIResponseMessage(BaseModel):
|
|||
scenarios.
|
||||
"""
|
||||
|
||||
content: str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]
|
||||
content: str | Sequence[OpenAIResponseInputMessageContent] | Sequence[OpenAIResponseOutputMessageContent]
|
||||
role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"]
|
||||
type: Literal["message"] = "message"
|
||||
|
||||
|
|
@ -254,10 +255,10 @@ class OpenAIResponseOutputMessageFileSearchToolCall(BaseModel):
|
|||
"""
|
||||
|
||||
id: str
|
||||
queries: list[str]
|
||||
queries: Sequence[str]
|
||||
status: str
|
||||
type: Literal["file_search_call"] = "file_search_call"
|
||||
results: list[OpenAIResponseOutputMessageFileSearchToolCallResults] | None = None
|
||||
results: Sequence[OpenAIResponseOutputMessageFileSearchToolCallResults] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -597,7 +598,7 @@ class OpenAIResponseObject(BaseModel):
|
|||
id: str
|
||||
model: str
|
||||
object: Literal["response"] = "response"
|
||||
output: list[OpenAIResponseOutput]
|
||||
output: Sequence[OpenAIResponseOutput]
|
||||
parallel_tool_calls: bool = False
|
||||
previous_response_id: str | None = None
|
||||
prompt: OpenAIResponsePrompt | None = None
|
||||
|
|
@ -607,7 +608,7 @@ class OpenAIResponseObject(BaseModel):
|
|||
# before the field was added. New responses will have this set always.
|
||||
text: OpenAIResponseText = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text"))
|
||||
top_p: float | None = None
|
||||
tools: list[OpenAIResponseTool] | None = None
|
||||
tools: Sequence[OpenAIResponseTool] | None = None
|
||||
truncation: str | None = None
|
||||
usage: OpenAIResponseUsage | None = None
|
||||
instructions: str | None = None
|
||||
|
|
@ -1315,7 +1316,7 @@ class ListOpenAIResponseInputItem(BaseModel):
|
|||
:param object: Object type identifier, always "list"
|
||||
"""
|
||||
|
||||
data: list[OpenAIResponseInput]
|
||||
data: Sequence[OpenAIResponseInput]
|
||||
object: Literal["list"] = "list"
|
||||
|
||||
|
||||
|
|
@ -1326,7 +1327,7 @@ class OpenAIResponseObjectWithInput(OpenAIResponseObject):
|
|||
:param input: List of input items that led to this response
|
||||
"""
|
||||
|
||||
input: list[OpenAIResponseInput]
|
||||
input: Sequence[OpenAIResponseInput]
|
||||
|
||||
def to_response_object(self) -> OpenAIResponseObject:
|
||||
"""Convert to OpenAIResponseObject by excluding input field."""
|
||||
|
|
@ -1344,7 +1345,7 @@ class ListOpenAIResponseObject(BaseModel):
|
|||
:param object: Object type identifier, always "list"
|
||||
"""
|
||||
|
||||
data: list[OpenAIResponseObjectWithInput]
|
||||
data: Sequence[OpenAIResponseObjectWithInput]
|
||||
has_more: bool
|
||||
first_id: str
|
||||
last_id: str
|
||||
|
|
|
|||
|
|
@ -4,14 +4,21 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Protocol, runtime_checkable
|
||||
from typing import Literal, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.apis.version import (
|
||||
LLAMA_STACK_API_V1,
|
||||
)
|
||||
from llama_stack.providers.datatypes import HealthStatus
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
# Valid values for the route filter parameter.
|
||||
# Actual API levels: v1, v1alpha, v1beta (filters by level, excludes deprecated)
|
||||
# Special filter value: "deprecated" (shows deprecated routes regardless of level)
|
||||
ApiFilter = Literal["v1", "v1alpha", "v1beta", "deprecated"]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RouteInfo(BaseModel):
|
||||
|
|
@ -64,11 +71,12 @@ class Inspect(Protocol):
|
|||
"""
|
||||
|
||||
@webmethod(route="/inspect/routes", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_routes(self) -> ListRoutesResponse:
|
||||
async def list_routes(self, api_filter: ApiFilter | None = None) -> ListRoutesResponse:
|
||||
"""List routes.
|
||||
|
||||
List all available API routes with their methods and implementing providers.
|
||||
|
||||
:param api_filter: Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, returns only non-deprecated v1 routes.
|
||||
:returns: Response containing information about all available routes.
|
||||
"""
|
||||
...
|
||||
|
|
|
|||
|
|
@ -90,12 +90,14 @@ class OpenAIModel(BaseModel):
|
|||
:object: The object type, which will be "model"
|
||||
:created: The Unix timestamp in seconds when the model was created
|
||||
:owned_by: The owner of the model
|
||||
:custom_metadata: Llama Stack-specific metadata including model_type, provider info, and additional metadata
|
||||
"""
|
||||
|
||||
id: str
|
||||
object: Literal["model"] = "model"
|
||||
created: int
|
||||
owned_by: str
|
||||
custom_metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class OpenAIListModelsResponse(BaseModel):
|
||||
|
|
@ -113,7 +115,7 @@ class Models(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/models", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/openai/v1/models", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
||||
"""List models using the OpenAI API.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .synthetic_data_generation import *
|
||||
|
|
@ -1,77 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Protocol
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
class FilteringFunction(Enum):
|
||||
"""The type of filtering function.
|
||||
|
||||
:cvar none: No filtering applied, accept all generated synthetic data
|
||||
:cvar random: Random sampling of generated data points
|
||||
:cvar top_k: Keep only the top-k highest scoring synthetic data samples
|
||||
:cvar top_p: Nucleus-style filtering, keep samples exceeding cumulative score threshold
|
||||
:cvar top_k_top_p: Combined top-k and top-p filtering strategy
|
||||
:cvar sigmoid: Apply sigmoid function for probability-based filtering
|
||||
"""
|
||||
|
||||
none = "none"
|
||||
random = "random"
|
||||
top_k = "top_k"
|
||||
top_p = "top_p"
|
||||
top_k_top_p = "top_k_top_p"
|
||||
sigmoid = "sigmoid"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SyntheticDataGenerationRequest(BaseModel):
|
||||
"""Request to generate synthetic data. A small batch of prompts and a filtering function
|
||||
|
||||
:param dialogs: List of conversation messages to use as input for synthetic data generation
|
||||
:param filtering_function: Type of filtering to apply to generated synthetic data samples
|
||||
:param model: (Optional) The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint
|
||||
"""
|
||||
|
||||
dialogs: list[Message]
|
||||
filtering_function: FilteringFunction = FilteringFunction.none
|
||||
model: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SyntheticDataGenerationResponse(BaseModel):
|
||||
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold.
|
||||
|
||||
:param synthetic_data: List of generated synthetic data samples that passed the filtering criteria
|
||||
:param statistics: (Optional) Statistical information about the generation process and filtering results
|
||||
"""
|
||||
|
||||
synthetic_data: list[dict[str, Any]]
|
||||
statistics: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class SyntheticDataGeneration(Protocol):
|
||||
@webmethod(route="/synthetic-data-generation/generate", level=LLAMA_STACK_API_V1)
|
||||
def synthetic_data_generate(
|
||||
self,
|
||||
dialogs: list[Message],
|
||||
filtering_function: FilteringFunction = FilteringFunction.none,
|
||||
model: str | None = None,
|
||||
) -> SyntheticDataGenerationResponse:
|
||||
"""Generate synthetic data based on input dialogs and apply filtering.
|
||||
|
||||
:param dialogs: List of conversation messages to use as input for synthetic data generation
|
||||
:param filtering_function: Type of filtering to apply to generated synthetic data samples
|
||||
:param model: (Optional) The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint
|
||||
:returns: Response containing filtered synthetic data samples and optional statistics
|
||||
"""
|
||||
...
|
||||
|
|
@ -8,7 +8,6 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import uuid
|
||||
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
|
||||
|
||||
from fastapi import Body
|
||||
|
|
@ -18,7 +17,6 @@ from llama_stack.apis.inference import InterleavedContent
|
|||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
from llama_stack.strong_typing.schema import register_schema
|
||||
|
||||
|
|
@ -61,38 +59,19 @@ class Chunk(BaseModel):
|
|||
"""
|
||||
A chunk of content that can be inserted into a vector database.
|
||||
:param content: The content of the chunk, which can be interleaved text, images, or other types.
|
||||
:param embedding: Optional embedding for the chunk. If not provided, it will be computed later.
|
||||
:param chunk_id: Unique identifier for the chunk. Must be provided explicitly.
|
||||
:param metadata: Metadata associated with the chunk that will be used in the model context during inference.
|
||||
:param stored_chunk_id: The chunk ID that is stored in the vector database. Used for backend functionality.
|
||||
:param embedding: Optional embedding for the chunk. If not provided, it will be computed later.
|
||||
:param chunk_metadata: Metadata for the chunk that will NOT be used in the context during inference.
|
||||
The `chunk_metadata` is required backend functionality.
|
||||
"""
|
||||
|
||||
content: InterleavedContent
|
||||
chunk_id: str
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
embedding: list[float] | None = None
|
||||
# The alias parameter serializes the field as "chunk_id" in JSON but keeps the internal name as "stored_chunk_id"
|
||||
stored_chunk_id: str | None = Field(default=None, alias="chunk_id")
|
||||
chunk_metadata: ChunkMetadata | None = None
|
||||
|
||||
model_config = {"populate_by_name": True}
|
||||
|
||||
def model_post_init(self, __context):
|
||||
# Extract chunk_id from metadata if present
|
||||
if self.metadata and "chunk_id" in self.metadata:
|
||||
self.stored_chunk_id = self.metadata.pop("chunk_id")
|
||||
|
||||
@property
|
||||
def chunk_id(self) -> str:
|
||||
"""Returns the chunk ID, which is either an input `chunk_id` or a generated one if not set."""
|
||||
if self.stored_chunk_id:
|
||||
return self.stored_chunk_id
|
||||
|
||||
if "document_id" in self.metadata:
|
||||
return generate_chunk_id(self.metadata["document_id"], str(self.content))
|
||||
|
||||
return generate_chunk_id(str(uuid.uuid4()), str(self.content))
|
||||
|
||||
@property
|
||||
def document_id(self) -> str | None:
|
||||
"""Returns the document_id from either metadata or chunk_metadata, with metadata taking precedence."""
|
||||
|
|
|
|||
|
|
@ -8,16 +8,30 @@ import argparse
|
|||
import os
|
||||
import ssl
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import uvicorn
|
||||
import yaml
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.cli.stack.utils import ImageType
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.core.datatypes import Api, Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_provider_registry
|
||||
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
InferenceStoreReference,
|
||||
KVStoreReference,
|
||||
ServerStoresConfig,
|
||||
SqliteKVStoreConfig,
|
||||
SqliteSqlStoreConfig,
|
||||
SqlStoreReference,
|
||||
StorageConfig,
|
||||
)
|
||||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.log import LoggingConfig, get_logger
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
|
@ -68,6 +82,12 @@ class StackRun(Subcommand):
|
|||
action="store_true",
|
||||
help="Start the UI server",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--providers",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Run a stack with only a list of providers. This list is formatted like: api1=provider1,api1=provider2,api2=provider3. Where there can be multiple providers per API.",
|
||||
)
|
||||
|
||||
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
||||
import yaml
|
||||
|
|
@ -93,6 +113,55 @@ class StackRun(Subcommand):
|
|||
config_file = resolve_config_or_distro(args.config, Mode.RUN)
|
||||
except ValueError as e:
|
||||
self.parser.error(str(e))
|
||||
elif args.providers:
|
||||
provider_list: dict[str, list[Provider]] = dict()
|
||||
for api_provider in args.providers.split(","):
|
||||
if "=" not in api_provider:
|
||||
cprint(
|
||||
"Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2",
|
||||
color="red",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
api, provider_type = api_provider.split("=")
|
||||
providers_for_api = get_provider_registry().get(Api(api), None)
|
||||
if providers_for_api is None:
|
||||
cprint(
|
||||
f"{api} is not a valid API.",
|
||||
color="red",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
if provider_type in providers_for_api:
|
||||
config_type = instantiate_class_type(providers_for_api[provider_type].config_class)
|
||||
if config_type is not None and hasattr(config_type, "sample_run_config"):
|
||||
config = config_type.sample_run_config(__distro_dir__="~/.llama/distributions/providers-run")
|
||||
else:
|
||||
config = {}
|
||||
provider = Provider(
|
||||
provider_type=provider_type,
|
||||
config=config,
|
||||
provider_id=provider_type.split("::")[1],
|
||||
)
|
||||
provider_list.setdefault(api, []).append(provider)
|
||||
else:
|
||||
cprint(
|
||||
f"{provider} is not a valid provider for the {api} API.",
|
||||
color="red",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
run_config = self._generate_run_config_from_providers(providers=provider_list)
|
||||
config_dict = run_config.model_dump(mode="json")
|
||||
|
||||
# Write config to disk in providers-run directory
|
||||
distro_dir = DISTRIBS_BASE_DIR / "providers-run"
|
||||
config_file = distro_dir / "run.yaml"
|
||||
|
||||
logger.info(f"Writing generated config to: {config_file}")
|
||||
with open(config_file, "w") as f:
|
||||
yaml.dump(config_dict, f, default_flow_style=False, sort_keys=False)
|
||||
|
||||
else:
|
||||
config_file = None
|
||||
|
||||
|
|
@ -106,7 +175,8 @@ class StackRun(Subcommand):
|
|||
|
||||
try:
|
||||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
if not os.path.exists(str(config.external_providers_dir)):
|
||||
# Create external_providers_dir if it's specified and doesn't exist
|
||||
if config.external_providers_dir and not os.path.exists(str(config.external_providers_dir)):
|
||||
os.makedirs(str(config.external_providers_dir), exist_ok=True)
|
||||
except AttributeError as e:
|
||||
self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
|
||||
|
|
@ -127,7 +197,7 @@ class StackRun(Subcommand):
|
|||
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
|
||||
|
||||
port = args.port or config.server.port
|
||||
host = config.server.host or ["::", "0.0.0.0"]
|
||||
host = config.server.host or "0.0.0.0"
|
||||
|
||||
# Set the config file in environment so create_app can find it
|
||||
os.environ["LLAMA_STACK_CONFIG"] = str(config_file)
|
||||
|
|
@ -139,6 +209,7 @@ class StackRun(Subcommand):
|
|||
"lifespan": "on",
|
||||
"log_level": logger.getEffectiveLevel(),
|
||||
"log_config": logger_config,
|
||||
"workers": config.server.workers,
|
||||
}
|
||||
|
||||
keyfile = config.server.tls_keyfile
|
||||
|
|
@ -212,3 +283,44 @@ class StackRun(Subcommand):
|
|||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start UI development server in {ui_dir}: {e}")
|
||||
|
||||
def _generate_run_config_from_providers(self, providers: dict[str, list[Provider]]):
|
||||
apis = list(providers.keys())
|
||||
distro_dir = DISTRIBS_BASE_DIR / "providers-run"
|
||||
# need somewhere to put the storage.
|
||||
os.makedirs(distro_dir, exist_ok=True)
|
||||
storage = StorageConfig(
|
||||
backends={
|
||||
"kv_default": SqliteKVStoreConfig(
|
||||
db_path=f"${{env.SQLITE_STORE_DIR:={distro_dir}}}/kvstore.db",
|
||||
),
|
||||
"sql_default": SqliteSqlStoreConfig(
|
||||
db_path=f"${{env.SQLITE_STORE_DIR:={distro_dir}}}/sql_store.db",
|
||||
),
|
||||
},
|
||||
stores=ServerStoresConfig(
|
||||
metadata=KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="registry",
|
||||
),
|
||||
inference=InferenceStoreReference(
|
||||
backend="sql_default",
|
||||
table_name="inference_store",
|
||||
),
|
||||
conversations=SqlStoreReference(
|
||||
backend="sql_default",
|
||||
table_name="openai_conversations",
|
||||
),
|
||||
prompts=KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="prompts",
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
return StackRunConfig(
|
||||
image_name="providers-run",
|
||||
apis=apis,
|
||||
providers=providers,
|
||||
storage=storage,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ from llama_stack.core.distribution import (
|
|||
get_provider_registry,
|
||||
)
|
||||
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
|
||||
from llama_stack.core.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.core.utils.prompt_for_config import prompt_for_config
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -194,19 +193,11 @@ def upgrade_from_routing_table(
|
|||
|
||||
|
||||
def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfig:
|
||||
version = config_dict.get("version", None)
|
||||
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
|
||||
processed_config_dict = replace_env_vars(config_dict)
|
||||
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))
|
||||
|
||||
if "routing_table" in config_dict:
|
||||
logger.info("Upgrading config...")
|
||||
config_dict = upgrade_from_routing_table(config_dict)
|
||||
|
||||
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
|
||||
if not config_dict.get("external_providers_dir", None):
|
||||
config_dict["external_providers_dir"] = EXTERNAL_PROVIDERS_DIR
|
||||
|
||||
processed_config_dict = replace_env_vars(config_dict)
|
||||
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))
|
||||
|
|
|
|||
|
|
@ -473,6 +473,10 @@ class ServerConfig(BaseModel):
|
|||
"- true: Enable localhost CORS for development\n"
|
||||
"- {allow_origins: [...], allow_methods: [...], ...}: Full configuration",
|
||||
)
|
||||
workers: int = Field(
|
||||
default=1,
|
||||
description="Number of workers to use for the server",
|
||||
)
|
||||
|
||||
|
||||
class StackRunConfig(BaseModel):
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from llama_stack.apis.inspect import (
|
|||
RouteInfo,
|
||||
VersionInfo,
|
||||
)
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.core.external import load_external_apis
|
||||
from llama_stack.core.server.routes import get_all_api_routes
|
||||
|
|
@ -39,9 +40,21 @@ class DistributionInspectImpl(Inspect):
|
|||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_routes(self) -> ListRoutesResponse:
|
||||
async def list_routes(self, api_filter: str | None = None) -> ListRoutesResponse:
|
||||
run_config: StackRunConfig = self.config.run_config
|
||||
|
||||
# Helper function to determine if a route should be included based on api_filter
|
||||
def should_include_route(webmethod) -> bool:
|
||||
if api_filter is None:
|
||||
# Default: only non-deprecated v1 APIs
|
||||
return not webmethod.deprecated and webmethod.level == LLAMA_STACK_API_V1
|
||||
elif api_filter == "deprecated":
|
||||
# Special filter: show deprecated routes regardless of their actual level
|
||||
return bool(webmethod.deprecated)
|
||||
else:
|
||||
# Filter by API level (non-deprecated routes only)
|
||||
return not webmethod.deprecated and webmethod.level == api_filter
|
||||
|
||||
ret = []
|
||||
external_apis = load_external_apis(run_config)
|
||||
all_endpoints = get_all_api_routes(external_apis)
|
||||
|
|
@ -55,8 +68,8 @@ class DistributionInspectImpl(Inspect):
|
|||
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
||||
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
|
||||
)
|
||||
for e, _ in endpoints
|
||||
if e.methods is not None
|
||||
for e, webmethod in endpoints
|
||||
if e.methods is not None and should_include_route(webmethod)
|
||||
]
|
||||
)
|
||||
else:
|
||||
|
|
@ -69,8 +82,8 @@ class DistributionInspectImpl(Inspect):
|
|||
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
||||
provider_types=[p.provider_type for p in providers],
|
||||
)
|
||||
for e, _ in endpoints
|
||||
if e.methods is not None
|
||||
for e, webmethod in endpoints
|
||||
if e.methods is not None and should_include_route(webmethod)
|
||||
]
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import UTC, datetime
|
||||
from typing import Annotated, Any
|
||||
|
||||
|
|
@ -15,20 +15,10 @@ from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatC
|
|||
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
)
|
||||
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
Inference,
|
||||
ListOpenAIChatCompletionResponse,
|
||||
Message,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
|
|
@ -45,15 +35,13 @@ from llama_stack.apis.inference import (
|
|||
OpenAIMessageParam,
|
||||
Order,
|
||||
RerankResponse,
|
||||
StopReason,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.core.telemetry.telemetry import MetricEvent, MetricInResponse
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.telemetry.telemetry import MetricEvent
|
||||
from llama_stack.core.telemetry.tracing import enqueue_event, get_current_span
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||
|
|
@ -153,35 +141,6 @@ class InferenceRouter(Inference):
|
|||
)
|
||||
return metric_events
|
||||
|
||||
async def _compute_and_log_token_usage(
|
||||
self,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
model: Model,
|
||||
) -> list[MetricInResponse]:
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens, completion_tokens, total_tokens, model.model_id, model.provider_id
|
||||
)
|
||||
if self.telemetry_enabled:
|
||||
for metric in metrics:
|
||||
enqueue_event(metric)
|
||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
||||
|
||||
async def _count_tokens(
|
||||
self,
|
||||
messages: list[Message] | InterleavedContent,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
) -> int | None:
|
||||
if not hasattr(self, "formatter") or self.formatter is None:
|
||||
return None
|
||||
|
||||
if isinstance(messages, list):
|
||||
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
|
||||
else:
|
||||
encoded = self.formatter.encode_content(messages)
|
||||
return len(encoded.tokens) if encoded and encoded.tokens else 0
|
||||
|
||||
async def _get_model_provider(self, model_id: str, expected_model_type: str) -> tuple[Inference, str]:
|
||||
model = await self.routing_table.get_object_by_identifier("model", model_id)
|
||||
if model:
|
||||
|
|
@ -375,121 +334,6 @@ class InferenceRouter(Inference):
|
|||
)
|
||||
return health_statuses
|
||||
|
||||
async def stream_tokens_and_compute_metrics(
|
||||
self,
|
||||
response,
|
||||
prompt_tokens,
|
||||
fully_qualified_model_id: str,
|
||||
provider_id: str,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None] | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||
completion_text = ""
|
||||
async for chunk in response:
|
||||
complete = False
|
||||
if hasattr(chunk, "event"): # only ChatCompletions have .event
|
||||
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
|
||||
if chunk.event.delta.type == "text":
|
||||
completion_text += chunk.event.delta.text
|
||||
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
|
||||
complete = True
|
||||
completion_tokens = await self._count_tokens(
|
||||
[
|
||||
CompletionMessage(
|
||||
content=completion_text,
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
],
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
)
|
||||
else:
|
||||
if hasattr(chunk, "delta"):
|
||||
completion_text += chunk.delta
|
||||
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry_enabled:
|
||||
complete = True
|
||||
completion_tokens = await self._count_tokens(completion_text)
|
||||
# if we are done receiving tokens
|
||||
if complete:
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
|
||||
# Create a separate span for streaming completion metrics
|
||||
if self.telemetry_enabled:
|
||||
# Log metrics in the new span context
|
||||
completion_metrics = self._construct_metrics(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
fully_qualified_model_id=fully_qualified_model_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
for metric in completion_metrics:
|
||||
if metric.metric in [
|
||||
"completion_tokens",
|
||||
"total_tokens",
|
||||
]: # Only log completion and total tokens
|
||||
enqueue_event(metric)
|
||||
|
||||
# Return metrics in response
|
||||
async_metrics = [
|
||||
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
|
||||
]
|
||||
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
|
||||
else:
|
||||
# Fallback if no telemetry
|
||||
completion_metrics = self._construct_metrics(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
fully_qualified_model_id=fully_qualified_model_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
async_metrics = [
|
||||
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
|
||||
]
|
||||
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
|
||||
yield chunk
|
||||
|
||||
async def count_tokens_and_compute_metrics(
|
||||
self,
|
||||
response: ChatCompletionResponse | CompletionResponse,
|
||||
prompt_tokens,
|
||||
fully_qualified_model_id: str,
|
||||
provider_id: str,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
):
|
||||
if isinstance(response, ChatCompletionResponse):
|
||||
content = [response.completion_message]
|
||||
else:
|
||||
content = response.content
|
||||
completion_tokens = await self._count_tokens(messages=content, tool_prompt_format=tool_prompt_format)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
|
||||
# Create a separate span for completion metrics
|
||||
if self.telemetry_enabled:
|
||||
# Log metrics in the new span context
|
||||
completion_metrics = self._construct_metrics(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
fully_qualified_model_id=fully_qualified_model_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
for metric in completion_metrics:
|
||||
if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens
|
||||
enqueue_event(metric)
|
||||
|
||||
# Return metrics in response
|
||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics]
|
||||
|
||||
# Fallback if no telemetry
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
fully_qualified_model_id=fully_qualified_model_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
||||
|
||||
async def stream_tokens_and_compute_metrics_openai_chat(
|
||||
self,
|
||||
response: AsyncIterator[OpenAIChatCompletionChunk],
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@ from llama_stack.core.datatypes import (
|
|||
ModelWithOwner,
|
||||
RegistryEntrySource,
|
||||
)
|
||||
from llama_stack.core.request_headers import PROVIDER_DATA_VAR, NeedsRequestProviderData
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .common import CommonRoutingTableImpl, lookup_model
|
||||
|
|
@ -42,19 +44,104 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
|
||||
await self.update_registered_models(provider_id, models)
|
||||
|
||||
async def _get_dynamic_models_from_provider_data(self) -> list[Model]:
|
||||
"""
|
||||
Fetch models from providers that have credentials in the current request's provider_data.
|
||||
|
||||
This allows users to see models available to them from providers that require
|
||||
per-request API keys (via X-LlamaStack-Provider-Data header).
|
||||
|
||||
Returns models with fully qualified identifiers (provider_id/model_id) but does NOT
|
||||
cache them in the registry since they are user-specific.
|
||||
"""
|
||||
provider_data = PROVIDER_DATA_VAR.get()
|
||||
if not provider_data:
|
||||
return []
|
||||
|
||||
dynamic_models = []
|
||||
|
||||
for provider_id, provider in self.impls_by_provider_id.items():
|
||||
# Check if this provider supports provider_data
|
||||
if not isinstance(provider, NeedsRequestProviderData):
|
||||
continue
|
||||
|
||||
# Check if provider has a validator (some providers like ollama don't need per-request credentials)
|
||||
spec = getattr(provider, "__provider_spec__", None)
|
||||
if not spec or not getattr(spec, "provider_data_validator", None):
|
||||
continue
|
||||
|
||||
# Validate provider_data silently - we're speculatively checking all providers
|
||||
# so validation failures are expected when user didn't provide keys for this provider
|
||||
try:
|
||||
validator = instantiate_class_type(spec.provider_data_validator)
|
||||
validator(**provider_data)
|
||||
except Exception:
|
||||
# User didn't provide credentials for this provider - skip silently
|
||||
continue
|
||||
|
||||
# Validation succeeded! User has credentials for this provider
|
||||
# Now try to list models
|
||||
try:
|
||||
models = await provider.list_models()
|
||||
if not models:
|
||||
continue
|
||||
|
||||
# Ensure models have fully qualified identifiers with provider_id prefix
|
||||
for model in models:
|
||||
# Only add prefix if model identifier doesn't already have it
|
||||
if not model.identifier.startswith(f"{provider_id}/"):
|
||||
model.identifier = f"{provider_id}/{model.provider_resource_id}"
|
||||
|
||||
dynamic_models.append(model)
|
||||
|
||||
logger.debug(f"Fetched {len(models)} models from provider {provider_id} using provider_data")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to list models from provider {provider_id} with provider_data: {e}")
|
||||
continue
|
||||
|
||||
return dynamic_models
|
||||
|
||||
async def list_models(self) -> ListModelsResponse:
|
||||
return ListModelsResponse(data=await self.get_all_with_type("model"))
|
||||
# Get models from registry
|
||||
registry_models = await self.get_all_with_type("model")
|
||||
|
||||
# Get additional models available via provider_data (user-specific, not cached)
|
||||
dynamic_models = await self._get_dynamic_models_from_provider_data()
|
||||
|
||||
# Combine, avoiding duplicates (registry takes precedence)
|
||||
registry_identifiers = {m.identifier for m in registry_models}
|
||||
unique_dynamic_models = [m for m in dynamic_models if m.identifier not in registry_identifiers]
|
||||
|
||||
return ListModelsResponse(data=registry_models + unique_dynamic_models)
|
||||
|
||||
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
||||
models = await self.get_all_with_type("model")
|
||||
# Get models from registry
|
||||
registry_models = await self.get_all_with_type("model")
|
||||
|
||||
# Get additional models available via provider_data (user-specific, not cached)
|
||||
dynamic_models = await self._get_dynamic_models_from_provider_data()
|
||||
|
||||
# Combine, avoiding duplicates (registry takes precedence)
|
||||
registry_identifiers = {m.identifier for m in registry_models}
|
||||
unique_dynamic_models = [m for m in dynamic_models if m.identifier not in registry_identifiers]
|
||||
|
||||
all_models = registry_models + unique_dynamic_models
|
||||
|
||||
openai_models = [
|
||||
OpenAIModel(
|
||||
id=model.identifier,
|
||||
object="model",
|
||||
created=int(time.time()),
|
||||
owned_by="llama_stack",
|
||||
custom_metadata={
|
||||
"model_type": model.model_type,
|
||||
"provider_id": model.provider_id,
|
||||
"provider_resource_id": model.provider_resource_id,
|
||||
**model.metadata,
|
||||
},
|
||||
)
|
||||
for model in models
|
||||
for model in all_models
|
||||
]
|
||||
return OpenAIListModelsResponse(data=openai_models)
|
||||
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from typing import Any
|
|||
import yaml
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.batches import Batches
|
||||
from llama_stack.apis.benchmarks import Benchmarks
|
||||
from llama_stack.apis.conversations import Conversations
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
|
|
@ -30,7 +31,6 @@ from llama_stack.apis.safety import Safety
|
|||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||
from llama_stack.apis.shields import Shields
|
||||
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
|
||||
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
|
||||
|
|
@ -63,8 +63,8 @@ class LlamaStack(
|
|||
Providers,
|
||||
Inference,
|
||||
Agents,
|
||||
Batches,
|
||||
Safety,
|
||||
SyntheticDataGeneration,
|
||||
Datasets,
|
||||
PostTraining,
|
||||
VectorIO,
|
||||
|
|
|
|||
|
|
@ -152,6 +152,37 @@ docker run \
|
|||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
||||
### Via Docker with Custom Run Configuration
|
||||
|
||||
You can also run the Docker container with a custom run configuration file by mounting it into the container:
|
||||
|
||||
```bash
|
||||
# Set the path to your custom run.yaml file
|
||||
CUSTOM_RUN_CONFIG=/path/to/your/custom-run.yaml
|
||||
|
||||
docker run -it \
|
||||
--pull always \
|
||||
--network host \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v $HOME/.llama:/root/.llama \
|
||||
-v $CUSTOM_RUN_CONFIG:/app/custom-run.yaml \
|
||||
-e RUN_CONFIG_PATH=/app/custom-run.yaml \
|
||||
-e INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
-e DEH_URL=$DEH_URL \
|
||||
-e CHROMA_URL=$CHROMA_URL \
|
||||
llamastack/distribution-{{ name }} \
|
||||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
||||
**Note**: The run configuration must be mounted into the container before it can be used. The `-v` flag mounts your local file into the container, and the `RUN_CONFIG_PATH` environment variable tells the entrypoint script which configuration to use.
|
||||
|
||||
{% if run_configs %}
|
||||
Available run configurations for this distribution:
|
||||
{% for config in run_configs %}
|
||||
- `{{ config }}`
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
### Via Conda
|
||||
|
||||
Make sure you have done `pip install llama-stack` and have the Llama Stack CLI available.
|
||||
|
|
|
|||
|
|
@ -68,6 +68,36 @@ docker run \
|
|||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
||||
### Via Docker with Custom Run Configuration
|
||||
|
||||
You can also run the Docker container with a custom run configuration file by mounting it into the container:
|
||||
|
||||
```bash
|
||||
# Set the path to your custom run.yaml file
|
||||
CUSTOM_RUN_CONFIG=/path/to/your/custom-run.yaml
|
||||
LLAMA_STACK_PORT=8321
|
||||
|
||||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
--gpu all \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ~/.llama:/root/.llama \
|
||||
-v $CUSTOM_RUN_CONFIG:/app/custom-run.yaml \
|
||||
-e RUN_CONFIG_PATH=/app/custom-run.yaml \
|
||||
llamastack/distribution-{{ name }} \
|
||||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
||||
**Note**: The run configuration must be mounted into the container before it can be used. The `-v` flag mounts your local file into the container, and the `RUN_CONFIG_PATH` environment variable tells the entrypoint script which configuration to use.
|
||||
|
||||
{% if run_configs %}
|
||||
Available run configurations for this distribution:
|
||||
{% for config in run_configs %}
|
||||
- `{{ config }}`
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
### Via venv
|
||||
|
||||
Make sure you have the Llama Stack CLI available.
|
||||
|
|
|
|||
|
|
@ -117,13 +117,42 @@ docker run \
|
|||
-it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ./run.yaml:/root/my-run.yaml \
|
||||
-v ~/.llama:/root/.llama \
|
||||
-e NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||
llamastack/distribution-{{ name }} \
|
||||
--config /root/my-run.yaml \
|
||||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
||||
### Via Docker with Custom Run Configuration
|
||||
|
||||
You can also run the Docker container with a custom run configuration file by mounting it into the container:
|
||||
|
||||
```bash
|
||||
# Set the path to your custom run.yaml file
|
||||
CUSTOM_RUN_CONFIG=/path/to/your/custom-run.yaml
|
||||
LLAMA_STACK_PORT=8321
|
||||
|
||||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ~/.llama:/root/.llama \
|
||||
-v $CUSTOM_RUN_CONFIG:/app/custom-run.yaml \
|
||||
-e RUN_CONFIG_PATH=/app/custom-run.yaml \
|
||||
-e NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||
llamastack/distribution-{{ name }} \
|
||||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
||||
**Note**: The run configuration must be mounted into the container before it can be used. The `-v` flag mounts your local file into the container, and the `RUN_CONFIG_PATH` environment variable tells the entrypoint script which configuration to use.
|
||||
|
||||
{% if run_configs %}
|
||||
Available run configurations for this distribution:
|
||||
{% for config in run_configs %}
|
||||
- `{{ config }}`
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
### Via venv
|
||||
|
||||
If you've set up your local development environment, you can also install the distribution dependencies using your local virtual environment.
|
||||
|
|
|
|||
|
|
@ -424,6 +424,7 @@ class DistributionTemplate(BaseModel):
|
|||
providers_table=providers_table,
|
||||
run_config_env_vars=self.run_config_env_vars,
|
||||
default_models=default_models,
|
||||
run_configs=list(self.run_configs.keys()),
|
||||
)
|
||||
return ""
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import uuid
|
|||
import warnings
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, cast
|
||||
|
||||
import httpx
|
||||
|
||||
|
|
@ -125,12 +126,12 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
|
||||
def turn_to_messages(self, turn: Turn) -> list[Message]:
|
||||
messages = []
|
||||
messages: list[Message] = []
|
||||
|
||||
# NOTE: if a toolcall response is in a step, we do not add it when processing the input messages
|
||||
tool_call_ids = set()
|
||||
for step in turn.steps:
|
||||
if step.step_type == StepType.tool_execution.value:
|
||||
if step.step_type == StepType.tool_execution.value and isinstance(step, ToolExecutionStep):
|
||||
for response in step.tool_responses:
|
||||
tool_call_ids.add(response.call_id)
|
||||
|
||||
|
|
@ -149,9 +150,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
messages.append(msg)
|
||||
|
||||
for step in turn.steps:
|
||||
if step.step_type == StepType.inference.value:
|
||||
if step.step_type == StepType.inference.value and isinstance(step, InferenceStep):
|
||||
messages.append(step.model_response)
|
||||
elif step.step_type == StepType.tool_execution.value:
|
||||
elif step.step_type == StepType.tool_execution.value and isinstance(step, ToolExecutionStep):
|
||||
for response in step.tool_responses:
|
||||
messages.append(
|
||||
ToolResponseMessage(
|
||||
|
|
@ -159,8 +160,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
content=response.content,
|
||||
)
|
||||
)
|
||||
elif step.step_type == StepType.shield_call.value:
|
||||
if step.violation:
|
||||
elif step.step_type == StepType.shield_call.value and isinstance(step, ShieldCallStep):
|
||||
if step.violation and step.violation.user_message:
|
||||
# CompletionMessage itself in the ShieldResponse
|
||||
messages.append(
|
||||
CompletionMessage(
|
||||
|
|
@ -174,7 +175,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
return await self.storage.create_session(name)
|
||||
|
||||
async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]:
|
||||
messages = []
|
||||
messages: list[Message] = []
|
||||
if self.agent_config.instructions != "":
|
||||
messages.append(SystemMessage(content=self.agent_config.instructions))
|
||||
|
||||
|
|
@ -231,7 +232,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
steps = []
|
||||
messages = await self.get_messages_from_turns(turns)
|
||||
|
||||
if is_resume:
|
||||
assert isinstance(request, AgentTurnResumeRequest)
|
||||
tool_response_messages = [
|
||||
ToolResponseMessage(call_id=x.call_id, content=x.content) for x in request.tool_responses
|
||||
]
|
||||
|
|
@ -252,42 +255,52 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
|
||||
request.session_id, request.turn_id
|
||||
)
|
||||
now = datetime.now(UTC).isoformat()
|
||||
now_dt = datetime.now(UTC)
|
||||
tool_execution_step = ToolExecutionStep(
|
||||
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
|
||||
turn_id=request.turn_id,
|
||||
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
|
||||
tool_responses=request.tool_responses,
|
||||
completed_at=now,
|
||||
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
|
||||
completed_at=now_dt,
|
||||
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now_dt),
|
||||
)
|
||||
steps.append(tool_execution_step)
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_type=StepType.tool_execution,
|
||||
step_id=tool_execution_step.step_id,
|
||||
step_details=tool_execution_step,
|
||||
)
|
||||
)
|
||||
)
|
||||
input_messages = last_turn.input_messages
|
||||
# Cast needed due to list invariance - last_turn.input_messages is the right type
|
||||
input_messages = last_turn.input_messages # type: ignore[assignment]
|
||||
|
||||
turn_id = request.turn_id
|
||||
actual_turn_id = request.turn_id
|
||||
start_time = last_turn.started_at
|
||||
else:
|
||||
assert isinstance(request, AgentTurnCreateRequest)
|
||||
messages.extend(request.messages)
|
||||
start_time = datetime.now(UTC).isoformat()
|
||||
input_messages = request.messages
|
||||
start_time = datetime.now(UTC)
|
||||
# Cast needed due to list invariance - request.messages is the right type
|
||||
input_messages = request.messages # type: ignore[assignment]
|
||||
# Use the generated turn_id from beginning of function
|
||||
actual_turn_id = turn_id if turn_id else str(uuid.uuid4())
|
||||
|
||||
output_message = None
|
||||
req_documents = request.documents if isinstance(request, AgentTurnCreateRequest) and not is_resume else None
|
||||
req_sampling = (
|
||||
self.agent_config.sampling_params if self.agent_config.sampling_params is not None else SamplingParams()
|
||||
)
|
||||
|
||||
async for chunk in self.run(
|
||||
session_id=request.session_id,
|
||||
turn_id=turn_id,
|
||||
turn_id=actual_turn_id,
|
||||
input_messages=messages,
|
||||
sampling_params=self.agent_config.sampling_params,
|
||||
sampling_params=req_sampling,
|
||||
stream=request.stream,
|
||||
documents=request.documents if not is_resume else None,
|
||||
documents=req_documents,
|
||||
):
|
||||
if isinstance(chunk, CompletionMessage):
|
||||
output_message = chunk
|
||||
|
|
@ -295,20 +308,23 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
|
||||
event = chunk.event
|
||||
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
|
||||
steps.append(event.payload.step_details)
|
||||
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value and hasattr(
|
||||
event.payload, "step_details"
|
||||
):
|
||||
step_details = event.payload.step_details
|
||||
steps.append(step_details)
|
||||
|
||||
yield chunk
|
||||
|
||||
assert output_message is not None
|
||||
|
||||
turn = Turn(
|
||||
turn_id=turn_id,
|
||||
turn_id=actual_turn_id,
|
||||
session_id=request.session_id,
|
||||
input_messages=input_messages,
|
||||
input_messages=input_messages, # type: ignore[arg-type]
|
||||
output_message=output_message,
|
||||
started_at=start_time,
|
||||
completed_at=datetime.now(UTC).isoformat(),
|
||||
completed_at=datetime.now(UTC),
|
||||
steps=steps,
|
||||
)
|
||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||
|
|
@ -345,9 +361,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
# return a "final value" for the `yield from` statement. we simulate that by yielding a
|
||||
# final boolean (to see whether an exception happened) and then explicitly testing for it.
|
||||
|
||||
if len(self.input_shields) > 0:
|
||||
if self.input_shields:
|
||||
async for res in self.run_multiple_shields_wrapper(
|
||||
turn_id, input_messages, self.input_shields, "user-input"
|
||||
turn_id, cast(list[OpenAIMessageParam], input_messages), self.input_shields, "user-input"
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
|
|
@ -374,9 +390,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
# for output shields run on the full input and output combination
|
||||
messages = input_messages + [final_response]
|
||||
|
||||
if len(self.output_shields) > 0:
|
||||
if self.output_shields:
|
||||
async for res in self.run_multiple_shields_wrapper(
|
||||
turn_id, messages, self.output_shields, "assistant-output"
|
||||
turn_id, cast(list[OpenAIMessageParam], messages), self.output_shields, "assistant-output"
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
|
|
@ -388,7 +404,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
async def run_multiple_shields_wrapper(
|
||||
self,
|
||||
turn_id: str,
|
||||
messages: list[Message],
|
||||
messages: list[OpenAIMessageParam],
|
||||
shields: list[str],
|
||||
touchpoint: str,
|
||||
) -> AsyncGenerator:
|
||||
|
|
@ -402,12 +418,12 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
return
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
shield_call_start_time = datetime.now(UTC).isoformat()
|
||||
shield_call_start_time = datetime.now(UTC)
|
||||
try:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_type=StepType.shield_call,
|
||||
step_id=step_id,
|
||||
metadata=dict(touchpoint=touchpoint),
|
||||
)
|
||||
|
|
@ -419,14 +435,14 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_type=StepType.shield_call,
|
||||
step_id=step_id,
|
||||
step_details=ShieldCallStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
violation=e.violation,
|
||||
started_at=shield_call_start_time,
|
||||
completed_at=datetime.now(UTC).isoformat(),
|
||||
completed_at=datetime.now(UTC),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
|
@ -443,14 +459,14 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_type=StepType.shield_call,
|
||||
step_id=step_id,
|
||||
step_details=ShieldCallStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
violation=None,
|
||||
started_at=shield_call_start_time,
|
||||
completed_at=datetime.now(UTC).isoformat(),
|
||||
completed_at=datetime.now(UTC),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
|
@ -496,21 +512,22 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
else:
|
||||
self.tool_name_to_args[tool_name]["vector_store_ids"].append(session_info.vector_store_id)
|
||||
|
||||
output_attachments = []
|
||||
output_attachments: list[Attachment] = []
|
||||
|
||||
n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
|
||||
|
||||
# Build a map of custom tools to their definitions for faster lookup
|
||||
client_tools = {}
|
||||
for tool in self.agent_config.client_tools:
|
||||
client_tools[tool.name] = tool
|
||||
if self.agent_config.client_tools:
|
||||
for tool in self.agent_config.client_tools:
|
||||
client_tools[tool.name] = tool
|
||||
while True:
|
||||
step_id = str(uuid.uuid4())
|
||||
inference_start_time = datetime.now(UTC).isoformat()
|
||||
inference_start_time = datetime.now(UTC)
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
step_type=StepType.inference.value,
|
||||
step_type=StepType.inference,
|
||||
step_id=step_id,
|
||||
)
|
||||
)
|
||||
|
|
@ -538,7 +555,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
else:
|
||||
return value
|
||||
|
||||
def _add_type(openai_msg: dict) -> OpenAIMessageParam:
|
||||
def _add_type(openai_msg: Any) -> OpenAIMessageParam:
|
||||
# Serialize any nested Pydantic models to plain dicts
|
||||
openai_msg = _serialize_nested(openai_msg)
|
||||
|
||||
|
|
@ -588,7 +605,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
messages=openai_messages,
|
||||
tools=openai_tools if openai_tools else None,
|
||||
tool_choice=tool_choice,
|
||||
response_format=self.agent_config.response_format,
|
||||
response_format=self.agent_config.response_format, # type: ignore[arg-type]
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_tokens=max_tokens,
|
||||
|
|
@ -598,7 +615,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
# Convert OpenAI stream back to Llama Stack format
|
||||
response_stream = convert_openai_chat_completion_stream(
|
||||
openai_stream, enable_incremental_tool_calls=True
|
||||
openai_stream, # type: ignore[arg-type]
|
||||
enable_incremental_tool_calls=True,
|
||||
)
|
||||
|
||||
async for chunk in response_stream:
|
||||
|
|
@ -620,7 +638,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.inference.value,
|
||||
step_type=StepType.inference,
|
||||
step_id=step_id,
|
||||
delta=delta,
|
||||
)
|
||||
|
|
@ -633,7 +651,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.inference.value,
|
||||
step_type=StepType.inference,
|
||||
step_id=step_id,
|
||||
delta=delta,
|
||||
)
|
||||
|
|
@ -651,7 +669,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
output_attr = json.dumps(
|
||||
{
|
||||
"content": content,
|
||||
"tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls],
|
||||
"tool_calls": [
|
||||
json.loads(t.model_dump_json()) for t in tool_calls if isinstance(t, ToolCall)
|
||||
],
|
||||
}
|
||||
)
|
||||
span.set_attribute("output", output_attr)
|
||||
|
|
@ -667,16 +687,18 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
if tool_calls:
|
||||
content = ""
|
||||
|
||||
# Filter out string tool calls for CompletionMessage (only keep ToolCall objects)
|
||||
valid_tool_calls = [t for t in tool_calls if isinstance(t, ToolCall)]
|
||||
message = CompletionMessage(
|
||||
content=content,
|
||||
stop_reason=stop_reason,
|
||||
tool_calls=tool_calls,
|
||||
tool_calls=valid_tool_calls if valid_tool_calls else None,
|
||||
)
|
||||
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.inference.value,
|
||||
step_type=StepType.inference,
|
||||
step_id=step_id,
|
||||
step_details=InferenceStep(
|
||||
# somewhere deep, we are re-assigning message or closing over some
|
||||
|
|
@ -686,13 +708,14 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
turn_id=turn_id,
|
||||
model_response=copy.deepcopy(message),
|
||||
started_at=inference_start_time,
|
||||
completed_at=datetime.now(UTC).isoformat(),
|
||||
completed_at=datetime.now(UTC),
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if n_iter >= self.agent_config.max_infer_iters:
|
||||
max_iters = self.agent_config.max_infer_iters if self.agent_config.max_infer_iters is not None else 10
|
||||
if n_iter >= max_iters:
|
||||
logger.info(f"done with MAX iterations ({n_iter}), exiting.")
|
||||
# NOTE: mark end_of_turn to indicate to client that we are done with the turn
|
||||
# Do not continue the tool call loop after this point
|
||||
|
|
@ -705,14 +728,16 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield message
|
||||
break
|
||||
|
||||
if len(message.tool_calls) == 0:
|
||||
if not message.tool_calls or len(message.tool_calls) == 0:
|
||||
if stop_reason == StopReason.end_of_turn:
|
||||
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
|
||||
if len(output_attachments) > 0:
|
||||
if isinstance(message.content, list):
|
||||
message.content += output_attachments
|
||||
# List invariance - attachments are compatible at runtime
|
||||
message.content += output_attachments # type: ignore[arg-type]
|
||||
else:
|
||||
message.content = [message.content] + output_attachments
|
||||
# List invariance - attachments are compatible at runtime
|
||||
message.content = [message.content] + output_attachments # type: ignore[assignment]
|
||||
yield message
|
||||
else:
|
||||
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
|
||||
|
|
@ -725,11 +750,12 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
non_client_tool_calls = []
|
||||
|
||||
# Separate client and non-client tool calls
|
||||
for tool_call in message.tool_calls:
|
||||
if tool_call.tool_name in client_tools:
|
||||
client_tool_calls.append(tool_call)
|
||||
else:
|
||||
non_client_tool_calls.append(tool_call)
|
||||
if message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
if tool_call.tool_name in client_tools:
|
||||
client_tool_calls.append(tool_call)
|
||||
else:
|
||||
non_client_tool_calls.append(tool_call)
|
||||
|
||||
# Process non-client tool calls first
|
||||
for tool_call in non_client_tool_calls:
|
||||
|
|
@ -737,7 +763,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_type=StepType.tool_execution,
|
||||
step_id=step_id,
|
||||
)
|
||||
)
|
||||
|
|
@ -746,7 +772,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_type=StepType.tool_execution,
|
||||
step_id=step_id,
|
||||
delta=ToolCallDelta(
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
|
|
@ -766,7 +792,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
if self.telemetry_enabled
|
||||
else {},
|
||||
) as span:
|
||||
tool_execution_start_time = datetime.now(UTC).isoformat()
|
||||
tool_execution_start_time = datetime.now(UTC)
|
||||
tool_result = await self.execute_tool_call_maybe(
|
||||
session_id,
|
||||
tool_call,
|
||||
|
|
@ -796,14 +822,14 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
],
|
||||
started_at=tool_execution_start_time,
|
||||
completed_at=datetime.now(UTC).isoformat(),
|
||||
completed_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
# Yield the step completion event
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_type=StepType.tool_execution,
|
||||
step_id=step_id,
|
||||
step_details=tool_execution_step,
|
||||
)
|
||||
|
|
@ -833,7 +859,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
turn_id=turn_id,
|
||||
tool_calls=client_tool_calls,
|
||||
tool_responses=[],
|
||||
started_at=datetime.now(UTC).isoformat(),
|
||||
started_at=datetime.now(UTC),
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -868,19 +894,20 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
toolgroup_to_args = toolgroup_to_args or {}
|
||||
|
||||
tool_name_to_def = {}
|
||||
tool_name_to_def: dict[str, ToolDefinition] = {}
|
||||
tool_name_to_args = {}
|
||||
|
||||
for tool_def in self.agent_config.client_tools:
|
||||
if tool_name_to_def.get(tool_def.name, None):
|
||||
raise ValueError(f"Tool {tool_def.name} already exists")
|
||||
if self.agent_config.client_tools:
|
||||
for tool_def in self.agent_config.client_tools:
|
||||
if tool_name_to_def.get(tool_def.name, None):
|
||||
raise ValueError(f"Tool {tool_def.name} already exists")
|
||||
|
||||
# Use input_schema from ToolDef directly
|
||||
tool_name_to_def[tool_def.name] = ToolDefinition(
|
||||
tool_name=tool_def.name,
|
||||
description=tool_def.description,
|
||||
input_schema=tool_def.input_schema,
|
||||
)
|
||||
# Use input_schema from ToolDef directly
|
||||
tool_name_to_def[tool_def.name] = ToolDefinition(
|
||||
tool_name=tool_def.name,
|
||||
description=tool_def.description,
|
||||
input_schema=tool_def.input_schema,
|
||||
)
|
||||
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
|
||||
toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
|
||||
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
|
||||
|
|
@ -908,15 +935,17 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
else:
|
||||
identifier = None
|
||||
|
||||
if tool_name_to_def.get(identifier, None):
|
||||
raise ValueError(f"Tool {identifier} already exists")
|
||||
if identifier:
|
||||
tool_name_to_def[identifier] = ToolDefinition(
|
||||
tool_name=identifier,
|
||||
# Convert BuiltinTool to string for dictionary key
|
||||
identifier_str = identifier.value if isinstance(identifier, BuiltinTool) else identifier
|
||||
if tool_name_to_def.get(identifier_str, None):
|
||||
raise ValueError(f"Tool {identifier_str} already exists")
|
||||
tool_name_to_def[identifier_str] = ToolDefinition(
|
||||
tool_name=identifier_str,
|
||||
description=tool_def.description,
|
||||
input_schema=tool_def.input_schema,
|
||||
)
|
||||
tool_name_to_args[identifier] = toolgroup_to_args.get(toolgroup_name, {})
|
||||
tool_name_to_args[identifier_str] = toolgroup_to_args.get(toolgroup_name, {})
|
||||
|
||||
self.tool_defs, self.tool_name_to_args = (
|
||||
list(tool_name_to_def.values()),
|
||||
|
|
@ -966,14 +995,17 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Failed to parse arguments for tool call: {tool_call.arguments}") from e
|
||||
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name=tool_name_str,
|
||||
kwargs={
|
||||
"session_id": session_id,
|
||||
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
|
||||
**args,
|
||||
**self.tool_name_to_args.get(tool_name_str, {}),
|
||||
},
|
||||
result = cast(
|
||||
ToolInvocationResult,
|
||||
await self.tool_runtime_api.invoke_tool(
|
||||
tool_name=tool_name_str,
|
||||
kwargs={
|
||||
"session_id": session_id,
|
||||
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
|
||||
**args,
|
||||
**self.tool_name_to_args.get(tool_name_str, {}),
|
||||
},
|
||||
),
|
||||
)
|
||||
logger.debug(f"tool call {tool_name_str} completed with result: {result}")
|
||||
return result
|
||||
|
|
@ -983,7 +1015,7 @@ async def load_data_from_url(url: str) -> str:
|
|||
if url.startswith("http"):
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(url)
|
||||
resp = r.text
|
||||
resp: str = r.text
|
||||
return resp
|
||||
raise ValueError(f"Unexpected URL: {type(url)}")
|
||||
|
||||
|
|
@ -1017,7 +1049,7 @@ def _interpret_content_as_attachment(
|
|||
snippet = match.group(1)
|
||||
data = json.loads(snippet)
|
||||
return Attachment(
|
||||
url=URL(uri="file://" + data["filepath"]),
|
||||
content=URL(uri="file://" + data["filepath"]),
|
||||
mime_type=data["mimetype"],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from llama_stack.apis.agents import (
|
|||
Document,
|
||||
ListOpenAIResponseInputItem,
|
||||
ListOpenAIResponseObject,
|
||||
OpenAIDeleteResponseObject,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseObject,
|
||||
|
|
@ -141,7 +142,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
persistence_store=(
|
||||
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
|
||||
),
|
||||
created_at=agent_info.created_at,
|
||||
created_at=agent_info.created_at.isoformat(),
|
||||
policy=self.policy,
|
||||
telemetry_enabled=self.telemetry_enabled,
|
||||
)
|
||||
|
|
@ -163,9 +164,9 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
agent_id: str,
|
||||
session_id: str,
|
||||
messages: list[UserMessage | ToolResponseMessage],
|
||||
toolgroups: list[AgentToolGroup] | None = None,
|
||||
documents: list[Document] | None = None,
|
||||
stream: bool | None = False,
|
||||
documents: list[Document] | None = None,
|
||||
toolgroups: list[AgentToolGroup] | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
request = AgentTurnCreateRequest(
|
||||
|
|
@ -221,6 +222,8 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
turn = await agent.storage.get_session_turn(session_id, turn_id)
|
||||
if turn is None:
|
||||
raise ValueError(f"Turn {turn_id} not found in session {session_id}")
|
||||
return turn
|
||||
|
||||
async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse:
|
||||
|
|
@ -232,13 +235,15 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
|
||||
async def get_agents_session(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
agent_id: str,
|
||||
turn_ids: list[str] | None = None,
|
||||
) -> Session:
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
|
||||
session_info = await agent.storage.get_session_info(session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
turns = await agent.storage.get_session_turns(session_id)
|
||||
if turn_ids:
|
||||
turns = [turn for turn in turns if turn.turn_id in turn_ids]
|
||||
|
|
@ -249,7 +254,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
started_at=session_info.started_at,
|
||||
)
|
||||
|
||||
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
|
||||
async def delete_agents_session(self, session_id: str, agent_id: str) -> None:
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
|
||||
# Delete turns first, then the session
|
||||
|
|
@ -302,7 +307,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
agent = Agent(
|
||||
agent_id=agent_id,
|
||||
agent_config=chat_agent.agent_config,
|
||||
created_at=chat_agent.created_at,
|
||||
created_at=datetime.fromisoformat(chat_agent.created_at),
|
||||
)
|
||||
return agent
|
||||
|
||||
|
|
@ -323,6 +328,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self,
|
||||
response_id: str,
|
||||
) -> OpenAIResponseObject:
|
||||
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
|
||||
return await self.openai_responses_impl.get_openai_response(response_id)
|
||||
|
||||
async def create_openai_response(
|
||||
|
|
@ -342,7 +348,8 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
max_infer_iters: int | None = 10,
|
||||
guardrails: list[ResponseGuardrail] | None = None,
|
||||
) -> OpenAIResponseObject:
|
||||
return await self.openai_responses_impl.create_openai_response(
|
||||
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
|
||||
result = await self.openai_responses_impl.create_openai_response(
|
||||
input,
|
||||
model,
|
||||
prompt,
|
||||
|
|
@ -358,6 +365,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
max_infer_iters,
|
||||
guardrails,
|
||||
)
|
||||
return result # type: ignore[no-any-return]
|
||||
|
||||
async def list_openai_responses(
|
||||
self,
|
||||
|
|
@ -366,6 +374,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
model: str | None = None,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseObject:
|
||||
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
|
||||
return await self.openai_responses_impl.list_openai_responses(after, limit, model, order)
|
||||
|
||||
async def list_openai_response_input_items(
|
||||
|
|
@ -377,9 +386,11 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
limit: int | None = 20,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseInputItem:
|
||||
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
|
||||
return await self.openai_responses_impl.list_openai_response_input_items(
|
||||
response_id, after, before, include, limit, order
|
||||
)
|
||||
|
||||
async def delete_openai_response(self, response_id: str) -> None:
|
||||
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
|
||||
return await self.openai_responses_impl.delete_openai_response(response_id)
|
||||
|
|
|
|||
|
|
@ -6,12 +6,14 @@
|
|||
|
||||
import json
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
|
||||
from llama_stack.apis.common.errors import SessionNotFoundError
|
||||
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||
from llama_stack.core.access_control.datatypes import AccessRule
|
||||
from llama_stack.core.access_control.conditions import User as ProtocolUser
|
||||
from llama_stack.core.access_control.datatypes import AccessRule, Action
|
||||
from llama_stack.core.datatypes import User
|
||||
from llama_stack.core.request_headers import get_authenticated_user
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -33,6 +35,15 @@ class AgentInfo(AgentConfig):
|
|||
created_at: datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionResource:
|
||||
"""Concrete implementation of ProtectedResource for session access control."""
|
||||
|
||||
type: str
|
||||
identifier: str
|
||||
owner: ProtocolUser # Use the protocol type for structural compatibility
|
||||
|
||||
|
||||
class AgentPersistence:
|
||||
def __init__(self, agent_id: str, kvstore: KVStore, policy: list[AccessRule]):
|
||||
self.agent_id = agent_id
|
||||
|
|
@ -53,8 +64,15 @@ class AgentPersistence:
|
|||
turns=[],
|
||||
identifier=name, # should this be qualified in any way?
|
||||
)
|
||||
if not is_action_allowed(self.policy, "create", session_info, user):
|
||||
raise AccessDeniedError("create", session_info, user)
|
||||
# Only perform access control if we have an authenticated user
|
||||
if user is not None and session_info.identifier is not None:
|
||||
resource = SessionResource(
|
||||
type=session_info.type,
|
||||
identifier=session_info.identifier,
|
||||
owner=user,
|
||||
)
|
||||
if not is_action_allowed(self.policy, Action.CREATE, resource, user):
|
||||
raise AccessDeniedError(Action.CREATE, resource, user)
|
||||
|
||||
await self.kvstore.set(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
|
|
@ -62,7 +80,7 @@ class AgentPersistence:
|
|||
)
|
||||
return session_id
|
||||
|
||||
async def get_session_info(self, session_id: str) -> AgentSessionInfo:
|
||||
async def get_session_info(self, session_id: str) -> AgentSessionInfo | None:
|
||||
value = await self.kvstore.get(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
)
|
||||
|
|
@ -83,7 +101,22 @@ class AgentPersistence:
|
|||
if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"):
|
||||
return True
|
||||
|
||||
return is_action_allowed(self.policy, "read", session_info, get_authenticated_user())
|
||||
# Get current user - if None, skip access control (e.g., in tests)
|
||||
user = get_authenticated_user()
|
||||
if user is None:
|
||||
return True
|
||||
|
||||
# Access control requires identifier and owner to be set
|
||||
if session_info.identifier is None or session_info.owner is None:
|
||||
return True
|
||||
|
||||
# At this point, both identifier and owner are guaranteed to be non-None
|
||||
resource = SessionResource(
|
||||
type=session_info.type,
|
||||
identifier=session_info.identifier,
|
||||
owner=session_info.owner,
|
||||
)
|
||||
return is_action_allowed(self.policy, Action.READ, resource, user)
|
||||
|
||||
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
|
||||
"""Get session info if the user has access to it. For internal use by sub-session methods."""
|
||||
|
|
|
|||
|
|
@ -91,7 +91,8 @@ class OpenAIResponsesImpl:
|
|||
input: str | list[OpenAIResponseInput],
|
||||
previous_response: _OpenAIResponseObjectWithInputAndMessages,
|
||||
):
|
||||
new_input_items = previous_response.input.copy()
|
||||
# Convert Sequence to list for mutation
|
||||
new_input_items = list(previous_response.input)
|
||||
new_input_items.extend(previous_response.output)
|
||||
|
||||
if isinstance(input, str):
|
||||
|
|
@ -107,7 +108,7 @@ class OpenAIResponsesImpl:
|
|||
tools: list[OpenAIResponseInputTool] | None,
|
||||
previous_response_id: str | None,
|
||||
conversation: str | None,
|
||||
) -> tuple[str | list[OpenAIResponseInput], list[OpenAIMessageParam]]:
|
||||
) -> tuple[str | list[OpenAIResponseInput], list[OpenAIMessageParam], ToolContext]:
|
||||
"""Process input with optional previous response context.
|
||||
|
||||
Returns:
|
||||
|
|
@ -208,6 +209,9 @@ class OpenAIResponsesImpl:
|
|||
messages: list[OpenAIMessageParam],
|
||||
) -> None:
|
||||
new_input_id = f"msg_{uuid.uuid4()}"
|
||||
# Type input_items_data as the full OpenAIResponseInput union to avoid list invariance issues
|
||||
input_items_data: list[OpenAIResponseInput] = []
|
||||
|
||||
if isinstance(input, str):
|
||||
# synthesize a message from the input string
|
||||
input_content = OpenAIResponseInputMessageContentText(text=input)
|
||||
|
|
@ -219,7 +223,6 @@ class OpenAIResponsesImpl:
|
|||
input_items_data = [input_content_item]
|
||||
else:
|
||||
# we already have a list of messages
|
||||
input_items_data = []
|
||||
for input_item in input:
|
||||
if isinstance(input_item, OpenAIResponseMessage):
|
||||
# These may or may not already have an id, so dump to dict, check for id, and add if missing
|
||||
|
|
@ -251,7 +254,7 @@ class OpenAIResponsesImpl:
|
|||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
include: list[str] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
guardrails: list[ResponseGuardrailSpec] | None = None,
|
||||
guardrails: list[str | ResponseGuardrailSpec] | None = None,
|
||||
):
|
||||
stream = bool(stream)
|
||||
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
|
||||
|
|
@ -289,16 +292,19 @@ class OpenAIResponsesImpl:
|
|||
failed_response = None
|
||||
|
||||
async for stream_chunk in stream_gen:
|
||||
if stream_chunk.type in {"response.completed", "response.incomplete"}:
|
||||
if final_response is not None:
|
||||
raise ValueError(
|
||||
"The response stream produced multiple terminal responses! "
|
||||
f"Earlier response from {final_event_type}"
|
||||
)
|
||||
final_response = stream_chunk.response
|
||||
final_event_type = stream_chunk.type
|
||||
elif stream_chunk.type == "response.failed":
|
||||
failed_response = stream_chunk.response
|
||||
match stream_chunk.type:
|
||||
case "response.completed" | "response.incomplete":
|
||||
if final_response is not None:
|
||||
raise ValueError(
|
||||
"The response stream produced multiple terminal responses! "
|
||||
f"Earlier response from {final_event_type}"
|
||||
)
|
||||
final_response = stream_chunk.response
|
||||
final_event_type = stream_chunk.type
|
||||
case "response.failed":
|
||||
failed_response = stream_chunk.response
|
||||
case _:
|
||||
pass # Other event types don't have .response
|
||||
|
||||
if failed_response is not None:
|
||||
error_message = (
|
||||
|
|
@ -326,6 +332,11 @@ class OpenAIResponsesImpl:
|
|||
max_infer_iters: int | None = 10,
|
||||
guardrail_ids: list[str] | None = None,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# These should never be None when called from create_openai_response (which sets defaults)
|
||||
# but we assert here to help mypy understand the types
|
||||
assert text is not None, "text must not be None"
|
||||
assert max_infer_iters is not None, "max_infer_iters must not be None"
|
||||
|
||||
# Input preprocessing
|
||||
all_input, messages, tool_context = await self._process_input_with_previous_response(
|
||||
input, tools, previous_response_id, conversation
|
||||
|
|
@ -368,16 +379,19 @@ class OpenAIResponsesImpl:
|
|||
final_response = None
|
||||
failed_response = None
|
||||
|
||||
output_items = []
|
||||
# Type as ConversationItem to avoid list invariance issues
|
||||
output_items: list[ConversationItem] = []
|
||||
async for stream_chunk in orchestrator.create_response():
|
||||
if stream_chunk.type in {"response.completed", "response.incomplete"}:
|
||||
final_response = stream_chunk.response
|
||||
elif stream_chunk.type == "response.failed":
|
||||
failed_response = stream_chunk.response
|
||||
|
||||
if stream_chunk.type == "response.output_item.done":
|
||||
item = stream_chunk.item
|
||||
output_items.append(item)
|
||||
match stream_chunk.type:
|
||||
case "response.completed" | "response.incomplete":
|
||||
final_response = stream_chunk.response
|
||||
case "response.failed":
|
||||
failed_response = stream_chunk.response
|
||||
case "response.output_item.done":
|
||||
item = stream_chunk.item
|
||||
output_items.append(item)
|
||||
case _:
|
||||
pass # Other event types
|
||||
|
||||
# Store and sync before yielding terminal events
|
||||
# This ensures the storage/syncing happens even if the consumer breaks after receiving the event
|
||||
|
|
@ -410,7 +424,8 @@ class OpenAIResponsesImpl:
|
|||
self, conversation_id: str, input: str | list[OpenAIResponseInput] | None, output_items: list[ConversationItem]
|
||||
) -> None:
|
||||
"""Sync content and response messages to the conversation."""
|
||||
conversation_items = []
|
||||
# Type as ConversationItem union to avoid list invariance issues
|
||||
conversation_items: list[ConversationItem] = []
|
||||
|
||||
if isinstance(input, str):
|
||||
conversation_items.append(
|
||||
|
|
|
|||
|
|
@ -111,7 +111,7 @@ class StreamingResponseOrchestrator:
|
|||
text: OpenAIResponseText,
|
||||
max_infer_iters: int,
|
||||
tool_executor, # Will be the tool execution logic from the main class
|
||||
instructions: str,
|
||||
instructions: str | None,
|
||||
safety_api,
|
||||
guardrail_ids: list[str] | None = None,
|
||||
prompt: OpenAIResponsePrompt | None = None,
|
||||
|
|
@ -128,7 +128,9 @@ class StreamingResponseOrchestrator:
|
|||
self.prompt = prompt
|
||||
self.sequence_number = 0
|
||||
# Store MCP tool mapping that gets built during tool processing
|
||||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = ctx.tool_context.previous_tools or {}
|
||||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = (
|
||||
ctx.tool_context.previous_tools if ctx.tool_context else {}
|
||||
)
|
||||
# Track final messages after all tool executions
|
||||
self.final_messages: list[OpenAIMessageParam] = []
|
||||
# mapping for annotations
|
||||
|
|
@ -229,7 +231,8 @@ class StreamingResponseOrchestrator:
|
|||
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model=self.ctx.model,
|
||||
messages=messages,
|
||||
tools=self.ctx.chat_tools,
|
||||
# Pydantic models are dict-compatible but mypy treats them as distinct types
|
||||
tools=self.ctx.chat_tools, # type: ignore[arg-type]
|
||||
stream=True,
|
||||
temperature=self.ctx.temperature,
|
||||
response_format=response_format,
|
||||
|
|
@ -272,7 +275,12 @@ class StreamingResponseOrchestrator:
|
|||
|
||||
# Handle choices with no tool calls
|
||||
for choice in current_response.choices:
|
||||
if not (choice.message.tool_calls and self.ctx.response_tools):
|
||||
has_tool_calls = (
|
||||
isinstance(choice.message, OpenAIAssistantMessageParam)
|
||||
and choice.message.tool_calls
|
||||
and self.ctx.response_tools
|
||||
)
|
||||
if not has_tool_calls:
|
||||
output_messages.append(
|
||||
await convert_chat_choice_to_response_message(
|
||||
choice,
|
||||
|
|
@ -722,7 +730,10 @@ class StreamingResponseOrchestrator:
|
|||
)
|
||||
|
||||
# Accumulate arguments for final response (only for subsequent chunks)
|
||||
if not is_new_tool_call:
|
||||
if not is_new_tool_call and response_tool_call is not None:
|
||||
# Both should have functions since we're inside the tool_call.function check above
|
||||
assert response_tool_call.function is not None
|
||||
assert tool_call.function is not None
|
||||
response_tool_call.function.arguments = (
|
||||
response_tool_call.function.arguments or ""
|
||||
) + tool_call.function.arguments
|
||||
|
|
@ -747,10 +758,13 @@ class StreamingResponseOrchestrator:
|
|||
for tool_call_index in sorted(chat_response_tool_calls.keys()):
|
||||
tool_call = chat_response_tool_calls[tool_call_index]
|
||||
# Ensure that arguments, if sent back to the inference provider, are not None
|
||||
tool_call.function.arguments = tool_call.function.arguments or "{}"
|
||||
if tool_call.function:
|
||||
tool_call.function.arguments = tool_call.function.arguments or "{}"
|
||||
tool_call_item_id = tool_call_item_ids[tool_call_index]
|
||||
final_arguments = tool_call.function.arguments
|
||||
tool_call_name = chat_response_tool_calls[tool_call_index].function.name
|
||||
final_arguments: str = tool_call.function.arguments or "{}" if tool_call.function else "{}"
|
||||
func = chat_response_tool_calls[tool_call_index].function
|
||||
|
||||
tool_call_name = func.name if func else ""
|
||||
|
||||
# Check if this is an MCP tool call
|
||||
is_mcp_tool = tool_call_name and tool_call_name in self.mcp_tool_to_server
|
||||
|
|
@ -894,12 +908,11 @@ class StreamingResponseOrchestrator:
|
|||
|
||||
self.sequence_number += 1
|
||||
if tool_call.function.name and tool_call.function.name in self.mcp_tool_to_server:
|
||||
item = OpenAIResponseOutputMessageMCPCall(
|
||||
item: OpenAIResponseOutput = OpenAIResponseOutputMessageMCPCall(
|
||||
arguments="",
|
||||
name=tool_call.function.name,
|
||||
id=matching_item_id,
|
||||
server_label=self.mcp_tool_to_server[tool_call.function.name].server_label,
|
||||
status="in_progress",
|
||||
)
|
||||
elif tool_call.function.name == "web_search":
|
||||
item = OpenAIResponseOutputMessageWebSearchToolCall(
|
||||
|
|
@ -1008,7 +1021,7 @@ class StreamingResponseOrchestrator:
|
|||
description=tool.description,
|
||||
input_schema=tool.input_schema,
|
||||
)
|
||||
return convert_tooldef_to_openai_tool(tool_def)
|
||||
return convert_tooldef_to_openai_tool(tool_def) # type: ignore[return-value] # Returns dict but ChatCompletionToolParam expects TypedDict
|
||||
|
||||
# Initialize chat_tools if not already set
|
||||
if self.ctx.chat_tools is None:
|
||||
|
|
@ -1016,7 +1029,7 @@ class StreamingResponseOrchestrator:
|
|||
|
||||
for input_tool in tools:
|
||||
if input_tool.type == "function":
|
||||
self.ctx.chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump()))
|
||||
self.ctx.chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump())) # type: ignore[typeddict-item,arg-type] # Dict compatible with FunctionDefinition
|
||||
elif input_tool.type in WebSearchToolTypes:
|
||||
tool_name = "web_search"
|
||||
# Need to access tool_groups_api from tool_executor
|
||||
|
|
@ -1055,8 +1068,8 @@ class StreamingResponseOrchestrator:
|
|||
if isinstance(mcp_tool.allowed_tools, list):
|
||||
always_allowed = mcp_tool.allowed_tools
|
||||
elif isinstance(mcp_tool.allowed_tools, AllowedToolsFilter):
|
||||
always_allowed = mcp_tool.allowed_tools.always
|
||||
never_allowed = mcp_tool.allowed_tools.never
|
||||
# AllowedToolsFilter only has tool_names field (not allowed/disallowed)
|
||||
always_allowed = mcp_tool.allowed_tools.tool_names
|
||||
|
||||
# Call list_mcp_tools
|
||||
tool_defs = None
|
||||
|
|
@ -1088,7 +1101,7 @@ class StreamingResponseOrchestrator:
|
|||
openai_tool = convert_tooldef_to_chat_tool(t)
|
||||
if self.ctx.chat_tools is None:
|
||||
self.ctx.chat_tools = []
|
||||
self.ctx.chat_tools.append(openai_tool)
|
||||
self.ctx.chat_tools.append(openai_tool) # type: ignore[arg-type] # Returns dict but ChatCompletionToolParam expects TypedDict
|
||||
|
||||
# Add to MCP tool mapping
|
||||
if t.name in self.mcp_tool_to_server:
|
||||
|
|
@ -1120,13 +1133,17 @@ class StreamingResponseOrchestrator:
|
|||
self, output_messages: list[OpenAIResponseOutput]
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Handle all mcp tool lists from previous response that are still valid:
|
||||
for tool in self.ctx.tool_context.previous_tool_listings:
|
||||
async for evt in self._reuse_mcp_list_tools(tool, output_messages):
|
||||
yield evt
|
||||
# Process all remaining tools (including MCP tools) and emit streaming events
|
||||
if self.ctx.tool_context.tools_to_process:
|
||||
async for stream_event in self._process_new_tools(self.ctx.tool_context.tools_to_process, output_messages):
|
||||
yield stream_event
|
||||
# tool_context can be None when no tools are provided in the response request
|
||||
if self.ctx.tool_context:
|
||||
for tool in self.ctx.tool_context.previous_tool_listings:
|
||||
async for evt in self._reuse_mcp_list_tools(tool, output_messages):
|
||||
yield evt
|
||||
# Process all remaining tools (including MCP tools) and emit streaming events
|
||||
if self.ctx.tool_context.tools_to_process:
|
||||
async for stream_event in self._process_new_tools(
|
||||
self.ctx.tool_context.tools_to_process, output_messages
|
||||
):
|
||||
yield stream_event
|
||||
|
||||
def _approval_required(self, tool_name: str) -> bool:
|
||||
if tool_name not in self.mcp_tool_to_server:
|
||||
|
|
@ -1220,7 +1237,7 @@ class StreamingResponseOrchestrator:
|
|||
openai_tool = convert_tooldef_to_openai_tool(tool_def)
|
||||
if self.ctx.chat_tools is None:
|
||||
self.ctx.chat_tools = []
|
||||
self.ctx.chat_tools.append(openai_tool)
|
||||
self.ctx.chat_tools.append(openai_tool) # type: ignore[arg-type] # Returns dict but ChatCompletionToolParam expects TypedDict
|
||||
|
||||
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
|
||||
id=f"mcp_list_{uuid.uuid4()}",
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
import asyncio
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInputToolFileSearch,
|
||||
|
|
@ -22,6 +23,7 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseObjectStreamResponseWebSearchCallSearching,
|
||||
OpenAIResponseOutputMessageFileSearchToolCall,
|
||||
OpenAIResponseOutputMessageFileSearchToolCallResults,
|
||||
OpenAIResponseOutputMessageMCPCall,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
)
|
||||
from llama_stack.apis.common.content_types import (
|
||||
|
|
@ -67,7 +69,7 @@ class ToolExecutor:
|
|||
) -> AsyncIterator[ToolExecutionResult]:
|
||||
tool_call_id = tool_call.id
|
||||
function = tool_call.function
|
||||
tool_kwargs = json.loads(function.arguments) if function.arguments else {}
|
||||
tool_kwargs = json.loads(function.arguments) if function and function.arguments else {}
|
||||
|
||||
if not function or not tool_call_id or not function.name:
|
||||
yield ToolExecutionResult(sequence_number=sequence_number)
|
||||
|
|
@ -84,7 +86,16 @@ class ToolExecutor:
|
|||
error_exc, result = await self._execute_tool(function.name, tool_kwargs, ctx, mcp_tool_to_server)
|
||||
|
||||
# Emit completion events for tool execution
|
||||
has_error = error_exc or (result and ((result.error_code and result.error_code > 0) or result.error_message))
|
||||
has_error = bool(
|
||||
error_exc
|
||||
or (
|
||||
result
|
||||
and (
|
||||
((error_code := getattr(result, "error_code", None)) and error_code > 0)
|
||||
or getattr(result, "error_message", None)
|
||||
)
|
||||
)
|
||||
)
|
||||
async for event_result in self._emit_completion_events(
|
||||
function.name, ctx, sequence_number, output_index, item_id, has_error, mcp_tool_to_server
|
||||
):
|
||||
|
|
@ -101,7 +112,9 @@ class ToolExecutor:
|
|||
sequence_number=sequence_number,
|
||||
final_output_message=output_message,
|
||||
final_input_message=input_message,
|
||||
citation_files=result.metadata.get("citation_files") if result and result.metadata else None,
|
||||
citation_files=(
|
||||
metadata.get("citation_files") if result and (metadata := getattr(result, "metadata", None)) else None
|
||||
),
|
||||
)
|
||||
|
||||
async def _execute_knowledge_search_via_vector_store(
|
||||
|
|
@ -188,8 +201,9 @@ class ToolExecutor:
|
|||
|
||||
citation_files[file_id] = filename
|
||||
|
||||
# Cast to proper InterleavedContent type (list invariance)
|
||||
return ToolInvocationResult(
|
||||
content=content_items,
|
||||
content=content_items, # type: ignore[arg-type]
|
||||
metadata={
|
||||
"document_ids": [r.file_id for r in search_results],
|
||||
"chunks": [r.content[0].text if r.content else "" for r in search_results],
|
||||
|
|
@ -209,51 +223,60 @@ class ToolExecutor:
|
|||
) -> AsyncIterator[ToolExecutionResult]:
|
||||
"""Emit progress events for tool execution start."""
|
||||
# Emit in_progress event based on tool type (only for tools with specific streaming events)
|
||||
progress_event = None
|
||||
if mcp_tool_to_server and function_name in mcp_tool_to_server:
|
||||
sequence_number += 1
|
||||
progress_event = OpenAIResponseObjectStreamResponseMcpCallInProgress(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
yield ToolExecutionResult(
|
||||
stream_event=OpenAIResponseObjectStreamResponseMcpCallInProgress(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
),
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
elif function_name == "web_search":
|
||||
sequence_number += 1
|
||||
progress_event = OpenAIResponseObjectStreamResponseWebSearchCallInProgress(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
yield ToolExecutionResult(
|
||||
stream_event=OpenAIResponseObjectStreamResponseWebSearchCallInProgress(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
),
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
elif function_name == "knowledge_search":
|
||||
sequence_number += 1
|
||||
progress_event = OpenAIResponseObjectStreamResponseFileSearchCallInProgress(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
yield ToolExecutionResult(
|
||||
stream_event=OpenAIResponseObjectStreamResponseFileSearchCallInProgress(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
),
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
|
||||
if progress_event:
|
||||
yield ToolExecutionResult(stream_event=progress_event, sequence_number=sequence_number)
|
||||
|
||||
# For web search, emit searching event
|
||||
if function_name == "web_search":
|
||||
sequence_number += 1
|
||||
searching_event = OpenAIResponseObjectStreamResponseWebSearchCallSearching(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
yield ToolExecutionResult(
|
||||
stream_event=OpenAIResponseObjectStreamResponseWebSearchCallSearching(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
),
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
|
||||
|
||||
# For file search, emit searching event
|
||||
if function_name == "knowledge_search":
|
||||
sequence_number += 1
|
||||
searching_event = OpenAIResponseObjectStreamResponseFileSearchCallSearching(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
yield ToolExecutionResult(
|
||||
stream_event=OpenAIResponseObjectStreamResponseFileSearchCallSearching(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
),
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
|
||||
|
||||
async def _execute_tool(
|
||||
self,
|
||||
|
|
@ -261,7 +284,7 @@ class ToolExecutor:
|
|||
tool_kwargs: dict,
|
||||
ctx: ChatCompletionContext,
|
||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||
) -> tuple[Exception | None, any]:
|
||||
) -> tuple[Exception | None, Any]:
|
||||
"""Execute the tool and return error exception and result."""
|
||||
error_exc = None
|
||||
result = None
|
||||
|
|
@ -284,9 +307,13 @@ class ToolExecutor:
|
|||
kwargs=tool_kwargs,
|
||||
)
|
||||
elif function_name == "knowledge_search":
|
||||
response_file_search_tool = next(
|
||||
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)),
|
||||
None,
|
||||
response_file_search_tool = (
|
||||
next(
|
||||
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)),
|
||||
None,
|
||||
)
|
||||
if ctx.response_tools
|
||||
else None
|
||||
)
|
||||
if response_file_search_tool:
|
||||
# Use vector_stores.search API instead of knowledge_search tool
|
||||
|
|
@ -322,35 +349,34 @@ class ToolExecutor:
|
|||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||
) -> AsyncIterator[ToolExecutionResult]:
|
||||
"""Emit completion or failure events for tool execution."""
|
||||
completion_event = None
|
||||
|
||||
if mcp_tool_to_server and function_name in mcp_tool_to_server:
|
||||
sequence_number += 1
|
||||
if has_error:
|
||||
completion_event = OpenAIResponseObjectStreamResponseMcpCallFailed(
|
||||
mcp_failed_event = OpenAIResponseObjectStreamResponseMcpCallFailed(
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
yield ToolExecutionResult(stream_event=mcp_failed_event, sequence_number=sequence_number)
|
||||
else:
|
||||
completion_event = OpenAIResponseObjectStreamResponseMcpCallCompleted(
|
||||
mcp_completed_event = OpenAIResponseObjectStreamResponseMcpCallCompleted(
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
yield ToolExecutionResult(stream_event=mcp_completed_event, sequence_number=sequence_number)
|
||||
elif function_name == "web_search":
|
||||
sequence_number += 1
|
||||
completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted(
|
||||
web_completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
yield ToolExecutionResult(stream_event=web_completion_event, sequence_number=sequence_number)
|
||||
elif function_name == "knowledge_search":
|
||||
sequence_number += 1
|
||||
completion_event = OpenAIResponseObjectStreamResponseFileSearchCallCompleted(
|
||||
file_completion_event = OpenAIResponseObjectStreamResponseFileSearchCallCompleted(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
|
||||
if completion_event:
|
||||
yield ToolExecutionResult(stream_event=completion_event, sequence_number=sequence_number)
|
||||
yield ToolExecutionResult(stream_event=file_completion_event, sequence_number=sequence_number)
|
||||
|
||||
async def _build_result_messages(
|
||||
self,
|
||||
|
|
@ -360,21 +386,18 @@ class ToolExecutor:
|
|||
tool_kwargs: dict,
|
||||
ctx: ChatCompletionContext,
|
||||
error_exc: Exception | None,
|
||||
result: any,
|
||||
result: Any,
|
||||
has_error: bool,
|
||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||
) -> tuple[any, any]:
|
||||
) -> tuple[Any, Any]:
|
||||
"""Build output and input messages from tool execution results."""
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
# Build output message
|
||||
message: Any
|
||||
if mcp_tool_to_server and function.name in mcp_tool_to_server:
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseOutputMessageMCPCall,
|
||||
)
|
||||
|
||||
message = OpenAIResponseOutputMessageMCPCall(
|
||||
id=item_id,
|
||||
arguments=function.arguments,
|
||||
|
|
@ -383,10 +406,14 @@ class ToolExecutor:
|
|||
)
|
||||
if error_exc:
|
||||
message.error = str(error_exc)
|
||||
elif (result and result.error_code and result.error_code > 0) or (result and result.error_message):
|
||||
message.error = f"Error (code {result.error_code}): {result.error_message}"
|
||||
elif result and result.content:
|
||||
message.output = interleaved_content_as_str(result.content)
|
||||
elif (result and (error_code := getattr(result, "error_code", None)) and error_code > 0) or (
|
||||
result and getattr(result, "error_message", None)
|
||||
):
|
||||
ec = getattr(result, "error_code", "unknown")
|
||||
em = getattr(result, "error_message", "")
|
||||
message.error = f"Error (code {ec}): {em}"
|
||||
elif result and (content := getattr(result, "content", None)):
|
||||
message.output = interleaved_content_as_str(content)
|
||||
else:
|
||||
if function.name == "web_search":
|
||||
message = OpenAIResponseOutputMessageWebSearchToolCall(
|
||||
|
|
@ -401,17 +428,17 @@ class ToolExecutor:
|
|||
queries=[tool_kwargs.get("query", "")],
|
||||
status="completed",
|
||||
)
|
||||
if result and "document_ids" in result.metadata:
|
||||
if result and (metadata := getattr(result, "metadata", None)) and "document_ids" in metadata:
|
||||
message.results = []
|
||||
for i, doc_id in enumerate(result.metadata["document_ids"]):
|
||||
text = result.metadata["chunks"][i] if "chunks" in result.metadata else None
|
||||
score = result.metadata["scores"][i] if "scores" in result.metadata else None
|
||||
for i, doc_id in enumerate(metadata["document_ids"]):
|
||||
text = metadata["chunks"][i] if "chunks" in metadata else None
|
||||
score = metadata["scores"][i] if "scores" in metadata else None
|
||||
message.results.append(
|
||||
OpenAIResponseOutputMessageFileSearchToolCallResults(
|
||||
file_id=doc_id,
|
||||
filename=doc_id,
|
||||
text=text,
|
||||
score=score,
|
||||
text=text if text is not None else "",
|
||||
score=score if score is not None else 0.0,
|
||||
attributes={},
|
||||
)
|
||||
)
|
||||
|
|
@ -421,27 +448,32 @@ class ToolExecutor:
|
|||
raise ValueError(f"Unknown tool {function.name} called")
|
||||
|
||||
# Build input message
|
||||
input_message = None
|
||||
if result and result.content:
|
||||
if isinstance(result.content, str):
|
||||
content = result.content
|
||||
elif isinstance(result.content, list):
|
||||
content = []
|
||||
for item in result.content:
|
||||
input_message: OpenAIToolMessageParam | None = None
|
||||
if result and (result_content := getattr(result, "content", None)):
|
||||
# all the mypy contortions here are still unsatisfactory with random Any typing
|
||||
if isinstance(result_content, str):
|
||||
msg_content: str | list[Any] = result_content
|
||||
elif isinstance(result_content, list):
|
||||
content_list: list[Any] = []
|
||||
for item in result_content:
|
||||
part: Any
|
||||
if isinstance(item, TextContentItem):
|
||||
part = OpenAIChatCompletionContentPartTextParam(text=item.text)
|
||||
elif isinstance(item, ImageContentItem):
|
||||
if item.image.data:
|
||||
url = f"data:image;base64,{item.image.data}"
|
||||
url_value = f"data:image;base64,{item.image.data}"
|
||||
else:
|
||||
url = item.image.url
|
||||
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url))
|
||||
url_value = str(item.image.url) if item.image.url else ""
|
||||
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url_value))
|
||||
else:
|
||||
raise ValueError(f"Unknown result content type: {type(item)}")
|
||||
content.append(part)
|
||||
content_list.append(part)
|
||||
msg_content = content_list
|
||||
else:
|
||||
raise ValueError(f"Unknown result content type: {type(result.content)}")
|
||||
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
|
||||
raise ValueError(f"Unknown result content type: {type(result_content)}")
|
||||
# OpenAIToolMessageParam accepts str | list[TextParam] but we may have images
|
||||
# This is runtime-safe as the API accepts it, but mypy complains
|
||||
input_message = OpenAIToolMessageParam(content=msg_content, tool_call_id=tool_call_id) # type: ignore[arg-type]
|
||||
else:
|
||||
text = str(error_exc) if error_exc else "Tool execution failed"
|
||||
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import cast
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -100,17 +101,19 @@ class ToolContext(BaseModel):
|
|||
if isinstance(tool, OpenAIResponseToolMCP):
|
||||
previous_tools_by_label[tool.server_label] = tool
|
||||
# collect tool definitions which are the same in current and previous requests:
|
||||
tools_to_process = []
|
||||
tools_to_process: list[OpenAIResponseInputTool] = []
|
||||
matched: dict[str, OpenAIResponseInputToolMCP] = {}
|
||||
for tool in self.current_tools:
|
||||
# Mypy confuses OpenAIResponseInputTool (Input union) with OpenAIResponseTool (output union)
|
||||
# which differ only in MCP type (InputToolMCP vs ToolMCP). Code is correct.
|
||||
for tool in cast(list[OpenAIResponseInputTool], self.current_tools): # type: ignore[assignment]
|
||||
if isinstance(tool, OpenAIResponseInputToolMCP) and tool.server_label in previous_tools_by_label:
|
||||
previous_tool = previous_tools_by_label[tool.server_label]
|
||||
if previous_tool.allowed_tools == tool.allowed_tools:
|
||||
matched[tool.server_label] = tool
|
||||
else:
|
||||
tools_to_process.append(tool)
|
||||
tools_to_process.append(tool) # type: ignore[arg-type]
|
||||
else:
|
||||
tools_to_process.append(tool)
|
||||
tools_to_process.append(tool) # type: ignore[arg-type]
|
||||
# tools that are not the same or were not previously defined need to be processed:
|
||||
self.tools_to_process = tools_to_process
|
||||
# for all matched definitions, get the mcp_list_tools objects from the previous output:
|
||||
|
|
@ -119,9 +122,11 @@ class ToolContext(BaseModel):
|
|||
]
|
||||
# reconstruct the tool to server mappings that can be reused:
|
||||
for listing in self.previous_tool_listings:
|
||||
# listing is OpenAIResponseOutputMessageMCPListTools which has tools: list[MCPListToolsTool]
|
||||
definition = matched[listing.server_label]
|
||||
for tool in listing.tools:
|
||||
self.previous_tools[tool.name] = definition
|
||||
for mcp_tool in listing.tools:
|
||||
# mcp_tool is MCPListToolsTool which has a name: str field
|
||||
self.previous_tools[mcp_tool.name] = definition
|
||||
|
||||
def available_tools(self) -> list[OpenAIResponseTool]:
|
||||
if not self.current_tools:
|
||||
|
|
@ -139,6 +144,8 @@ class ToolContext(BaseModel):
|
|||
server_label=tool.server_label,
|
||||
allowed_tools=tool.allowed_tools,
|
||||
)
|
||||
# Exhaustive check - all tool types should be handled above
|
||||
raise AssertionError(f"Unexpected tool type: {type(tool)}")
|
||||
|
||||
return [convert_tool(tool) for tool in self.current_tools]
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
import asyncio
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
|
||||
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
|
|
@ -71,14 +72,14 @@ async def convert_chat_choice_to_response_message(
|
|||
|
||||
return OpenAIResponseMessage(
|
||||
id=message_id or f"msg_{uuid.uuid4()}",
|
||||
content=[OpenAIResponseOutputMessageContentOutputText(text=clean_text, annotations=annotations)],
|
||||
content=[OpenAIResponseOutputMessageContentOutputText(text=clean_text, annotations=list(annotations))],
|
||||
status="completed",
|
||||
role="assistant",
|
||||
)
|
||||
|
||||
|
||||
async def convert_response_content_to_chat_content(
|
||||
content: (str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]),
|
||||
content: str | Sequence[OpenAIResponseInputMessageContent | OpenAIResponseOutputMessageContent],
|
||||
) -> str | list[OpenAIChatCompletionContentPartParam]:
|
||||
"""
|
||||
Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts.
|
||||
|
|
@ -88,7 +89,8 @@ async def convert_response_content_to_chat_content(
|
|||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
converted_parts = []
|
||||
# Type with union to avoid list invariance issues
|
||||
converted_parts: list[OpenAIChatCompletionContentPartParam] = []
|
||||
for content_part in content:
|
||||
if isinstance(content_part, OpenAIResponseInputMessageContentText):
|
||||
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
|
||||
|
|
@ -158,9 +160,11 @@ async def convert_response_input_to_chat_messages(
|
|||
),
|
||||
)
|
||||
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
|
||||
# Output can be None, use empty string as fallback
|
||||
output_content = input_item.output if input_item.output is not None else ""
|
||||
messages.append(
|
||||
OpenAIToolMessageParam(
|
||||
content=input_item.output,
|
||||
content=output_content,
|
||||
tool_call_id=input_item.id,
|
||||
)
|
||||
)
|
||||
|
|
@ -172,7 +176,8 @@ async def convert_response_input_to_chat_messages(
|
|||
):
|
||||
# these are handled by the responses impl itself and not pass through to chat completions
|
||||
pass
|
||||
else:
|
||||
elif isinstance(input_item, OpenAIResponseMessage):
|
||||
# Narrow type to OpenAIResponseMessage which has content and role attributes
|
||||
content = await convert_response_content_to_chat_content(input_item.content)
|
||||
message_type = await get_message_type_by_role(input_item.role)
|
||||
if message_type is None:
|
||||
|
|
@ -191,7 +196,8 @@ async def convert_response_input_to_chat_messages(
|
|||
last_user_content = getattr(last_user_msg, "content", None)
|
||||
if last_user_content == content:
|
||||
continue # Skip duplicate user message
|
||||
messages.append(message_type(content=content))
|
||||
# Dynamic message type call - different message types have different content expectations
|
||||
messages.append(message_type(content=content)) # type: ignore[call-arg,arg-type]
|
||||
if len(tool_call_results):
|
||||
# Check if unpaired function_call_outputs reference function_calls from previous messages
|
||||
if previous_messages:
|
||||
|
|
@ -237,8 +243,11 @@ async def convert_response_text_to_chat_response_format(
|
|||
if text.format["type"] == "json_object":
|
||||
return OpenAIResponseFormatJSONObject()
|
||||
if text.format["type"] == "json_schema":
|
||||
# Assert name exists for json_schema format
|
||||
assert text.format.get("name"), "json_schema format requires a name"
|
||||
schema_name: str = text.format["name"] # type: ignore[assignment]
|
||||
return OpenAIResponseFormatJSONSchema(
|
||||
json_schema=OpenAIJSONSchema(name=text.format["name"], schema=text.format["schema"])
|
||||
json_schema=OpenAIJSONSchema(name=schema_name, schema=text.format["schema"])
|
||||
)
|
||||
raise ValueError(f"Unsupported text format: {text.format}")
|
||||
|
||||
|
|
@ -251,7 +260,7 @@ async def get_message_type_by_role(role: str) -> type[OpenAIMessageParam] | None
|
|||
"assistant": OpenAIAssistantMessageParam,
|
||||
"developer": OpenAIDeveloperMessageParam,
|
||||
}
|
||||
return role_to_type.get(role)
|
||||
return role_to_type.get(role) # type: ignore[return-value] # Pydantic models use ModelMetaclass
|
||||
|
||||
|
||||
def _extract_citations_from_text(
|
||||
|
|
@ -320,7 +329,8 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[
|
|||
|
||||
# Look up shields to get their provider_resource_id (actual model ID)
|
||||
model_ids = []
|
||||
shields_list = await safety_api.routing_table.list_shields()
|
||||
# TODO: list_shields not in Safety interface but available at runtime via API routing
|
||||
shields_list = await safety_api.routing_table.list_shields() # type: ignore[attr-defined]
|
||||
|
||||
for guardrail_id in guardrail_ids:
|
||||
matching_shields = [shield for shield in shields_list.data if shield.identifier == guardrail_id]
|
||||
|
|
@ -337,7 +347,9 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[
|
|||
for result in response.results:
|
||||
if result.flagged:
|
||||
message = result.user_message or "Content blocked by safety guardrails"
|
||||
flagged_categories = [cat for cat, flagged in result.categories.items() if flagged]
|
||||
flagged_categories = (
|
||||
[cat for cat, flagged in result.categories.items() if flagged] if result.categories else []
|
||||
)
|
||||
violation_type = result.metadata.get("violation_type", []) if result.metadata else []
|
||||
|
||||
if flagged_categories:
|
||||
|
|
@ -347,6 +359,9 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[
|
|||
|
||||
return message
|
||||
|
||||
# No violations found
|
||||
return None
|
||||
|
||||
|
||||
def extract_guardrail_ids(guardrails: list | None) -> list[str]:
|
||||
"""Extract guardrail IDs from guardrails parameter, handling both string IDs and ResponseGuardrailSpec objects."""
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import asyncio
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.inference import OpenAIMessageParam
|
||||
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
||||
from llama_stack.core.telemetry import tracing
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -31,7 +31,7 @@ class ShieldRunnerMixin:
|
|||
self.input_shields = input_shields
|
||||
self.output_shields = output_shields
|
||||
|
||||
async def run_multiple_shields(self, messages: list[Message], identifiers: list[str]) -> None:
|
||||
async def run_multiple_shields(self, messages: list[OpenAIMessageParam], identifiers: list[str]) -> None:
|
||||
async def run_shield_with_span(identifier: str):
|
||||
async with tracing.span(f"run_shield_{identifier}"):
|
||||
return await self.safety_api.run_shield(
|
||||
|
|
|
|||
|
|
@ -28,4 +28,13 @@ def available_providers() -> list[ProviderSpec]:
|
|||
config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig",
|
||||
description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.",
|
||||
),
|
||||
RemoteProviderSpec(
|
||||
api=Api.files,
|
||||
provider_type="remote::openai",
|
||||
adapter_type="openai",
|
||||
pip_packages=["openai"] + sql_store_pip_packages,
|
||||
module="llama_stack.providers.remote.files.openai",
|
||||
config_class="llama_stack.providers.remote.files.openai.config.OpenAIFilesImplConfig",
|
||||
description="OpenAI Files API provider for managing files through OpenAI's native file storage service.",
|
||||
),
|
||||
]
|
||||
|
|
|
|||
19
src/llama_stack/providers/remote/files/openai/__init__.py
Normal file
19
src/llama_stack/providers/remote/files/openai/__init__.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import AccessRule, Api
|
||||
|
||||
from .config import OpenAIFilesImplConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: OpenAIFilesImplConfig, deps: dict[Api, Any], policy: list[AccessRule] | None = None):
|
||||
from .files import OpenAIFilesImpl
|
||||
|
||||
impl = OpenAIFilesImpl(config, policy or [])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
28
src/llama_stack/providers/remote/files/openai/config.py
Normal file
28
src/llama_stack/providers/remote/files/openai/config.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.core.storage.datatypes import SqlStoreReference
|
||||
|
||||
|
||||
class OpenAIFilesImplConfig(BaseModel):
|
||||
"""Configuration for OpenAI Files API provider."""
|
||||
|
||||
api_key: str = Field(description="OpenAI API key for authentication")
|
||||
metadata_store: SqlStoreReference = Field(description="SQL store configuration for file metadata")
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||
return {
|
||||
"api_key": "${env.OPENAI_API_KEY}",
|
||||
"metadata_store": SqlStoreReference(
|
||||
backend="sql_default",
|
||||
table_name="openai_files_metadata",
|
||||
).model_dump(exclude_none=True),
|
||||
}
|
||||
239
src/llama_stack/providers/remote/files/openai/files.py
Normal file
239
src/llama_stack/providers/remote/files/openai/files.py
Normal file
|
|
@ -0,0 +1,239 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import Depends, File, Form, Response, UploadFile
|
||||
|
||||
from llama_stack.apis.common.errors import ResourceNotFoundError
|
||||
from llama_stack.apis.common.responses import Order
|
||||
from llama_stack.apis.files import (
|
||||
ExpiresAfter,
|
||||
Files,
|
||||
ListOpenAIFileResponse,
|
||||
OpenAIFileDeleteResponse,
|
||||
OpenAIFileObject,
|
||||
OpenAIFilePurpose,
|
||||
)
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.providers.utils.files.form_data import parse_expires_after
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
|
||||
from openai import OpenAI
|
||||
|
||||
from .config import OpenAIFilesImplConfig
|
||||
|
||||
|
||||
def _make_file_object(
|
||||
*,
|
||||
id: str,
|
||||
filename: str,
|
||||
purpose: str,
|
||||
bytes: int,
|
||||
created_at: int,
|
||||
expires_at: int,
|
||||
**kwargs: Any,
|
||||
) -> OpenAIFileObject:
|
||||
"""
|
||||
Construct an OpenAIFileObject and normalize expires_at.
|
||||
|
||||
If expires_at is greater than the max we treat it as no-expiration and
|
||||
return None for expires_at.
|
||||
"""
|
||||
obj = OpenAIFileObject(
|
||||
id=id,
|
||||
filename=filename,
|
||||
purpose=OpenAIFilePurpose(purpose),
|
||||
bytes=bytes,
|
||||
created_at=created_at,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
if obj.expires_at is not None and obj.expires_at > (obj.created_at + ExpiresAfter.MAX):
|
||||
obj.expires_at = None # type: ignore
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
class OpenAIFilesImpl(Files):
|
||||
"""OpenAI Files API implementation."""
|
||||
|
||||
def __init__(self, config: OpenAIFilesImplConfig, policy: list[AccessRule]) -> None:
|
||||
self._config = config
|
||||
self.policy = policy
|
||||
self._client: OpenAI | None = None
|
||||
self._sql_store: AuthorizedSqlStore | None = None
|
||||
|
||||
def _now(self) -> int:
|
||||
"""Return current UTC timestamp as int seconds."""
|
||||
return int(datetime.now(UTC).timestamp())
|
||||
|
||||
async def _get_file(self, file_id: str, return_expired: bool = False) -> dict[str, Any]:
|
||||
where: dict[str, str | dict] = {"id": file_id}
|
||||
if not return_expired:
|
||||
where["expires_at"] = {">": self._now()}
|
||||
if not (row := await self.sql_store.fetch_one("openai_files", where=where)):
|
||||
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
||||
return row
|
||||
|
||||
async def _delete_file(self, file_id: str) -> None:
|
||||
"""Delete a file from OpenAI and the database."""
|
||||
try:
|
||||
self.client.files.delete(file_id)
|
||||
except Exception as e:
|
||||
# If file doesn't exist on OpenAI side, just remove from metadata store
|
||||
if "not found" not in str(e).lower():
|
||||
raise RuntimeError(f"Failed to delete file from OpenAI: {e}") from e
|
||||
|
||||
await self.sql_store.delete("openai_files", where={"id": file_id})
|
||||
|
||||
async def _delete_if_expired(self, file_id: str) -> None:
|
||||
"""If the file exists and is expired, delete it."""
|
||||
if row := await self._get_file(file_id, return_expired=True):
|
||||
if (expires_at := row.get("expires_at")) and expires_at <= self._now():
|
||||
await self._delete_file(file_id)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self._client = OpenAI(api_key=self._config.api_key)
|
||||
|
||||
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store), self.policy)
|
||||
await self._sql_store.create_table(
|
||||
"openai_files",
|
||||
{
|
||||
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||
"filename": ColumnType.STRING,
|
||||
"purpose": ColumnType.STRING,
|
||||
"bytes": ColumnType.INTEGER,
|
||||
"created_at": ColumnType.INTEGER,
|
||||
"expires_at": ColumnType.INTEGER,
|
||||
},
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def client(self) -> OpenAI:
|
||||
assert self._client is not None, "Provider not initialized"
|
||||
return self._client
|
||||
|
||||
@property
|
||||
def sql_store(self) -> AuthorizedSqlStore:
|
||||
assert self._sql_store is not None, "Provider not initialized"
|
||||
return self._sql_store
|
||||
|
||||
async def openai_upload_file(
|
||||
self,
|
||||
file: Annotated[UploadFile, File()],
|
||||
purpose: Annotated[OpenAIFilePurpose, Form()],
|
||||
expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None,
|
||||
) -> OpenAIFileObject:
|
||||
filename = getattr(file, "filename", None) or "uploaded_file"
|
||||
content = await file.read()
|
||||
file_size = len(content)
|
||||
|
||||
created_at = self._now()
|
||||
|
||||
expires_at = created_at + ExpiresAfter.MAX * 42
|
||||
if purpose == OpenAIFilePurpose.BATCH:
|
||||
expires_at = created_at + ExpiresAfter.MAX
|
||||
|
||||
if expires_after is not None:
|
||||
expires_at = created_at + expires_after.seconds
|
||||
|
||||
try:
|
||||
from io import BytesIO
|
||||
|
||||
file_obj = BytesIO(content)
|
||||
file_obj.name = filename
|
||||
|
||||
response = self.client.files.create(
|
||||
file=file_obj,
|
||||
purpose=purpose.value,
|
||||
)
|
||||
|
||||
file_id = response.id
|
||||
|
||||
entry: dict[str, Any] = {
|
||||
"id": file_id,
|
||||
"filename": filename,
|
||||
"purpose": purpose.value,
|
||||
"bytes": file_size,
|
||||
"created_at": created_at,
|
||||
"expires_at": expires_at,
|
||||
}
|
||||
|
||||
await self.sql_store.insert("openai_files", entry)
|
||||
|
||||
return _make_file_object(**entry)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to upload file to OpenAI: {e}") from e
|
||||
|
||||
async def openai_list_files(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int | None = 10000,
|
||||
order: Order | None = Order.desc,
|
||||
purpose: OpenAIFilePurpose | None = None,
|
||||
) -> ListOpenAIFileResponse:
|
||||
if not order:
|
||||
order = Order.desc
|
||||
|
||||
where_conditions: dict[str, Any] = {"expires_at": {">": self._now()}}
|
||||
if purpose:
|
||||
where_conditions["purpose"] = purpose.value
|
||||
|
||||
paginated_result = await self.sql_store.fetch_all(
|
||||
table="openai_files",
|
||||
where=where_conditions,
|
||||
order_by=[("created_at", order.value)],
|
||||
cursor=("id", after) if after else None,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
files = [_make_file_object(**row) for row in paginated_result.data]
|
||||
|
||||
return ListOpenAIFileResponse(
|
||||
data=files,
|
||||
has_more=paginated_result.has_more,
|
||||
first_id=files[0].id if files else "",
|
||||
last_id=files[-1].id if files else "",
|
||||
)
|
||||
|
||||
async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject:
|
||||
await self._delete_if_expired(file_id)
|
||||
row = await self._get_file(file_id)
|
||||
return _make_file_object(**row)
|
||||
|
||||
async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse:
|
||||
await self._delete_if_expired(file_id)
|
||||
_ = await self._get_file(file_id)
|
||||
await self._delete_file(file_id)
|
||||
return OpenAIFileDeleteResponse(id=file_id, deleted=True)
|
||||
|
||||
async def openai_retrieve_file_content(self, file_id: str) -> Response:
|
||||
await self._delete_if_expired(file_id)
|
||||
|
||||
row = await self._get_file(file_id)
|
||||
|
||||
try:
|
||||
response = self.client.files.content(file_id)
|
||||
file_content = response.content
|
||||
|
||||
except Exception as e:
|
||||
if "not found" in str(e).lower():
|
||||
await self._delete_file(file_id)
|
||||
raise ResourceNotFoundError(file_id, "File", "files.list()") from e
|
||||
raise RuntimeError(f"Failed to download file from OpenAI: {e}") from e
|
||||
|
||||
return Response(
|
||||
content=file_content,
|
||||
media_type="application/octet-stream",
|
||||
headers={"Content-Disposition": f'attachment; filename="{row["filename"]}"'},
|
||||
)
|
||||
|
|
@ -33,4 +33,5 @@ class AnthropicInferenceAdapter(OpenAIMixin):
|
|||
return "https://api.anthropic.com/v1"
|
||||
|
||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||
return [m.id async for m in AsyncAnthropic(api_key=self.get_api_key()).models.list()]
|
||||
api_key = self._get_api_key_from_config_or_provider_data()
|
||||
return [m.id async for m in AsyncAnthropic(api_key=api_key).models.list()]
|
||||
|
|
|
|||
|
|
@ -33,10 +33,11 @@ class DatabricksInferenceAdapter(OpenAIMixin):
|
|||
|
||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||
# Filter out None values from endpoint names
|
||||
api_token = self._get_api_key_from_config_or_provider_data()
|
||||
return [
|
||||
endpoint.name # type: ignore[misc]
|
||||
for endpoint in WorkspaceClient(
|
||||
host=self.config.url, token=self.get_api_key()
|
||||
host=self.config.url, token=api_token
|
||||
).serving_endpoints.list() # TODO: this is not async
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -181,3 +181,22 @@ vlm_response = client.chat.completions.create(
|
|||
|
||||
print(f"VLM Response: {vlm_response.choices[0].message.content}")
|
||||
```
|
||||
|
||||
### Rerank Example
|
||||
|
||||
The following example shows how to rerank documents using an NVIDIA NIM.
|
||||
|
||||
```python
|
||||
rerank_response = client.alpha.inference.rerank(
|
||||
model="nvidia/nvidia/llama-3.2-nv-rerankqa-1b-v2",
|
||||
query="query",
|
||||
items=[
|
||||
"item_1",
|
||||
"item_2",
|
||||
"item_3",
|
||||
],
|
||||
)
|
||||
|
||||
for i, result in enumerate(rerank_response):
|
||||
print(f"{i+1}. [Index: {result.index}, " f"Score: {(result.relevance_score):.3f}]")
|
||||
```
|
||||
|
|
@ -28,6 +28,7 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
|
|||
Attributes:
|
||||
url (str): A base url for accessing the NVIDIA NIM, e.g. http://localhost:8000
|
||||
api_key (str): The access key for the hosted NIM endpoints
|
||||
rerank_model_to_url (dict[str, str]): Mapping of rerank model identifiers to their API endpoints
|
||||
|
||||
There are two ways to access NVIDIA NIMs -
|
||||
0. Hosted: Preview APIs hosted at https://integrate.api.nvidia.com
|
||||
|
|
@ -55,6 +56,14 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
|
|||
default_factory=lambda: os.getenv("NVIDIA_APPEND_API_VERSION", "True").lower() != "false",
|
||||
description="When set to false, the API version will not be appended to the base_url. By default, it is true.",
|
||||
)
|
||||
rerank_model_to_url: dict[str, str] = Field(
|
||||
default_factory=lambda: {
|
||||
"nv-rerank-qa-mistral-4b:1": "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking",
|
||||
"nvidia/nv-rerankqa-mistral-4b-v3": "https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking",
|
||||
"nvidia/llama-3.2-nv-rerankqa-1b-v2": "https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking",
|
||||
},
|
||||
description="Mapping of rerank model identifiers to their API endpoints. ",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
|
|
|
|||
|
|
@ -5,6 +5,19 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
import aiohttp
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
RerankData,
|
||||
RerankResponse,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
|
|
@ -61,3 +74,101 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
|
|||
:return: The NVIDIA API base URL
|
||||
"""
|
||||
return f"{self.config.url}/v1" if self.config.append_api_version else self.config.url
|
||||
|
||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||
"""
|
||||
Return both dynamic model IDs and statically configured rerank model IDs.
|
||||
"""
|
||||
dynamic_ids: Iterable[str] = []
|
||||
try:
|
||||
dynamic_ids = await super().list_provider_model_ids()
|
||||
except Exception:
|
||||
# If the dynamic listing fails, proceed with just configured rerank IDs
|
||||
dynamic_ids = []
|
||||
|
||||
configured_rerank_ids = list(self.config.rerank_model_to_url.keys())
|
||||
return list(dict.fromkeys(list(dynamic_ids) + configured_rerank_ids)) # remove duplicates
|
||||
|
||||
def construct_model_from_identifier(self, identifier: str) -> Model:
|
||||
"""
|
||||
Classify rerank models from config; otherwise use the base behavior.
|
||||
"""
|
||||
if identifier in self.config.rerank_model_to_url:
|
||||
return Model(
|
||||
provider_id=self.__provider_id__, # type: ignore[attr-defined]
|
||||
provider_resource_id=identifier,
|
||||
identifier=identifier,
|
||||
model_type=ModelType.rerank,
|
||||
)
|
||||
return super().construct_model_from_identifier(identifier)
|
||||
|
||||
async def rerank(
|
||||
self,
|
||||
model: str,
|
||||
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
|
||||
max_num_results: int | None = None,
|
||||
) -> RerankResponse:
|
||||
provider_model_id = await self._get_provider_model_id(model)
|
||||
|
||||
ranking_url = self.get_base_url()
|
||||
|
||||
if _is_nvidia_hosted(self.config) and provider_model_id in self.config.rerank_model_to_url:
|
||||
ranking_url = self.config.rerank_model_to_url[provider_model_id]
|
||||
|
||||
logger.debug(f"Using rerank endpoint: {ranking_url} for model: {provider_model_id}")
|
||||
|
||||
# Convert query to text format
|
||||
if isinstance(query, str):
|
||||
query_text = query
|
||||
elif isinstance(query, OpenAIChatCompletionContentPartTextParam):
|
||||
query_text = query.text
|
||||
else:
|
||||
raise ValueError("Query must be a string or text content part")
|
||||
|
||||
# Convert items to text format
|
||||
passages = []
|
||||
for item in items:
|
||||
if isinstance(item, str):
|
||||
passages.append({"text": item})
|
||||
elif isinstance(item, OpenAIChatCompletionContentPartTextParam):
|
||||
passages.append({"text": item.text})
|
||||
else:
|
||||
raise ValueError("Items must be strings or text content parts")
|
||||
|
||||
payload = {
|
||||
"model": provider_model_id,
|
||||
"query": {"text": query_text},
|
||||
"passages": passages,
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.get_api_key()}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(ranking_url, headers=headers, json=payload) as response:
|
||||
if response.status != 200:
|
||||
response_text = await response.text()
|
||||
raise ConnectionError(
|
||||
f"NVIDIA rerank API request failed with status {response.status}: {response_text}"
|
||||
)
|
||||
|
||||
result = await response.json()
|
||||
rankings = result.get("rankings", [])
|
||||
|
||||
# Convert to RerankData format
|
||||
rerank_data = []
|
||||
for ranking in rankings:
|
||||
rerank_data.append(RerankData(index=ranking["index"], relevance_score=ranking["logit"]))
|
||||
|
||||
# Apply max_num_results limit
|
||||
if max_num_results is not None:
|
||||
rerank_data = rerank_data[:max_num_results]
|
||||
|
||||
return RerankResponse(data=rerank_data)
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA rerank API at {ranking_url}: {e}") from e
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ class InferenceStore:
|
|||
self.reference = reference
|
||||
self.sql_store = None
|
||||
self.policy = policy
|
||||
self.enable_write_queue = True
|
||||
|
||||
# Async write queue and worker control
|
||||
self._queue: asyncio.Queue[tuple[OpenAIChatCompletion, list[OpenAIMessageParam]]] | None = None
|
||||
|
|
@ -47,14 +48,13 @@ class InferenceStore:
|
|||
base_store = sqlstore_impl(self.reference)
|
||||
self.sql_store = AuthorizedSqlStore(base_store, self.policy)
|
||||
|
||||
# Disable write queue for SQLite to avoid concurrency issues
|
||||
backend_name = self.reference.backend
|
||||
backend_config = _SQLSTORE_BACKENDS.get(backend_name)
|
||||
if backend_config is None:
|
||||
raise ValueError(
|
||||
f"Unregistered SQL backend '{backend_name}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
|
||||
)
|
||||
self.enable_write_queue = backend_config.type != StorageBackendType.SQL_SQLITE
|
||||
# Disable write queue for SQLite since WAL mode handles concurrency
|
||||
# Keep it enabled for other backends (like Postgres) for performance
|
||||
backend_config = _SQLSTORE_BACKENDS.get(self.reference.backend)
|
||||
if backend_config and backend_config.type == StorageBackendType.SQL_SQLITE:
|
||||
self.enable_write_queue = False
|
||||
logger.debug("Write queue disabled for SQLite (WAL mode handles concurrency)")
|
||||
|
||||
await self.sql_store.create_table(
|
||||
"chat_completions",
|
||||
{
|
||||
|
|
@ -70,8 +70,9 @@ class InferenceStore:
|
|||
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
|
||||
for _ in range(self._num_writers):
|
||||
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
|
||||
else:
|
||||
logger.info("Write queue disabled for SQLite to avoid concurrency issues")
|
||||
logger.debug(
|
||||
f"Inference store write queue enabled with {self._num_writers} writers, max queue size {self._max_write_queue_size}"
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if not self._worker_tasks:
|
||||
|
|
|
|||
|
|
@ -128,7 +128,9 @@ class LiteLLMOpenAIMixin(
|
|||
return schema
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
input_dict = {}
|
||||
from typing import Any
|
||||
|
||||
input_dict: dict[str, Any] = {}
|
||||
|
||||
input_dict["messages"] = [
|
||||
await convert_message_to_openai_dict_new(m, download_images=self.download_images) for m in request.messages
|
||||
|
|
@ -139,30 +141,27 @@ class LiteLLMOpenAIMixin(
|
|||
f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported."
|
||||
)
|
||||
|
||||
fmt = fmt.json_schema
|
||||
name = fmt["title"]
|
||||
del fmt["title"]
|
||||
fmt["additionalProperties"] = False
|
||||
# Convert to dict for manipulation
|
||||
fmt_dict = dict(fmt.json_schema)
|
||||
name = fmt_dict["title"]
|
||||
del fmt_dict["title"]
|
||||
fmt_dict["additionalProperties"] = False
|
||||
|
||||
# Apply additionalProperties: False recursively to all objects
|
||||
fmt = self._add_additional_properties_recursive(fmt)
|
||||
fmt_dict = self._add_additional_properties_recursive(fmt_dict)
|
||||
|
||||
input_dict["response_format"] = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": name,
|
||||
"schema": fmt,
|
||||
"schema": fmt_dict,
|
||||
"strict": self.json_schema_strict,
|
||||
},
|
||||
}
|
||||
if request.tools:
|
||||
input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools]
|
||||
if request.tool_config.tool_choice:
|
||||
input_dict["tool_choice"] = (
|
||||
request.tool_config.tool_choice.value
|
||||
if isinstance(request.tool_config.tool_choice, ToolChoice)
|
||||
else request.tool_config.tool_choice
|
||||
)
|
||||
if request.tool_config and (tool_choice := request.tool_config.tool_choice):
|
||||
input_dict["tool_choice"] = tool_choice.value if isinstance(tool_choice, ToolChoice) else tool_choice
|
||||
|
||||
return {
|
||||
"model": request.model,
|
||||
|
|
@ -176,10 +175,10 @@ class LiteLLMOpenAIMixin(
|
|||
def get_api_key(self) -> str:
|
||||
provider_data = self.get_request_provider_data()
|
||||
key_field = self.provider_data_api_key_field
|
||||
if provider_data and getattr(provider_data, key_field, None):
|
||||
api_key = getattr(provider_data, key_field)
|
||||
else:
|
||||
api_key = self.api_key_from_config
|
||||
if provider_data and key_field and (api_key := getattr(provider_data, key_field, None)):
|
||||
return str(api_key) # type: ignore[no-any-return] # getattr returns Any, can't narrow without runtime type inspection
|
||||
|
||||
api_key = self.api_key_from_config
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"API key is not set. Please provide a valid API key in the "
|
||||
|
|
@ -192,7 +191,13 @@ class LiteLLMOpenAIMixin(
|
|||
self,
|
||||
params: OpenAIEmbeddingsRequestWithExtraBody,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
if not self.model_store:
|
||||
raise ValueError("Model store is not initialized")
|
||||
|
||||
model_obj = await self.model_store.get_model(params.model)
|
||||
if model_obj.provider_resource_id is None:
|
||||
raise ValueError(f"Model {params.model} has no provider_resource_id")
|
||||
provider_resource_id = model_obj.provider_resource_id
|
||||
|
||||
# Convert input to list if it's a string
|
||||
input_list = [params.input] if isinstance(params.input, str) else params.input
|
||||
|
|
@ -200,7 +205,7 @@ class LiteLLMOpenAIMixin(
|
|||
# Call litellm embedding function
|
||||
# litellm.drop_params = True
|
||||
response = litellm.embedding(
|
||||
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
||||
model=self.get_litellm_model_name(provider_resource_id),
|
||||
input=input_list,
|
||||
api_key=self.get_api_key(),
|
||||
api_base=self.api_base,
|
||||
|
|
@ -217,7 +222,7 @@ class LiteLLMOpenAIMixin(
|
|||
|
||||
return OpenAIEmbeddingsResponse(
|
||||
data=data,
|
||||
model=model_obj.provider_resource_id,
|
||||
model=provider_resource_id,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
|
@ -225,10 +230,16 @@ class LiteLLMOpenAIMixin(
|
|||
self,
|
||||
params: OpenAICompletionRequestWithExtraBody,
|
||||
) -> OpenAICompletion:
|
||||
if not self.model_store:
|
||||
raise ValueError("Model store is not initialized")
|
||||
|
||||
model_obj = await self.model_store.get_model(params.model)
|
||||
if model_obj.provider_resource_id is None:
|
||||
raise ValueError(f"Model {params.model} has no provider_resource_id")
|
||||
provider_resource_id = model_obj.provider_resource_id
|
||||
|
||||
request_params = await prepare_openai_completion_params(
|
||||
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
||||
model=self.get_litellm_model_name(provider_resource_id),
|
||||
prompt=params.prompt,
|
||||
best_of=params.best_of,
|
||||
echo=params.echo,
|
||||
|
|
@ -249,7 +260,8 @@ class LiteLLMOpenAIMixin(
|
|||
api_key=self.get_api_key(),
|
||||
api_base=self.api_base,
|
||||
)
|
||||
return await litellm.atext_completion(**request_params)
|
||||
# LiteLLM returns compatible type but mypy can't verify external library
|
||||
return await litellm.atext_completion(**request_params) # type: ignore[no-any-return] # external lib lacks type stubs
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
|
|
@ -265,10 +277,16 @@ class LiteLLMOpenAIMixin(
|
|||
elif "include_usage" not in stream_options:
|
||||
stream_options = {**stream_options, "include_usage": True}
|
||||
|
||||
if not self.model_store:
|
||||
raise ValueError("Model store is not initialized")
|
||||
|
||||
model_obj = await self.model_store.get_model(params.model)
|
||||
if model_obj.provider_resource_id is None:
|
||||
raise ValueError(f"Model {params.model} has no provider_resource_id")
|
||||
provider_resource_id = model_obj.provider_resource_id
|
||||
|
||||
request_params = await prepare_openai_completion_params(
|
||||
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
||||
model=self.get_litellm_model_name(provider_resource_id),
|
||||
messages=params.messages,
|
||||
frequency_penalty=params.frequency_penalty,
|
||||
function_call=params.function_call,
|
||||
|
|
@ -294,7 +312,8 @@ class LiteLLMOpenAIMixin(
|
|||
api_key=self.get_api_key(),
|
||||
api_base=self.api_base,
|
||||
)
|
||||
return await litellm.acompletion(**request_params)
|
||||
# LiteLLM returns compatible type but mypy can't verify external library
|
||||
return await litellm.acompletion(**request_params) # type: ignore[no-any-return] # external lib lacks type stubs
|
||||
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ logger = get_logger(name=__name__, category="providers::utils")
|
|||
|
||||
|
||||
class RemoteInferenceProviderConfig(BaseModel):
|
||||
allowed_models: list[str] | None = Field( # TODO: make this non-optional and give a list() default
|
||||
allowed_models: list[str] | None = Field(
|
||||
default=None,
|
||||
description="List of models that should be registered with the model registry. If None, all models are allowed.",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -161,8 +161,10 @@ def get_sampling_strategy_options(params: SamplingParams) -> dict:
|
|||
if isinstance(params.strategy, GreedySamplingStrategy):
|
||||
options["temperature"] = 0.0
|
||||
elif isinstance(params.strategy, TopPSamplingStrategy):
|
||||
options["temperature"] = params.strategy.temperature
|
||||
options["top_p"] = params.strategy.top_p
|
||||
if params.strategy.temperature is not None:
|
||||
options["temperature"] = params.strategy.temperature
|
||||
if params.strategy.top_p is not None:
|
||||
options["top_p"] = params.strategy.top_p
|
||||
elif isinstance(params.strategy, TopKSamplingStrategy):
|
||||
options["top_k"] = params.strategy.top_k
|
||||
else:
|
||||
|
|
@ -192,12 +194,12 @@ def get_sampling_options(params: SamplingParams | None) -> dict:
|
|||
|
||||
def text_from_choice(choice) -> str:
|
||||
if hasattr(choice, "delta") and choice.delta:
|
||||
return choice.delta.content
|
||||
return choice.delta.content # type: ignore[no-any-return] # external OpenAI types lack precise annotations
|
||||
|
||||
if hasattr(choice, "message"):
|
||||
return choice.message.content
|
||||
return choice.message.content # type: ignore[no-any-return] # external OpenAI types lack precise annotations
|
||||
|
||||
return choice.text
|
||||
return choice.text # type: ignore[no-any-return] # external OpenAI types lack precise annotations
|
||||
|
||||
|
||||
def get_stop_reason(finish_reason: str) -> StopReason:
|
||||
|
|
@ -216,7 +218,7 @@ def convert_openai_completion_logprobs(
|
|||
) -> list[TokenLogProbs] | None:
|
||||
if not logprobs:
|
||||
return None
|
||||
if hasattr(logprobs, "top_logprobs"):
|
||||
if hasattr(logprobs, "top_logprobs") and logprobs.top_logprobs:
|
||||
return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs]
|
||||
|
||||
# Together supports logprobs with top_k=1 only. This means for each token position,
|
||||
|
|
@ -236,7 +238,7 @@ def convert_openai_completion_logprobs_stream(text: str, logprobs: float | OpenA
|
|||
if isinstance(logprobs, float):
|
||||
# Adapt response from Together CompletionChoicesChunk
|
||||
return [TokenLogProbs(logprobs_by_token={text: logprobs})]
|
||||
if hasattr(logprobs, "top_logprobs"):
|
||||
if hasattr(logprobs, "top_logprobs") and logprobs.top_logprobs:
|
||||
return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs]
|
||||
return None
|
||||
|
||||
|
|
@ -245,23 +247,24 @@ def process_completion_response(
|
|||
response: OpenAICompatCompletionResponse,
|
||||
) -> CompletionResponse:
|
||||
choice = response.choices[0]
|
||||
text = choice.text or ""
|
||||
# drop suffix <eot_id> if present and return stop reason as end of turn
|
||||
if choice.text.endswith("<|eot_id|>"):
|
||||
if text.endswith("<|eot_id|>"):
|
||||
return CompletionResponse(
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
content=choice.text[: -len("<|eot_id|>")],
|
||||
content=text[: -len("<|eot_id|>")],
|
||||
logprobs=convert_openai_completion_logprobs(choice.logprobs),
|
||||
)
|
||||
# drop suffix <eom_id> if present and return stop reason as end of message
|
||||
if choice.text.endswith("<|eom_id|>"):
|
||||
if text.endswith("<|eom_id|>"):
|
||||
return CompletionResponse(
|
||||
stop_reason=StopReason.end_of_message,
|
||||
content=choice.text[: -len("<|eom_id|>")],
|
||||
content=text[: -len("<|eom_id|>")],
|
||||
logprobs=convert_openai_completion_logprobs(choice.logprobs),
|
||||
)
|
||||
return CompletionResponse(
|
||||
stop_reason=get_stop_reason(choice.finish_reason),
|
||||
content=choice.text,
|
||||
stop_reason=get_stop_reason(choice.finish_reason or "stop"),
|
||||
content=text,
|
||||
logprobs=convert_openai_completion_logprobs(choice.logprobs),
|
||||
)
|
||||
|
||||
|
|
@ -272,10 +275,10 @@ def process_chat_completion_response(
|
|||
) -> ChatCompletionResponse:
|
||||
choice = response.choices[0]
|
||||
if choice.finish_reason == "tool_calls":
|
||||
if not choice.message or not choice.message.tool_calls:
|
||||
if not hasattr(choice, "message") or not choice.message or not choice.message.tool_calls: # type: ignore[attr-defined] # OpenAICompatCompletionChoice is runtime duck-typed
|
||||
raise ValueError("Tool calls are not present in the response")
|
||||
|
||||
tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls]
|
||||
tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls] # type: ignore[attr-defined] # OpenAICompatCompletionChoice is runtime duck-typed
|
||||
if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls):
|
||||
# If we couldn't parse a tool call, jsonify the tool calls and return them
|
||||
return ChatCompletionResponse(
|
||||
|
|
@ -287,9 +290,11 @@ def process_chat_completion_response(
|
|||
)
|
||||
else:
|
||||
# Otherwise, return tool calls as normal
|
||||
# Filter to only valid ToolCall objects
|
||||
valid_tool_calls = [tc for tc in tool_calls if isinstance(tc, ToolCall)]
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
tool_calls=tool_calls,
|
||||
tool_calls=valid_tool_calls,
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
# Content is not optional
|
||||
content="",
|
||||
|
|
@ -299,7 +304,7 @@ def process_chat_completion_response(
|
|||
|
||||
# TODO: This does not work well with tool calls for vLLM remote provider
|
||||
# Ref: https://github.com/meta-llama/llama-stack/issues/1058
|
||||
raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason))
|
||||
raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason or "stop"))
|
||||
|
||||
# NOTE: If we do not set tools in chat-completion request, we should not
|
||||
# expect the ToolCall in the response. Instead, we should return the raw
|
||||
|
|
@ -324,8 +329,8 @@ def process_chat_completion_response(
|
|||
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
content=raw_message.content,
|
||||
stop_reason=raw_message.stop_reason,
|
||||
content=raw_message.content, # type: ignore[arg-type] # decode_assistant_message returns Union[str, InterleavedContent]
|
||||
stop_reason=raw_message.stop_reason or StopReason.end_of_turn,
|
||||
tool_calls=raw_message.tool_calls,
|
||||
),
|
||||
logprobs=None,
|
||||
|
|
@ -448,7 +453,7 @@ async def process_chat_completion_stream_response(
|
|||
)
|
||||
|
||||
# parse tool calls and report errors
|
||||
message = decode_assistant_message(buffer, stop_reason)
|
||||
message = decode_assistant_message(buffer, stop_reason or StopReason.end_of_turn)
|
||||
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
|
|
@ -463,7 +468,7 @@ async def process_chat_completion_stream_response(
|
|||
)
|
||||
)
|
||||
|
||||
request_tools = {t.tool_name: t for t in request.tools}
|
||||
request_tools = {t.tool_name: t for t in (request.tools or [])}
|
||||
for tool_call in message.tool_calls:
|
||||
if tool_call.tool_name in request_tools:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
|
|
@ -525,7 +530,7 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
|
|||
}
|
||||
|
||||
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||
result["tool_calls"] = []
|
||||
tool_calls_list = []
|
||||
for tc in message.tool_calls:
|
||||
# The tool.tool_name can be a str or a BuiltinTool enum. If
|
||||
# it's the latter, convert to a string.
|
||||
|
|
@ -533,7 +538,7 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
|
|||
if isinstance(tool_name, BuiltinTool):
|
||||
tool_name = tool_name.value
|
||||
|
||||
result["tool_calls"].append(
|
||||
tool_calls_list.append(
|
||||
{
|
||||
"id": tc.call_id,
|
||||
"type": "function",
|
||||
|
|
@ -543,6 +548,7 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
|
|||
},
|
||||
}
|
||||
)
|
||||
result["tool_calls"] = tool_calls_list # type: ignore[assignment] # dict allows Any value, stricter type expected
|
||||
return result
|
||||
|
||||
|
||||
|
|
@ -608,7 +614,7 @@ async def convert_message_to_openai_dict_new(
|
|||
),
|
||||
)
|
||||
elif isinstance(content_, list):
|
||||
return [await impl(item) for item in content_]
|
||||
return [await impl(item) for item in content_] # type: ignore[misc] # recursive list comprehension confuses mypy's type narrowing
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {type(content_)}")
|
||||
|
||||
|
|
@ -620,7 +626,7 @@ async def convert_message_to_openai_dict_new(
|
|||
else:
|
||||
return [ret]
|
||||
|
||||
out: OpenAIChatCompletionMessage = None
|
||||
out: OpenAIChatCompletionMessage
|
||||
if isinstance(message, UserMessage):
|
||||
out = OpenAIChatCompletionUserMessage(
|
||||
role="user",
|
||||
|
|
@ -636,7 +642,7 @@ async def convert_message_to_openai_dict_new(
|
|||
),
|
||||
type="function",
|
||||
)
|
||||
for tool in message.tool_calls
|
||||
for tool in (message.tool_calls or [])
|
||||
]
|
||||
params = {}
|
||||
if tool_calls:
|
||||
|
|
@ -644,18 +650,18 @@ async def convert_message_to_openai_dict_new(
|
|||
out = OpenAIChatCompletionAssistantMessage(
|
||||
role="assistant",
|
||||
content=await _convert_message_content(message.content),
|
||||
**params,
|
||||
**params, # type: ignore[typeddict-item] # tool_calls dict expansion conflicts with TypedDict optional field
|
||||
)
|
||||
elif isinstance(message, ToolResponseMessage):
|
||||
out = OpenAIChatCompletionToolMessage(
|
||||
role="tool",
|
||||
tool_call_id=message.call_id,
|
||||
content=await _convert_message_content(message.content),
|
||||
content=await _convert_message_content(message.content), # type: ignore[typeddict-item] # content union type incompatible with TypedDict str requirement
|
||||
)
|
||||
elif isinstance(message, SystemMessage):
|
||||
out = OpenAIChatCompletionSystemMessage(
|
||||
role="system",
|
||||
content=await _convert_message_content(message.content),
|
||||
content=await _convert_message_content(message.content), # type: ignore[typeddict-item] # content union type incompatible with TypedDict str requirement
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported message type: {type(message)}")
|
||||
|
|
@ -758,16 +764,16 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
|
|||
function = out["function"]
|
||||
|
||||
if isinstance(tool.tool_name, BuiltinTool):
|
||||
function["name"] = tool.tool_name.value
|
||||
function["name"] = tool.tool_name.value # type: ignore[index] # dict value inferred as Any but mypy sees Collection[str]
|
||||
else:
|
||||
function["name"] = tool.tool_name
|
||||
function["name"] = tool.tool_name # type: ignore[index] # dict value inferred as Any but mypy sees Collection[str]
|
||||
|
||||
if tool.description:
|
||||
function["description"] = tool.description
|
||||
function["description"] = tool.description # type: ignore[index] # dict value inferred as Any but mypy sees Collection[str]
|
||||
|
||||
if tool.input_schema:
|
||||
# Pass through the entire JSON Schema as-is
|
||||
function["parameters"] = tool.input_schema
|
||||
function["parameters"] = tool.input_schema # type: ignore[index] # dict value inferred as Any but mypy sees Collection[str]
|
||||
|
||||
# NOTE: OpenAI does not support output_schema, so we drop it here
|
||||
# It's stored in LlamaStack for validation and other provider usage
|
||||
|
|
@ -815,15 +821,15 @@ def _convert_openai_request_tool_config(tool_choice: str | dict[str, Any] | None
|
|||
tool_config = ToolConfig()
|
||||
if tool_choice:
|
||||
try:
|
||||
tool_choice = ToolChoice(tool_choice)
|
||||
tool_choice = ToolChoice(tool_choice) # type: ignore[assignment] # reassigning to enum narrows union but mypy can't track after exception
|
||||
except ValueError:
|
||||
pass
|
||||
tool_config.tool_choice = tool_choice
|
||||
tool_config.tool_choice = tool_choice # type: ignore[assignment] # ToolConfig.tool_choice accepts Union[ToolChoice, dict] but mypy tracks narrower type
|
||||
return tool_config
|
||||
|
||||
|
||||
def _convert_openai_request_tools(tools: list[dict[str, Any]] | None = None) -> list[ToolDefinition]:
|
||||
lls_tools = []
|
||||
lls_tools: list[ToolDefinition] = []
|
||||
if not tools:
|
||||
return lls_tools
|
||||
|
||||
|
|
@ -843,16 +849,16 @@ def _convert_openai_request_tools(tools: list[dict[str, Any]] | None = None) ->
|
|||
|
||||
|
||||
def _convert_openai_request_response_format(
|
||||
response_format: OpenAIResponseFormatParam = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
):
|
||||
if not response_format:
|
||||
return None
|
||||
# response_format can be a dict or a pydantic model
|
||||
response_format = dict(response_format)
|
||||
if response_format.get("type", "") == "json_schema":
|
||||
response_format_dict = dict(response_format) # type: ignore[arg-type] # OpenAIResponseFormatParam union needs dict conversion
|
||||
if response_format_dict.get("type", "") == "json_schema":
|
||||
return JsonSchemaResponseFormat(
|
||||
type="json_schema",
|
||||
json_schema=response_format.get("json_schema", {}).get("schema", ""),
|
||||
type="json_schema", # type: ignore[arg-type] # Literal["json_schema"] incompatible with expected type
|
||||
json_schema=response_format_dict.get("json_schema", {}).get("schema", ""),
|
||||
)
|
||||
return None
|
||||
|
||||
|
|
@ -938,16 +944,15 @@ def _convert_openai_sampling_params(
|
|||
|
||||
# Map an explicit temperature of 0 to greedy sampling
|
||||
if temperature == 0:
|
||||
strategy = GreedySamplingStrategy()
|
||||
sampling_params.strategy = GreedySamplingStrategy()
|
||||
else:
|
||||
# OpenAI defaults to 1.0 for temperature and top_p if unset
|
||||
if temperature is None:
|
||||
temperature = 1.0
|
||||
if top_p is None:
|
||||
top_p = 1.0
|
||||
strategy = TopPSamplingStrategy(temperature=temperature, top_p=top_p)
|
||||
sampling_params.strategy = TopPSamplingStrategy(temperature=temperature, top_p=top_p) # type: ignore[assignment] # SamplingParams.strategy union accepts this type
|
||||
|
||||
sampling_params.strategy = strategy
|
||||
return sampling_params
|
||||
|
||||
|
||||
|
|
@ -957,23 +962,24 @@ def openai_messages_to_messages(
|
|||
"""
|
||||
Convert a list of OpenAIChatCompletionMessage into a list of Message.
|
||||
"""
|
||||
converted_messages = []
|
||||
converted_messages: list[Message] = []
|
||||
for message in messages:
|
||||
converted_message: Message
|
||||
if message.role == "system":
|
||||
converted_message = SystemMessage(content=openai_content_to_content(message.content))
|
||||
converted_message = SystemMessage(content=openai_content_to_content(message.content)) # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types
|
||||
elif message.role == "user":
|
||||
converted_message = UserMessage(content=openai_content_to_content(message.content))
|
||||
converted_message = UserMessage(content=openai_content_to_content(message.content)) # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types
|
||||
elif message.role == "assistant":
|
||||
converted_message = CompletionMessage(
|
||||
content=openai_content_to_content(message.content),
|
||||
tool_calls=_convert_openai_tool_calls(message.tool_calls),
|
||||
content=openai_content_to_content(message.content), # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types
|
||||
tool_calls=_convert_openai_tool_calls(message.tool_calls) if message.tool_calls else [], # type: ignore[arg-type] # OpenAI tool_calls type incompatible with conversion function
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
elif message.role == "tool":
|
||||
converted_message = ToolResponseMessage(
|
||||
role="tool",
|
||||
call_id=message.tool_call_id,
|
||||
content=openai_content_to_content(message.content),
|
||||
content=openai_content_to_content(message.content), # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown role {message.role}")
|
||||
|
|
@ -990,9 +996,9 @@ def openai_content_to_content(content: str | Iterable[OpenAIChatCompletionConten
|
|||
return [openai_content_to_content(c) for c in content]
|
||||
elif hasattr(content, "type"):
|
||||
if content.type == "text":
|
||||
return TextContentItem(type="text", text=content.text)
|
||||
return TextContentItem(type="text", text=content.text) # type: ignore[attr-defined] # Iterable narrowed by hasattr check but mypy doesn't track
|
||||
elif content.type == "image_url":
|
||||
return ImageContentItem(type="image", image=_URLOrData(url=URL(uri=content.image_url.url)))
|
||||
return ImageContentItem(type="image", image=_URLOrData(url=URL(uri=content.image_url.url))) # type: ignore[attr-defined] # Iterable narrowed by hasattr check but mypy doesn't track
|
||||
else:
|
||||
raise ValueError(f"Unknown content type: {content.type}")
|
||||
else:
|
||||
|
|
@ -1041,9 +1047,9 @@ def convert_openai_chat_completion_choice(
|
|||
completion_message=CompletionMessage(
|
||||
content=choice.message.content or "", # CompletionMessage content is not optional
|
||||
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
|
||||
tool_calls=_convert_openai_tool_calls(choice.message.tool_calls),
|
||||
tool_calls=_convert_openai_tool_calls(choice.message.tool_calls) if choice.message.tool_calls else [], # type: ignore[arg-type] # OpenAI tool_calls Optional type broadens union
|
||||
),
|
||||
logprobs=_convert_openai_logprobs(getattr(choice, "logprobs", None)),
|
||||
logprobs=_convert_openai_logprobs(getattr(choice, "logprobs", None)), # type: ignore[arg-type] # getattr returns Any, can't narrow without inspection
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1070,7 +1076,7 @@ async def convert_openai_chat_completion_stream(
|
|||
choice = chunk.choices[0] # assuming only one choice per chunk
|
||||
|
||||
# we assume there's only one finish_reason in the stream
|
||||
stop_reason = _convert_openai_finish_reason(choice.finish_reason) or stop_reason
|
||||
stop_reason = _convert_openai_finish_reason(choice.finish_reason) if choice.finish_reason else stop_reason
|
||||
logprobs = getattr(choice, "logprobs", None)
|
||||
|
||||
# if there's a tool call, emit an event for each tool in the list
|
||||
|
|
@ -1083,7 +1089,7 @@ async def convert_openai_chat_completion_stream(
|
|||
event=ChatCompletionResponseEvent(
|
||||
event_type=event_type,
|
||||
delta=TextDelta(text=choice.delta.content),
|
||||
logprobs=_convert_openai_logprobs(logprobs),
|
||||
logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -1101,10 +1107,10 @@ async def convert_openai_chat_completion_stream(
|
|||
event=ChatCompletionResponseEvent(
|
||||
event_type=event_type,
|
||||
delta=ToolCallDelta(
|
||||
tool_call=_convert_openai_tool_calls([tool_call])[0],
|
||||
tool_call=_convert_openai_tool_calls([tool_call])[0], # type: ignore[arg-type, list-item] # delta tool_call type differs from complete tool_call
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
),
|
||||
logprobs=_convert_openai_logprobs(logprobs),
|
||||
logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
|
@ -1125,12 +1131,15 @@ async def convert_openai_chat_completion_stream(
|
|||
if tool_call.function.name:
|
||||
buffer["name"] = tool_call.function.name
|
||||
delta = f"{buffer['name']}("
|
||||
buffer["content"] += delta
|
||||
if buffer["content"] is not None:
|
||||
buffer["content"] += delta
|
||||
|
||||
if tool_call.function.arguments:
|
||||
delta = tool_call.function.arguments
|
||||
buffer["arguments"] += delta
|
||||
buffer["content"] += delta
|
||||
if buffer["arguments"] is not None and delta:
|
||||
buffer["arguments"] += delta
|
||||
if buffer["content"] is not None and delta:
|
||||
buffer["content"] += delta
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
|
|
@ -1139,7 +1148,7 @@ async def convert_openai_chat_completion_stream(
|
|||
tool_call=delta,
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
),
|
||||
logprobs=_convert_openai_logprobs(logprobs),
|
||||
logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result
|
||||
)
|
||||
)
|
||||
elif choice.delta.content:
|
||||
|
|
@ -1147,7 +1156,7 @@ async def convert_openai_chat_completion_stream(
|
|||
event=ChatCompletionResponseEvent(
|
||||
event_type=event_type,
|
||||
delta=TextDelta(text=choice.delta.content or ""),
|
||||
logprobs=_convert_openai_logprobs(logprobs),
|
||||
logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -1155,7 +1164,8 @@ async def convert_openai_chat_completion_stream(
|
|||
logger.debug(f"toolcall_buffer[{idx}]: {buffer}")
|
||||
if buffer["name"]:
|
||||
delta = ")"
|
||||
buffer["content"] += delta
|
||||
if buffer["content"] is not None:
|
||||
buffer["content"] += delta
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=event_type,
|
||||
|
|
@ -1168,16 +1178,16 @@ async def convert_openai_chat_completion_stream(
|
|||
)
|
||||
|
||||
try:
|
||||
tool_call = ToolCall(
|
||||
call_id=buffer["call_id"],
|
||||
tool_name=buffer["name"],
|
||||
arguments=buffer["arguments"],
|
||||
parsed_tool_call = ToolCall(
|
||||
call_id=buffer["call_id"] or "",
|
||||
tool_name=buffer["name"] or "",
|
||||
arguments=buffer["arguments"] or "",
|
||||
)
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
tool_call=tool_call,
|
||||
tool_call=parsed_tool_call, # type: ignore[arg-type] # ToolCallDelta.tool_call accepts Union[str, ToolCall]
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
|
|
@ -1189,7 +1199,7 @@ async def convert_openai_chat_completion_stream(
|
|||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
tool_call=buffer["content"],
|
||||
tool_call=buffer["content"], # type: ignore[arg-type] # ToolCallDelta.tool_call accepts Union[str, ToolCall]
|
||||
parse_status=ToolCallParseStatus.failed,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
|
|
@ -1250,7 +1260,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
messages = openai_messages_to_messages(messages)
|
||||
messages = openai_messages_to_messages(messages) # type: ignore[assignment] # converted from OpenAI to LlamaStack message format
|
||||
response_format = _convert_openai_request_response_format(response_format)
|
||||
sampling_params = _convert_openai_sampling_params(
|
||||
max_tokens=max_tokens,
|
||||
|
|
@ -1259,15 +1269,15 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
)
|
||||
tool_config = _convert_openai_request_tool_config(tool_choice)
|
||||
|
||||
tools = _convert_openai_request_tools(tools)
|
||||
tools = _convert_openai_request_tools(tools) # type: ignore[assignment] # converted from OpenAI to LlamaStack tool format
|
||||
if tool_config.tool_choice == ToolChoice.none:
|
||||
tools = []
|
||||
tools = [] # type: ignore[assignment] # empty list narrows return type but mypy tracks broader type
|
||||
|
||||
outstanding_responses = []
|
||||
# "n" is the number of completions to generate per prompt
|
||||
n = n or 1
|
||||
for _i in range(0, n):
|
||||
response = self.chat_completion(
|
||||
response = self.chat_completion( # type: ignore[attr-defined] # mixin expects class to implement chat_completion
|
||||
model_id=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
|
|
@ -1279,7 +1289,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
outstanding_responses.append(response)
|
||||
|
||||
if stream:
|
||||
return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(self, model, outstanding_responses)
|
||||
return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(self, model, outstanding_responses) # type: ignore[no-any-return] # mixin async generator return type too complex for mypy
|
||||
|
||||
return await OpenAIChatCompletionToLlamaStackMixin._process_non_stream_response(
|
||||
self, model, outstanding_responses
|
||||
|
|
@ -1295,14 +1305,16 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
response = await outstanding_response
|
||||
async for chunk in response:
|
||||
event = chunk.event
|
||||
finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason)
|
||||
finish_reason = (
|
||||
_convert_stop_reason_to_openai_finish_reason(event.stop_reason) if event.stop_reason else None
|
||||
)
|
||||
|
||||
if isinstance(event.delta, TextDelta):
|
||||
text_delta = event.delta.text
|
||||
delta = OpenAIChoiceDelta(content=text_delta)
|
||||
yield OpenAIChatCompletionChunk(
|
||||
id=id,
|
||||
choices=[OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)],
|
||||
choices=[OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)], # type: ignore[arg-type] # finish_reason Optional[str] incompatible with Literal union
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
object="chat.completion.chunk",
|
||||
|
|
@ -1310,13 +1322,17 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
elif isinstance(event.delta, ToolCallDelta):
|
||||
if event.delta.parse_status == ToolCallParseStatus.succeeded:
|
||||
tool_call = event.delta.tool_call
|
||||
if isinstance(tool_call, str):
|
||||
continue
|
||||
|
||||
# First chunk includes full structure
|
||||
openai_tool_call = OpenAIChoiceDeltaToolCall(
|
||||
index=0,
|
||||
id=tool_call.call_id,
|
||||
function=OpenAIChoiceDeltaToolCallFunction(
|
||||
name=tool_call.tool_name,
|
||||
name=tool_call.tool_name
|
||||
if isinstance(tool_call.tool_name, str)
|
||||
else tool_call.tool_name.value, # type: ignore[arg-type] # enum .value extraction on Union confuses mypy
|
||||
arguments="",
|
||||
),
|
||||
)
|
||||
|
|
@ -1324,7 +1340,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
yield OpenAIChatCompletionChunk(
|
||||
id=id,
|
||||
choices=[
|
||||
OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)
|
||||
OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta) # type: ignore[arg-type] # finish_reason Optional[str] incompatible with Literal union
|
||||
],
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
|
|
@ -1341,7 +1357,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
yield OpenAIChatCompletionChunk(
|
||||
id=id,
|
||||
choices=[
|
||||
OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)
|
||||
OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta) # type: ignore[arg-type] # finish_reason Optional[str] incompatible with Literal union
|
||||
],
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
|
|
@ -1351,7 +1367,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
async def _process_non_stream_response(
|
||||
self, model: str, outstanding_responses: list[Awaitable[ChatCompletionResponse]]
|
||||
) -> OpenAIChatCompletion:
|
||||
choices = []
|
||||
choices: list[OpenAIChatCompletionChoice] = []
|
||||
for outstanding_response in outstanding_responses:
|
||||
response = await outstanding_response
|
||||
completion_message = response.completion_message
|
||||
|
|
@ -1360,14 +1376,14 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
|
||||
choice = OpenAIChatCompletionChoice(
|
||||
index=len(choices),
|
||||
message=message,
|
||||
message=message, # type: ignore[arg-type] # OpenAIChatCompletionMessage union incompatible with narrower Message type
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
choices.append(choice)
|
||||
choices.append(choice) # type: ignore[arg-type] # OpenAIChatCompletionChoice type annotation mismatch
|
||||
|
||||
return OpenAIChatCompletion(
|
||||
id=f"chatcmpl-{uuid.uuid4()}",
|
||||
choices=choices,
|
||||
choices=choices, # type: ignore[arg-type] # list[OpenAIChatCompletionChoice] union incompatible
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
|
|
|
|||
|
|
@ -83,9 +83,6 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
# This is set in list_models() and used in check_model_availability()
|
||||
_model_cache: dict[str, Model] = {}
|
||||
|
||||
# List of allowed models for this provider, if empty all models allowed
|
||||
allowed_models: list[str] = []
|
||||
|
||||
# Optional field name in provider data to look for API key, which takes precedence
|
||||
provider_data_api_key_field: str | None = None
|
||||
|
||||
|
|
@ -441,7 +438,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
for provider_model_id in provider_models_ids:
|
||||
if not isinstance(provider_model_id, str):
|
||||
raise ValueError(f"Model ID {provider_model_id} from list_provider_model_ids() is not a string")
|
||||
if self.allowed_models and provider_model_id not in self.allowed_models:
|
||||
if self.config.allowed_models is not None and provider_model_id not in self.config.allowed_models:
|
||||
logger.info(f"Skipping model {provider_model_id} as it is not in the allowed models list")
|
||||
continue
|
||||
model = self.construct_model_from_identifier(provider_model_id)
|
||||
|
|
|
|||
|
|
@ -196,6 +196,7 @@ def make_overlapped_chunks(
|
|||
chunks.append(
|
||||
Chunk(
|
||||
content=chunk,
|
||||
chunk_id=chunk_id,
|
||||
metadata=chunk_metadata,
|
||||
chunk_metadata=backend_chunk_metadata,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -70,13 +70,13 @@ class ResponsesStore:
|
|||
base_store = sqlstore_impl(self.reference)
|
||||
self.sql_store = AuthorizedSqlStore(base_store, self.policy)
|
||||
|
||||
# Disable write queue for SQLite since WAL mode handles concurrency
|
||||
# Keep it enabled for other backends (like Postgres) for performance
|
||||
backend_config = _SQLSTORE_BACKENDS.get(self.reference.backend)
|
||||
if backend_config is None:
|
||||
raise ValueError(
|
||||
f"Unregistered SQL backend '{self.reference.backend}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
|
||||
)
|
||||
if backend_config.type == StorageBackendType.SQL_SQLITE:
|
||||
if backend_config and backend_config.type == StorageBackendType.SQL_SQLITE:
|
||||
self.enable_write_queue = False
|
||||
logger.debug("Write queue disabled for SQLite (WAL mode handles concurrency)")
|
||||
|
||||
await self.sql_store.create_table(
|
||||
"openai_responses",
|
||||
{
|
||||
|
|
@ -99,8 +99,9 @@ class ResponsesStore:
|
|||
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
|
||||
for _ in range(self._num_writers):
|
||||
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
|
||||
else:
|
||||
logger.debug("Write queue disabled for SQLite to avoid concurrency issues")
|
||||
logger.debug(
|
||||
f"Responses store write queue enabled with {self._num_writers} writers, max queue size {self._max_write_queue_size}"
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if not self._worker_tasks:
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from sqlalchemy import (
|
|||
String,
|
||||
Table,
|
||||
Text,
|
||||
event,
|
||||
inspect,
|
||||
select,
|
||||
text,
|
||||
|
|
@ -75,7 +76,36 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
|||
self.metadata = MetaData()
|
||||
|
||||
def create_engine(self) -> AsyncEngine:
|
||||
return create_async_engine(self.config.engine_str, pool_pre_ping=True)
|
||||
# Configure connection args for better concurrency support
|
||||
connect_args = {}
|
||||
if "sqlite" in self.config.engine_str:
|
||||
# SQLite-specific optimizations for concurrent access
|
||||
# With WAL mode, most locks resolve in milliseconds, but allow up to 5s for edge cases
|
||||
connect_args["timeout"] = 5.0
|
||||
connect_args["check_same_thread"] = False # Allow usage across asyncio tasks
|
||||
|
||||
engine = create_async_engine(
|
||||
self.config.engine_str,
|
||||
pool_pre_ping=True,
|
||||
connect_args=connect_args,
|
||||
)
|
||||
|
||||
# Enable WAL mode for SQLite to support concurrent readers and writers
|
||||
if "sqlite" in self.config.engine_str:
|
||||
|
||||
@event.listens_for(engine.sync_engine, "connect")
|
||||
def set_sqlite_pragma(dbapi_conn, connection_record):
|
||||
cursor = dbapi_conn.cursor()
|
||||
# Enable Write-Ahead Logging for better concurrency
|
||||
cursor.execute("PRAGMA journal_mode=WAL")
|
||||
# Set busy timeout to 5 seconds (retry instead of immediate failure)
|
||||
# With WAL mode, locks should be brief; if we hit 5s there's a bigger issue
|
||||
cursor.execute("PRAGMA busy_timeout=5000")
|
||||
# Use NORMAL synchronous mode for better performance (still safe with WAL)
|
||||
cursor.execute("PRAGMA synchronous=NORMAL")
|
||||
cursor.close()
|
||||
|
||||
return engine
|
||||
|
||||
async def create_table(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -430,6 +430,32 @@ def _unwrap_generic_list(typ: type[list[T]]) -> type[T]:
|
|||
return list_type # type: ignore[no-any-return]
|
||||
|
||||
|
||||
def is_generic_sequence(typ: object) -> bool:
|
||||
"True if the specified type is a generic Sequence, i.e. `Sequence[T]`."
|
||||
import collections.abc
|
||||
|
||||
typ = unwrap_annotated_type(typ)
|
||||
return typing.get_origin(typ) is collections.abc.Sequence
|
||||
|
||||
|
||||
def unwrap_generic_sequence(typ: object) -> type:
|
||||
"""
|
||||
Extracts the item type of a Sequence type.
|
||||
|
||||
:param typ: The Sequence type `Sequence[T]`.
|
||||
:returns: The item type `T`.
|
||||
"""
|
||||
|
||||
return rewrap_annotated_type(_unwrap_generic_sequence, typ) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def _unwrap_generic_sequence(typ: object) -> type:
|
||||
"Extracts the item type of a Sequence type (e.g. returns `T` for `Sequence[T]`)."
|
||||
|
||||
(sequence_type,) = typing.get_args(typ) # unpack single tuple element
|
||||
return sequence_type # type: ignore[no-any-return]
|
||||
|
||||
|
||||
def is_generic_set(typ: object) -> TypeGuard[type[set]]:
|
||||
"True if the specified type is a generic set, i.e. `Set[T]`."
|
||||
|
||||
|
|
|
|||
|
|
@ -18,10 +18,12 @@ from .inspection import (
|
|||
TypeLike,
|
||||
is_generic_dict,
|
||||
is_generic_list,
|
||||
is_generic_sequence,
|
||||
is_type_optional,
|
||||
is_type_union,
|
||||
unwrap_generic_dict,
|
||||
unwrap_generic_list,
|
||||
unwrap_generic_sequence,
|
||||
unwrap_optional_type,
|
||||
unwrap_union_types,
|
||||
)
|
||||
|
|
@ -155,24 +157,28 @@ def python_type_to_name(data_type: TypeLike, force: bool = False) -> str:
|
|||
if metadata is not None:
|
||||
# type is Annotated[T, ...]
|
||||
arg = typing.get_args(data_type)[0]
|
||||
return python_type_to_name(arg)
|
||||
return python_type_to_name(arg, force=force)
|
||||
|
||||
if force:
|
||||
# generic types
|
||||
if is_type_optional(data_type, strict=True):
|
||||
inner_name = python_type_to_name(unwrap_optional_type(data_type))
|
||||
inner_name = python_type_to_name(unwrap_optional_type(data_type), force=True)
|
||||
return f"Optional__{inner_name}"
|
||||
elif is_generic_list(data_type):
|
||||
item_name = python_type_to_name(unwrap_generic_list(data_type))
|
||||
item_name = python_type_to_name(unwrap_generic_list(data_type), force=True)
|
||||
return f"List__{item_name}"
|
||||
elif is_generic_sequence(data_type):
|
||||
# Treat Sequence the same as List for schema generation purposes
|
||||
item_name = python_type_to_name(unwrap_generic_sequence(data_type), force=True)
|
||||
return f"List__{item_name}"
|
||||
elif is_generic_dict(data_type):
|
||||
key_type, value_type = unwrap_generic_dict(data_type)
|
||||
key_name = python_type_to_name(key_type)
|
||||
value_name = python_type_to_name(value_type)
|
||||
key_name = python_type_to_name(key_type, force=True)
|
||||
value_name = python_type_to_name(value_type, force=True)
|
||||
return f"Dict__{key_name}__{value_name}"
|
||||
elif is_type_union(data_type):
|
||||
member_types = unwrap_union_types(data_type)
|
||||
member_names = "__".join(python_type_to_name(member_type) for member_type in member_types)
|
||||
member_names = "__".join(python_type_to_name(member_type, force=True) for member_type in member_types)
|
||||
return f"Union__{member_names}"
|
||||
|
||||
# named system or user-defined type
|
||||
|
|
|
|||
|
|
@ -111,7 +111,7 @@ def get_class_property_docstrings(
|
|||
def docstring_to_schema(data_type: type) -> Schema:
|
||||
short_description, long_description = get_class_docstrings(data_type)
|
||||
schema: Schema = {
|
||||
"title": python_type_to_name(data_type),
|
||||
"title": python_type_to_name(data_type, force=True),
|
||||
}
|
||||
|
||||
description = "\n".join(filter(None, [short_description, long_description]))
|
||||
|
|
@ -417,6 +417,10 @@ class JsonSchemaGenerator:
|
|||
if origin_type is list:
|
||||
(list_type,) = typing.get_args(typ) # unpack single tuple element
|
||||
return {"type": "array", "items": self.type_to_schema(list_type)}
|
||||
elif origin_type is collections.abc.Sequence:
|
||||
# Treat Sequence the same as list for JSON schema (both are arrays)
|
||||
(sequence_type,) = typing.get_args(typ) # unpack single tuple element
|
||||
return {"type": "array", "items": self.type_to_schema(sequence_type)}
|
||||
elif origin_type is dict:
|
||||
key_type, value_type = typing.get_args(typ)
|
||||
if not (key_type is str or key_type is int or is_type_enum(key_type)):
|
||||
|
|
|
|||
|
|
@ -51,10 +51,14 @@ async function proxyRequest(request: NextRequest, method: string) {
|
|||
);
|
||||
|
||||
// Create response with same status and headers
|
||||
const proxyResponse = new NextResponse(responseText, {
|
||||
status: response.status,
|
||||
statusText: response.statusText,
|
||||
});
|
||||
// Handle 204 No Content responses specially
|
||||
const proxyResponse =
|
||||
response.status === 204
|
||||
? new NextResponse(null, { status: 204 })
|
||||
: new NextResponse(responseText, {
|
||||
status: response.status,
|
||||
statusText: response.statusText,
|
||||
});
|
||||
|
||||
// Copy response headers (except problematic ones)
|
||||
response.headers.forEach((value, key) => {
|
||||
|
|
|
|||
5
src/llama_stack/ui/app/prompts/page.tsx
Normal file
5
src/llama_stack/ui/app/prompts/page.tsx
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
import { PromptManagement } from "@/components/prompts";
|
||||
|
||||
export default function PromptsPage() {
|
||||
return <PromptManagement />;
|
||||
}
|
||||
|
|
@ -8,6 +8,7 @@ import {
|
|||
MessageCircle,
|
||||
Settings2,
|
||||
Compass,
|
||||
FileText,
|
||||
} from "lucide-react";
|
||||
import Link from "next/link";
|
||||
import { usePathname } from "next/navigation";
|
||||
|
|
@ -50,6 +51,11 @@ const manageItems = [
|
|||
url: "/logs/vector-stores",
|
||||
icon: Database,
|
||||
},
|
||||
{
|
||||
title: "Prompts",
|
||||
url: "/prompts",
|
||||
icon: FileText,
|
||||
},
|
||||
{
|
||||
title: "Documentation",
|
||||
url: "https://llama-stack.readthedocs.io/en/latest/references/api_reference/index.html",
|
||||
|
|
|
|||
4
src/llama_stack/ui/components/prompts/index.ts
Normal file
4
src/llama_stack/ui/components/prompts/index.ts
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
export { PromptManagement } from "./prompt-management";
|
||||
export { PromptList } from "./prompt-list";
|
||||
export { PromptEditor } from "./prompt-editor";
|
||||
export * from "./types";
|
||||
309
src/llama_stack/ui/components/prompts/prompt-editor.test.tsx
Normal file
309
src/llama_stack/ui/components/prompts/prompt-editor.test.tsx
Normal file
|
|
@ -0,0 +1,309 @@
|
|||
import React from "react";
|
||||
import { render, screen, fireEvent } from "@testing-library/react";
|
||||
import "@testing-library/jest-dom";
|
||||
import { PromptEditor } from "./prompt-editor";
|
||||
import type { Prompt, PromptFormData } from "./types";
|
||||
|
||||
describe("PromptEditor", () => {
|
||||
const mockOnSave = jest.fn();
|
||||
const mockOnCancel = jest.fn();
|
||||
const mockOnDelete = jest.fn();
|
||||
|
||||
const defaultProps = {
|
||||
onSave: mockOnSave,
|
||||
onCancel: mockOnCancel,
|
||||
onDelete: mockOnDelete,
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe("Create Mode", () => {
|
||||
test("renders create form correctly", () => {
|
||||
render(<PromptEditor {...defaultProps} />);
|
||||
|
||||
expect(screen.getByLabelText("Prompt Content *")).toBeInTheDocument();
|
||||
expect(screen.getByText("Variables")).toBeInTheDocument();
|
||||
expect(screen.getByText("Preview")).toBeInTheDocument();
|
||||
expect(screen.getByText("Create Prompt")).toBeInTheDocument();
|
||||
expect(screen.getByText("Cancel")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("shows preview placeholder when no content", () => {
|
||||
render(<PromptEditor {...defaultProps} />);
|
||||
|
||||
expect(
|
||||
screen.getByText("Enter content to preview the compiled prompt")
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("submits form with correct data", () => {
|
||||
render(<PromptEditor {...defaultProps} />);
|
||||
|
||||
const promptInput = screen.getByLabelText("Prompt Content *");
|
||||
fireEvent.change(promptInput, {
|
||||
target: { value: "Hello {{name}}, welcome!" },
|
||||
});
|
||||
|
||||
fireEvent.click(screen.getByText("Create Prompt"));
|
||||
|
||||
expect(mockOnSave).toHaveBeenCalledWith({
|
||||
prompt: "Hello {{name}}, welcome!",
|
||||
variables: [],
|
||||
});
|
||||
});
|
||||
|
||||
test("prevents submission with empty prompt", () => {
|
||||
render(<PromptEditor {...defaultProps} />);
|
||||
|
||||
fireEvent.click(screen.getByText("Create Prompt"));
|
||||
|
||||
expect(mockOnSave).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Edit Mode", () => {
|
||||
const mockPrompt: Prompt = {
|
||||
prompt_id: "prompt_123",
|
||||
prompt: "Hello {{name}}, how is {{weather}}?",
|
||||
version: 1,
|
||||
variables: ["name", "weather"],
|
||||
is_default: true,
|
||||
};
|
||||
|
||||
test("renders edit form with existing data", () => {
|
||||
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
|
||||
|
||||
expect(
|
||||
screen.getByDisplayValue("Hello {{name}}, how is {{weather}}?")
|
||||
).toBeInTheDocument();
|
||||
expect(screen.getAllByText("name")).toHaveLength(2); // One in variables, one in preview
|
||||
expect(screen.getAllByText("weather")).toHaveLength(2); // One in variables, one in preview
|
||||
expect(screen.getByText("Update Prompt")).toBeInTheDocument();
|
||||
expect(screen.getByText("Delete Prompt")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("submits updated data correctly", () => {
|
||||
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
|
||||
|
||||
const promptInput = screen.getByLabelText("Prompt Content *");
|
||||
fireEvent.change(promptInput, {
|
||||
target: { value: "Updated: Hello {{name}}!" },
|
||||
});
|
||||
|
||||
fireEvent.click(screen.getByText("Update Prompt"));
|
||||
|
||||
expect(mockOnSave).toHaveBeenCalledWith({
|
||||
prompt: "Updated: Hello {{name}}!",
|
||||
variables: ["name", "weather"],
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Variables Management", () => {
|
||||
test("adds new variable", () => {
|
||||
render(<PromptEditor {...defaultProps} />);
|
||||
|
||||
const variableInput = screen.getByPlaceholderText(
|
||||
"Add variable name (e.g. user_name, topic)"
|
||||
);
|
||||
fireEvent.change(variableInput, { target: { value: "testVar" } });
|
||||
fireEvent.click(screen.getByText("Add"));
|
||||
|
||||
expect(screen.getByText("testVar")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("prevents adding duplicate variables", () => {
|
||||
render(<PromptEditor {...defaultProps} />);
|
||||
|
||||
const variableInput = screen.getByPlaceholderText(
|
||||
"Add variable name (e.g. user_name, topic)"
|
||||
);
|
||||
|
||||
// Add first variable
|
||||
fireEvent.change(variableInput, { target: { value: "test" } });
|
||||
fireEvent.click(screen.getByText("Add"));
|
||||
|
||||
// Try to add same variable again
|
||||
fireEvent.change(variableInput, { target: { value: "test" } });
|
||||
|
||||
// Button should be disabled
|
||||
expect(screen.getByText("Add")).toBeDisabled();
|
||||
});
|
||||
|
||||
test("removes variable", () => {
|
||||
const mockPrompt: Prompt = {
|
||||
prompt_id: "prompt_123",
|
||||
prompt: "Hello {{name}}",
|
||||
version: 1,
|
||||
variables: ["name", "location"],
|
||||
is_default: true,
|
||||
};
|
||||
|
||||
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
|
||||
|
||||
// Check that both variables are present initially
|
||||
expect(screen.getAllByText("name").length).toBeGreaterThan(0);
|
||||
expect(screen.getAllByText("location").length).toBeGreaterThan(0);
|
||||
|
||||
// Remove the location variable by clicking the X button with the specific title
|
||||
const removeLocationButton = screen.getByTitle(
|
||||
"Remove location variable"
|
||||
);
|
||||
fireEvent.click(removeLocationButton);
|
||||
|
||||
// Name should still be there, location should be gone from the variables section
|
||||
expect(screen.getAllByText("name").length).toBeGreaterThan(0);
|
||||
expect(
|
||||
screen.queryByTitle("Remove location variable")
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("adds variable on Enter key", () => {
|
||||
render(<PromptEditor {...defaultProps} />);
|
||||
|
||||
const variableInput = screen.getByPlaceholderText(
|
||||
"Add variable name (e.g. user_name, topic)"
|
||||
);
|
||||
fireEvent.change(variableInput, { target: { value: "enterVar" } });
|
||||
|
||||
// Simulate Enter key press
|
||||
fireEvent.keyPress(variableInput, {
|
||||
key: "Enter",
|
||||
code: "Enter",
|
||||
charCode: 13,
|
||||
preventDefault: jest.fn(),
|
||||
});
|
||||
|
||||
// Check if the variable was added by looking for the badge
|
||||
expect(screen.getAllByText("enterVar").length).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe("Preview Functionality", () => {
|
||||
test("shows live preview with variables", () => {
|
||||
render(<PromptEditor {...defaultProps} />);
|
||||
|
||||
// Add prompt content
|
||||
const promptInput = screen.getByLabelText("Prompt Content *");
|
||||
fireEvent.change(promptInput, {
|
||||
target: { value: "Hello {{name}}, welcome to {{place}}!" },
|
||||
});
|
||||
|
||||
// Add variables
|
||||
const variableInput = screen.getByPlaceholderText(
|
||||
"Add variable name (e.g. user_name, topic)"
|
||||
);
|
||||
fireEvent.change(variableInput, { target: { value: "name" } });
|
||||
fireEvent.click(screen.getByText("Add"));
|
||||
|
||||
fireEvent.change(variableInput, { target: { value: "place" } });
|
||||
fireEvent.click(screen.getByText("Add"));
|
||||
|
||||
// Check that preview area shows the content
|
||||
expect(screen.getByText("Compiled Prompt")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("shows variable value inputs in preview", () => {
|
||||
const mockPrompt: Prompt = {
|
||||
prompt_id: "prompt_123",
|
||||
prompt: "Hello {{name}}",
|
||||
version: 1,
|
||||
variables: ["name"],
|
||||
is_default: true,
|
||||
};
|
||||
|
||||
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
|
||||
|
||||
expect(screen.getByText("Variable Values")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByPlaceholderText("Enter value for name")
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("shows color legend for variable states", () => {
|
||||
render(<PromptEditor {...defaultProps} />);
|
||||
|
||||
// Add content to show preview
|
||||
const promptInput = screen.getByLabelText("Prompt Content *");
|
||||
fireEvent.change(promptInput, {
|
||||
target: { value: "Hello {{name}}" },
|
||||
});
|
||||
|
||||
expect(screen.getByText("Used")).toBeInTheDocument();
|
||||
expect(screen.getByText("Unused")).toBeInTheDocument();
|
||||
expect(screen.getByText("Undefined")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Error Handling", () => {
|
||||
test("displays error message", () => {
|
||||
const errorMessage = "Prompt contains undeclared variables";
|
||||
render(<PromptEditor {...defaultProps} error={errorMessage} />);
|
||||
|
||||
expect(screen.getByText(errorMessage)).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Delete Functionality", () => {
|
||||
const mockPrompt: Prompt = {
|
||||
prompt_id: "prompt_123",
|
||||
prompt: "Hello {{name}}",
|
||||
version: 1,
|
||||
variables: ["name"],
|
||||
is_default: true,
|
||||
};
|
||||
|
||||
test("shows delete button in edit mode", () => {
|
||||
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
|
||||
|
||||
expect(screen.getByText("Delete Prompt")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("hides delete button in create mode", () => {
|
||||
render(<PromptEditor {...defaultProps} />);
|
||||
|
||||
expect(screen.queryByText("Delete Prompt")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("calls onDelete with confirmation", () => {
|
||||
const originalConfirm = window.confirm;
|
||||
window.confirm = jest.fn(() => true);
|
||||
|
||||
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
|
||||
|
||||
fireEvent.click(screen.getByText("Delete Prompt"));
|
||||
|
||||
expect(window.confirm).toHaveBeenCalledWith(
|
||||
"Are you sure you want to delete this prompt? This action cannot be undone."
|
||||
);
|
||||
expect(mockOnDelete).toHaveBeenCalledWith("prompt_123");
|
||||
|
||||
window.confirm = originalConfirm;
|
||||
});
|
||||
|
||||
test("does not delete when confirmation is cancelled", () => {
|
||||
const originalConfirm = window.confirm;
|
||||
window.confirm = jest.fn(() => false);
|
||||
|
||||
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
|
||||
|
||||
fireEvent.click(screen.getByText("Delete Prompt"));
|
||||
|
||||
expect(mockOnDelete).not.toHaveBeenCalled();
|
||||
|
||||
window.confirm = originalConfirm;
|
||||
});
|
||||
});
|
||||
|
||||
describe("Cancel Functionality", () => {
|
||||
test("calls onCancel when cancel button is clicked", () => {
|
||||
render(<PromptEditor {...defaultProps} />);
|
||||
|
||||
fireEvent.click(screen.getByText("Cancel"));
|
||||
|
||||
expect(mockOnCancel).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
346
src/llama_stack/ui/components/prompts/prompt-editor.tsx
Normal file
346
src/llama_stack/ui/components/prompts/prompt-editor.tsx
Normal file
|
|
@ -0,0 +1,346 @@
|
|||
"use client";
|
||||
|
||||
import { useState, useEffect } from "react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { Textarea } from "@/components/ui/textarea";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import {
|
||||
Card,
|
||||
CardContent,
|
||||
CardDescription,
|
||||
CardHeader,
|
||||
CardTitle,
|
||||
} from "@/components/ui/card";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
import { X, Plus, Save, Trash2 } from "lucide-react";
|
||||
import { Prompt, PromptFormData } from "./types";
|
||||
|
||||
interface PromptEditorProps {
|
||||
prompt?: Prompt;
|
||||
onSave: (prompt: PromptFormData) => void;
|
||||
onCancel: () => void;
|
||||
onDelete?: (promptId: string) => void;
|
||||
error?: string | null;
|
||||
}
|
||||
|
||||
export function PromptEditor({
|
||||
prompt,
|
||||
onSave,
|
||||
onCancel,
|
||||
onDelete,
|
||||
error,
|
||||
}: PromptEditorProps) {
|
||||
const [formData, setFormData] = useState<PromptFormData>({
|
||||
prompt: "",
|
||||
variables: [],
|
||||
});
|
||||
|
||||
const [newVariable, setNewVariable] = useState("");
|
||||
const [variableValues, setVariableValues] = useState<Record<string, string>>(
|
||||
{}
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (prompt) {
|
||||
setFormData({
|
||||
prompt: prompt.prompt || "",
|
||||
variables: prompt.variables || [],
|
||||
});
|
||||
}
|
||||
}, [prompt]);
|
||||
|
||||
const handleSubmit = (e: React.FormEvent) => {
|
||||
e.preventDefault();
|
||||
if (!formData.prompt.trim()) {
|
||||
return;
|
||||
}
|
||||
onSave(formData);
|
||||
};
|
||||
|
||||
const addVariable = () => {
|
||||
if (
|
||||
newVariable.trim() &&
|
||||
!formData.variables.includes(newVariable.trim())
|
||||
) {
|
||||
setFormData(prev => ({
|
||||
...prev,
|
||||
variables: [...prev.variables, newVariable.trim()],
|
||||
}));
|
||||
setNewVariable("");
|
||||
}
|
||||
};
|
||||
|
||||
const removeVariable = (variableToRemove: string) => {
|
||||
setFormData(prev => ({
|
||||
...prev,
|
||||
variables: prev.variables.filter(
|
||||
variable => variable !== variableToRemove
|
||||
),
|
||||
}));
|
||||
};
|
||||
|
||||
const renderPreview = () => {
|
||||
const text = formData.prompt;
|
||||
if (!text) return text;
|
||||
|
||||
// Split text by variable patterns and process each part
|
||||
const parts = text.split(/(\{\{\s*\w+\s*\}\})/g);
|
||||
|
||||
return parts.map((part, index) => {
|
||||
const variableMatch = part.match(/\{\{\s*(\w+)\s*\}\}/);
|
||||
if (variableMatch) {
|
||||
const variableName = variableMatch[1];
|
||||
const isDefined = formData.variables.includes(variableName);
|
||||
const value = variableValues[variableName];
|
||||
|
||||
if (!isDefined) {
|
||||
// Variable not in variables list - likely a typo/bug (RED)
|
||||
return (
|
||||
<span
|
||||
key={index}
|
||||
className="bg-red-100 text-red-800 dark:bg-red-900 dark:text-red-200 px-1 rounded font-medium"
|
||||
>
|
||||
{part}
|
||||
</span>
|
||||
);
|
||||
} else if (value && value.trim()) {
|
||||
// Variable defined and has value - show the value (GREEN)
|
||||
return (
|
||||
<span
|
||||
key={index}
|
||||
className="bg-green-100 text-green-800 dark:bg-green-900 dark:text-green-200 px-1 rounded font-medium"
|
||||
>
|
||||
{value}
|
||||
</span>
|
||||
);
|
||||
} else {
|
||||
// Variable defined but empty (YELLOW)
|
||||
return (
|
||||
<span
|
||||
key={index}
|
||||
className="bg-yellow-100 text-yellow-800 dark:bg-yellow-900 dark:text-yellow-200 px-1 rounded font-medium"
|
||||
>
|
||||
{part}
|
||||
</span>
|
||||
);
|
||||
}
|
||||
}
|
||||
return part;
|
||||
});
|
||||
};
|
||||
|
||||
const updateVariableValue = (variable: string, value: string) => {
|
||||
setVariableValues(prev => ({
|
||||
...prev,
|
||||
[variable]: value,
|
||||
}));
|
||||
};
|
||||
|
||||
return (
|
||||
<form onSubmit={handleSubmit} className="space-y-6">
|
||||
{error && (
|
||||
<div className="p-4 bg-destructive/10 border border-destructive/20 rounded-md">
|
||||
<p className="text-destructive text-sm">{error}</p>
|
||||
</div>
|
||||
)}
|
||||
<div className="grid grid-cols-1 lg:grid-cols-2 gap-6">
|
||||
{/* Form Section */}
|
||||
<div className="space-y-4">
|
||||
<div>
|
||||
<Label htmlFor="prompt">Prompt Content *</Label>
|
||||
<Textarea
|
||||
id="prompt"
|
||||
value={formData.prompt}
|
||||
onChange={e =>
|
||||
setFormData(prev => ({ ...prev, prompt: e.target.value }))
|
||||
}
|
||||
placeholder="Enter your prompt content here. Use {{variable_name}} for dynamic variables."
|
||||
className="min-h-32 font-mono mt-2"
|
||||
required
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground mt-2">
|
||||
Use double curly braces around variable names, e.g.,{" "}
|
||||
{`{{user_name}}`} or {`{{topic}}`}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="space-y-3">
|
||||
<Label className="text-sm font-medium">Variables</Label>
|
||||
|
||||
<div className="flex gap-2 mt-2">
|
||||
<Input
|
||||
value={newVariable}
|
||||
onChange={e => setNewVariable(e.target.value)}
|
||||
placeholder="Add variable name (e.g. user_name, topic)"
|
||||
onKeyPress={e =>
|
||||
e.key === "Enter" && (e.preventDefault(), addVariable())
|
||||
}
|
||||
className="flex-1"
|
||||
/>
|
||||
<Button
|
||||
type="button"
|
||||
onClick={addVariable}
|
||||
size="sm"
|
||||
disabled={
|
||||
!newVariable.trim() ||
|
||||
formData.variables.includes(newVariable.trim())
|
||||
}
|
||||
>
|
||||
<Plus className="h-4 w-4" />
|
||||
Add
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{formData.variables.length > 0 && (
|
||||
<div className="border rounded-lg p-3 bg-muted/20">
|
||||
<div className="flex flex-wrap gap-2">
|
||||
{formData.variables.map(variable => (
|
||||
<Badge
|
||||
key={variable}
|
||||
variant="secondary"
|
||||
className="text-sm px-2 py-1"
|
||||
>
|
||||
{variable}
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => removeVariable(variable)}
|
||||
className="ml-2 hover:text-destructive transition-colors"
|
||||
title={`Remove ${variable} variable`}
|
||||
>
|
||||
<X className="h-3 w-3" />
|
||||
</button>
|
||||
</Badge>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Variables that can be used in the prompt template. Each variable
|
||||
should match a {`{{variable}}`} placeholder in the content above.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Preview Section */}
|
||||
<div className="space-y-4">
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle className="text-lg">Preview</CardTitle>
|
||||
<CardDescription>
|
||||
Live preview of compiled prompt and variable substitution.
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent className="space-y-4">
|
||||
{formData.prompt ? (
|
||||
<>
|
||||
{/* Variable Values */}
|
||||
{formData.variables.length > 0 && (
|
||||
<div className="space-y-3">
|
||||
<Label className="text-sm font-medium">
|
||||
Variable Values
|
||||
</Label>
|
||||
<div className="space-y-2">
|
||||
{formData.variables.map(variable => (
|
||||
<div
|
||||
key={variable}
|
||||
className="grid grid-cols-2 gap-3 items-center"
|
||||
>
|
||||
<div className="text-sm font-mono text-muted-foreground">
|
||||
{variable}
|
||||
</div>
|
||||
<Input
|
||||
id={`var-${variable}`}
|
||||
value={variableValues[variable] || ""}
|
||||
onChange={e =>
|
||||
updateVariableValue(variable, e.target.value)
|
||||
}
|
||||
placeholder={`Enter value for ${variable}`}
|
||||
className="text-sm"
|
||||
/>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
<Separator />
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Live Preview */}
|
||||
<div>
|
||||
<Label className="text-sm font-medium mb-2 block">
|
||||
Compiled Prompt
|
||||
</Label>
|
||||
<div className="bg-muted/50 p-4 rounded-lg border">
|
||||
<div className="text-sm leading-relaxed whitespace-pre-wrap">
|
||||
{renderPreview()}
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex flex-wrap gap-4 mt-2 text-xs">
|
||||
<div className="flex items-center gap-1">
|
||||
<div className="w-3 h-3 bg-green-500 dark:bg-green-400 border rounded"></div>
|
||||
<span className="text-muted-foreground">Used</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-1">
|
||||
<div className="w-3 h-3 bg-yellow-500 dark:bg-yellow-400 border rounded"></div>
|
||||
<span className="text-muted-foreground">Unused</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-1">
|
||||
<div className="w-3 h-3 bg-red-500 dark:bg-red-400 border rounded"></div>
|
||||
<span className="text-muted-foreground">Undefined</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
<div className="text-center py-8">
|
||||
<div className="text-muted-foreground text-sm">
|
||||
Enter content to preview the compiled prompt
|
||||
</div>
|
||||
<div className="text-xs text-muted-foreground mt-2">
|
||||
Use {`{{variable_name}}`} to add dynamic variables
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Separator />
|
||||
|
||||
<div className="flex justify-between">
|
||||
<div>
|
||||
{prompt && onDelete && (
|
||||
<Button
|
||||
type="button"
|
||||
variant="destructive"
|
||||
onClick={() => {
|
||||
if (
|
||||
confirm(
|
||||
`Are you sure you want to delete this prompt? This action cannot be undone.`
|
||||
)
|
||||
) {
|
||||
onDelete(prompt.prompt_id);
|
||||
}
|
||||
}}
|
||||
>
|
||||
<Trash2 className="h-4 w-4 mr-2" />
|
||||
Delete Prompt
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
<div className="flex gap-2">
|
||||
<Button type="button" variant="outline" onClick={onCancel}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button type="submit">
|
||||
<Save className="h-4 w-4 mr-2" />
|
||||
{prompt ? "Update" : "Create"} Prompt
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</form>
|
||||
);
|
||||
}
|
||||
259
src/llama_stack/ui/components/prompts/prompt-list.test.tsx
Normal file
259
src/llama_stack/ui/components/prompts/prompt-list.test.tsx
Normal file
|
|
@ -0,0 +1,259 @@
|
|||
import React from "react";
|
||||
import { render, screen, fireEvent } from "@testing-library/react";
|
||||
import "@testing-library/jest-dom";
|
||||
import { PromptList } from "./prompt-list";
|
||||
import type { Prompt } from "./types";
|
||||
|
||||
describe("PromptList", () => {
|
||||
const mockOnEdit = jest.fn();
|
||||
const mockOnDelete = jest.fn();
|
||||
|
||||
const defaultProps = {
|
||||
prompts: [],
|
||||
onEdit: mockOnEdit,
|
||||
onDelete: mockOnDelete,
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe("Empty State", () => {
|
||||
test("renders empty message when no prompts", () => {
|
||||
render(<PromptList {...defaultProps} />);
|
||||
|
||||
expect(screen.getByText("No prompts yet")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("shows filtered empty message when search has no results", () => {
|
||||
const prompts: Prompt[] = [
|
||||
{
|
||||
prompt_id: "prompt_123",
|
||||
prompt: "Hello world",
|
||||
version: 1,
|
||||
variables: [],
|
||||
is_default: false,
|
||||
},
|
||||
];
|
||||
|
||||
render(<PromptList {...defaultProps} prompts={prompts} />);
|
||||
|
||||
// Search for something that doesn't exist
|
||||
const searchInput = screen.getByPlaceholderText("Search prompts...");
|
||||
fireEvent.change(searchInput, { target: { value: "nonexistent" } });
|
||||
|
||||
expect(
|
||||
screen.getByText("No prompts match your filters")
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Prompts Display", () => {
|
||||
const mockPrompts: Prompt[] = [
|
||||
{
|
||||
prompt_id: "prompt_123",
|
||||
prompt: "Hello {{name}}, how are you?",
|
||||
version: 1,
|
||||
variables: ["name"],
|
||||
is_default: true,
|
||||
},
|
||||
{
|
||||
prompt_id: "prompt_456",
|
||||
prompt: "Summarize this {{text}} in {{length}} words",
|
||||
version: 2,
|
||||
variables: ["text", "length"],
|
||||
is_default: false,
|
||||
},
|
||||
{
|
||||
prompt_id: "prompt_789",
|
||||
prompt: "Simple prompt with no variables",
|
||||
version: 1,
|
||||
variables: [],
|
||||
is_default: false,
|
||||
},
|
||||
];
|
||||
|
||||
test("renders prompts table with correct headers", () => {
|
||||
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
|
||||
|
||||
expect(screen.getByText("ID")).toBeInTheDocument();
|
||||
expect(screen.getByText("Content")).toBeInTheDocument();
|
||||
expect(screen.getByText("Variables")).toBeInTheDocument();
|
||||
expect(screen.getByText("Version")).toBeInTheDocument();
|
||||
expect(screen.getByText("Actions")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("renders prompt data correctly", () => {
|
||||
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
|
||||
|
||||
// Check prompt IDs
|
||||
expect(screen.getByText("prompt_123")).toBeInTheDocument();
|
||||
expect(screen.getByText("prompt_456")).toBeInTheDocument();
|
||||
expect(screen.getByText("prompt_789")).toBeInTheDocument();
|
||||
|
||||
// Check content
|
||||
expect(
|
||||
screen.getByText("Hello {{name}}, how are you?")
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText("Summarize this {{text}} in {{length}} words")
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText("Simple prompt with no variables")
|
||||
).toBeInTheDocument();
|
||||
|
||||
// Check versions
|
||||
expect(screen.getAllByText("1")).toHaveLength(2); // Two prompts with version 1
|
||||
expect(screen.getByText("2")).toBeInTheDocument();
|
||||
|
||||
// Check default badge
|
||||
expect(screen.getByText("Default")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("renders variables correctly", () => {
|
||||
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
|
||||
|
||||
// Check variables display
|
||||
expect(screen.getByText("name")).toBeInTheDocument();
|
||||
expect(screen.getByText("text")).toBeInTheDocument();
|
||||
expect(screen.getByText("length")).toBeInTheDocument();
|
||||
expect(screen.getByText("None")).toBeInTheDocument(); // For prompt with no variables
|
||||
});
|
||||
|
||||
test("prompt ID links are clickable and call onEdit", () => {
|
||||
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
|
||||
|
||||
// Click on the first prompt ID link
|
||||
const promptLink = screen.getByRole("button", { name: "prompt_123" });
|
||||
fireEvent.click(promptLink);
|
||||
|
||||
expect(mockOnEdit).toHaveBeenCalledWith(mockPrompts[0]);
|
||||
});
|
||||
|
||||
test("edit buttons call onEdit", () => {
|
||||
const { container } = render(
|
||||
<PromptList {...defaultProps} prompts={mockPrompts} />
|
||||
);
|
||||
|
||||
// Find the action buttons in the table - they should be in the last column
|
||||
const actionCells = container.querySelectorAll("td:last-child");
|
||||
const firstActionCell = actionCells[0];
|
||||
const editButton = firstActionCell?.querySelector("button");
|
||||
|
||||
expect(editButton).toBeInTheDocument();
|
||||
fireEvent.click(editButton!);
|
||||
|
||||
expect(mockOnEdit).toHaveBeenCalledWith(mockPrompts[0]);
|
||||
});
|
||||
|
||||
test("delete buttons call onDelete with confirmation", () => {
|
||||
const originalConfirm = window.confirm;
|
||||
window.confirm = jest.fn(() => true);
|
||||
|
||||
const { container } = render(
|
||||
<PromptList {...defaultProps} prompts={mockPrompts} />
|
||||
);
|
||||
|
||||
// Find the delete button (second button in the first action cell)
|
||||
const actionCells = container.querySelectorAll("td:last-child");
|
||||
const firstActionCell = actionCells[0];
|
||||
const buttons = firstActionCell?.querySelectorAll("button");
|
||||
const deleteButton = buttons?.[1]; // Second button should be delete
|
||||
|
||||
expect(deleteButton).toBeInTheDocument();
|
||||
fireEvent.click(deleteButton!);
|
||||
|
||||
expect(window.confirm).toHaveBeenCalledWith(
|
||||
"Are you sure you want to delete this prompt? This action cannot be undone."
|
||||
);
|
||||
expect(mockOnDelete).toHaveBeenCalledWith("prompt_123");
|
||||
|
||||
window.confirm = originalConfirm;
|
||||
});
|
||||
|
||||
test("delete does not execute when confirmation is cancelled", () => {
|
||||
const originalConfirm = window.confirm;
|
||||
window.confirm = jest.fn(() => false);
|
||||
|
||||
const { container } = render(
|
||||
<PromptList {...defaultProps} prompts={mockPrompts} />
|
||||
);
|
||||
|
||||
const actionCells = container.querySelectorAll("td:last-child");
|
||||
const firstActionCell = actionCells[0];
|
||||
const buttons = firstActionCell?.querySelectorAll("button");
|
||||
const deleteButton = buttons?.[1]; // Second button should be delete
|
||||
|
||||
expect(deleteButton).toBeInTheDocument();
|
||||
fireEvent.click(deleteButton!);
|
||||
|
||||
expect(mockOnDelete).not.toHaveBeenCalled();
|
||||
|
||||
window.confirm = originalConfirm;
|
||||
});
|
||||
});
|
||||
|
||||
describe("Search Functionality", () => {
|
||||
const mockPrompts: Prompt[] = [
|
||||
{
|
||||
prompt_id: "user_greeting",
|
||||
prompt: "Hello {{name}}, welcome!",
|
||||
version: 1,
|
||||
variables: ["name"],
|
||||
is_default: true,
|
||||
},
|
||||
{
|
||||
prompt_id: "system_summary",
|
||||
prompt: "Summarize the following text",
|
||||
version: 1,
|
||||
variables: [],
|
||||
is_default: false,
|
||||
},
|
||||
];
|
||||
|
||||
test("filters prompts by prompt ID", () => {
|
||||
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
|
||||
|
||||
const searchInput = screen.getByPlaceholderText("Search prompts...");
|
||||
fireEvent.change(searchInput, { target: { value: "user" } });
|
||||
|
||||
expect(screen.getByText("user_greeting")).toBeInTheDocument();
|
||||
expect(screen.queryByText("system_summary")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("filters prompts by content", () => {
|
||||
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
|
||||
|
||||
const searchInput = screen.getByPlaceholderText("Search prompts...");
|
||||
fireEvent.change(searchInput, { target: { value: "welcome" } });
|
||||
|
||||
expect(screen.getByText("user_greeting")).toBeInTheDocument();
|
||||
expect(screen.queryByText("system_summary")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("search is case insensitive", () => {
|
||||
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
|
||||
|
||||
const searchInput = screen.getByPlaceholderText("Search prompts...");
|
||||
fireEvent.change(searchInput, { target: { value: "HELLO" } });
|
||||
|
||||
expect(screen.getByText("user_greeting")).toBeInTheDocument();
|
||||
expect(screen.queryByText("system_summary")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("clearing search shows all prompts", () => {
|
||||
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
|
||||
|
||||
const searchInput = screen.getByPlaceholderText("Search prompts...");
|
||||
|
||||
// Filter first
|
||||
fireEvent.change(searchInput, { target: { value: "user" } });
|
||||
expect(screen.queryByText("system_summary")).not.toBeInTheDocument();
|
||||
|
||||
// Clear search
|
||||
fireEvent.change(searchInput, { target: { value: "" } });
|
||||
expect(screen.getByText("user_greeting")).toBeInTheDocument();
|
||||
expect(screen.getByText("system_summary")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
164
src/llama_stack/ui/components/prompts/prompt-list.tsx
Normal file
164
src/llama_stack/ui/components/prompts/prompt-list.tsx
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/ui/table";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Edit, Search, Trash2 } from "lucide-react";
|
||||
import { Prompt, PromptFilters } from "./types";
|
||||
|
||||
interface PromptListProps {
|
||||
prompts: Prompt[];
|
||||
onEdit: (prompt: Prompt) => void;
|
||||
onDelete: (promptId: string) => void;
|
||||
}
|
||||
|
||||
export function PromptList({ prompts, onEdit, onDelete }: PromptListProps) {
|
||||
const [filters, setFilters] = useState<PromptFilters>({});
|
||||
|
||||
const filteredPrompts = prompts.filter(prompt => {
|
||||
if (
|
||||
filters.searchTerm &&
|
||||
!(
|
||||
prompt.prompt
|
||||
?.toLowerCase()
|
||||
.includes(filters.searchTerm.toLowerCase()) ||
|
||||
prompt.prompt_id
|
||||
.toLowerCase()
|
||||
.includes(filters.searchTerm.toLowerCase())
|
||||
)
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
{/* Filters */}
|
||||
<div className="flex flex-col sm:flex-row gap-4">
|
||||
<div className="relative flex-1">
|
||||
<Search className="absolute left-3 top-1/2 transform -translate-y-1/2 text-muted-foreground h-4 w-4" />
|
||||
<Input
|
||||
placeholder="Search prompts..."
|
||||
value={filters.searchTerm || ""}
|
||||
onChange={e =>
|
||||
setFilters(prev => ({ ...prev, searchTerm: e.target.value }))
|
||||
}
|
||||
className="pl-10"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Prompts Table */}
|
||||
<div className="overflow-auto">
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>ID</TableHead>
|
||||
<TableHead>Content</TableHead>
|
||||
<TableHead>Variables</TableHead>
|
||||
<TableHead>Version</TableHead>
|
||||
<TableHead>Actions</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{filteredPrompts.map(prompt => (
|
||||
<TableRow key={prompt.prompt_id}>
|
||||
<TableCell className="max-w-48">
|
||||
<Button
|
||||
variant="link"
|
||||
className="p-0 h-auto font-mono text-blue-600 hover:text-blue-800 dark:text-blue-400 dark:hover:text-blue-300 max-w-full justify-start"
|
||||
onClick={() => onEdit(prompt)}
|
||||
title={prompt.prompt_id}
|
||||
>
|
||||
<div className="truncate">{prompt.prompt_id}</div>
|
||||
</Button>
|
||||
</TableCell>
|
||||
<TableCell className="max-w-64">
|
||||
<div
|
||||
className="font-mono text-xs text-muted-foreground truncate"
|
||||
title={prompt.prompt || "No content"}
|
||||
>
|
||||
{prompt.prompt || "No content"}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{prompt.variables.length > 0 ? (
|
||||
<div className="flex flex-wrap gap-1">
|
||||
{prompt.variables.map(variable => (
|
||||
<Badge
|
||||
key={variable}
|
||||
variant="outline"
|
||||
className="text-xs"
|
||||
>
|
||||
{variable}
|
||||
</Badge>
|
||||
))}
|
||||
</div>
|
||||
) : (
|
||||
<span className="text-muted-foreground text-sm">None</span>
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell className="text-sm">
|
||||
{prompt.version}
|
||||
{prompt.is_default && (
|
||||
<Badge variant="secondary" className="text-xs ml-2">
|
||||
Default
|
||||
</Badge>
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="flex gap-1">
|
||||
<Button
|
||||
size="sm"
|
||||
variant="outline"
|
||||
onClick={() => onEdit(prompt)}
|
||||
className="h-8 w-8 p-0"
|
||||
>
|
||||
<Edit className="h-3 w-3" />
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
variant="outline"
|
||||
onClick={() => {
|
||||
if (
|
||||
confirm(
|
||||
`Are you sure you want to delete this prompt? This action cannot be undone.`
|
||||
)
|
||||
) {
|
||||
onDelete(prompt.prompt_id);
|
||||
}
|
||||
}}
|
||||
className="h-8 w-8 p-0 text-destructive hover:text-destructive"
|
||||
>
|
||||
<Trash2 className="h-3 w-3" />
|
||||
</Button>
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
|
||||
{filteredPrompts.length === 0 && (
|
||||
<div className="text-center py-12">
|
||||
<div className="text-muted-foreground">
|
||||
{prompts.length === 0
|
||||
? "No prompts yet"
|
||||
: "No prompts match your filters"}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
304
src/llama_stack/ui/components/prompts/prompt-management.test.tsx
Normal file
304
src/llama_stack/ui/components/prompts/prompt-management.test.tsx
Normal file
|
|
@ -0,0 +1,304 @@
|
|||
import React from "react";
|
||||
import { render, screen, fireEvent, waitFor } from "@testing-library/react";
|
||||
import "@testing-library/jest-dom";
|
||||
import { PromptManagement } from "./prompt-management";
|
||||
import type { Prompt } from "./types";
|
||||
|
||||
// Mock the auth client
|
||||
const mockPromptsClient = {
|
||||
list: jest.fn(),
|
||||
create: jest.fn(),
|
||||
update: jest.fn(),
|
||||
delete: jest.fn(),
|
||||
};
|
||||
|
||||
jest.mock("@/hooks/use-auth-client", () => ({
|
||||
useAuthClient: () => ({
|
||||
prompts: mockPromptsClient,
|
||||
}),
|
||||
}));
|
||||
|
||||
describe("PromptManagement", () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe("Loading State", () => {
|
||||
test("renders loading state initially", () => {
|
||||
mockPromptsClient.list.mockReturnValue(new Promise(() => {})); // Never resolves
|
||||
render(<PromptManagement />);
|
||||
|
||||
expect(screen.getByText("Loading prompts...")).toBeInTheDocument();
|
||||
expect(screen.getByText("Prompts")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Empty State", () => {
|
||||
test("renders empty state when no prompts", async () => {
|
||||
mockPromptsClient.list.mockResolvedValue([]);
|
||||
render(<PromptManagement />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("No prompts found.")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
expect(screen.getByText("Create Your First Prompt")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("opens modal when clicking 'Create Your First Prompt'", async () => {
|
||||
mockPromptsClient.list.mockResolvedValue([]);
|
||||
render(<PromptManagement />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("Create Your First Prompt")
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
fireEvent.click(screen.getByText("Create Your First Prompt"));
|
||||
|
||||
expect(screen.getByText("Create New Prompt")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Error State", () => {
|
||||
test("renders error state when API fails", async () => {
|
||||
const error = new Error("API not found");
|
||||
mockPromptsClient.list.mockRejectedValue(error);
|
||||
render(<PromptManagement />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/Error:/)).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("renders specific error for 404", async () => {
|
||||
const error = new Error("404 Not found");
|
||||
mockPromptsClient.list.mockRejectedValue(error);
|
||||
render(<PromptManagement />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText(/Prompts API endpoint not found/)
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Prompts List", () => {
|
||||
const mockPrompts: Prompt[] = [
|
||||
{
|
||||
prompt_id: "prompt_123",
|
||||
prompt: "Hello {{name}}, how are you?",
|
||||
version: 1,
|
||||
variables: ["name"],
|
||||
is_default: true,
|
||||
},
|
||||
{
|
||||
prompt_id: "prompt_456",
|
||||
prompt: "Summarize this {{text}}",
|
||||
version: 2,
|
||||
variables: ["text"],
|
||||
is_default: false,
|
||||
},
|
||||
];
|
||||
|
||||
test("renders prompts list correctly", async () => {
|
||||
mockPromptsClient.list.mockResolvedValue(mockPrompts);
|
||||
render(<PromptManagement />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("prompt_123")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
expect(screen.getByText("prompt_456")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText("Hello {{name}}, how are you?")
|
||||
).toBeInTheDocument();
|
||||
expect(screen.getByText("Summarize this {{text}}")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("opens modal when clicking 'New Prompt' button", async () => {
|
||||
mockPromptsClient.list.mockResolvedValue(mockPrompts);
|
||||
render(<PromptManagement />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("prompt_123")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
fireEvent.click(screen.getByText("New Prompt"));
|
||||
|
||||
expect(screen.getByText("Create New Prompt")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Modal Operations", () => {
|
||||
const mockPrompts: Prompt[] = [
|
||||
{
|
||||
prompt_id: "prompt_123",
|
||||
prompt: "Hello {{name}}",
|
||||
version: 1,
|
||||
variables: ["name"],
|
||||
is_default: true,
|
||||
},
|
||||
];
|
||||
|
||||
test("closes modal when clicking cancel", async () => {
|
||||
mockPromptsClient.list.mockResolvedValue(mockPrompts);
|
||||
render(<PromptManagement />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("prompt_123")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Open modal
|
||||
fireEvent.click(screen.getByText("New Prompt"));
|
||||
expect(screen.getByText("Create New Prompt")).toBeInTheDocument();
|
||||
|
||||
// Close modal
|
||||
fireEvent.click(screen.getByText("Cancel"));
|
||||
expect(screen.queryByText("Create New Prompt")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("creates new prompt successfully", async () => {
|
||||
const newPrompt: Prompt = {
|
||||
prompt_id: "prompt_new",
|
||||
prompt: "New prompt content",
|
||||
version: 1,
|
||||
variables: [],
|
||||
is_default: false,
|
||||
};
|
||||
|
||||
mockPromptsClient.list.mockResolvedValue(mockPrompts);
|
||||
mockPromptsClient.create.mockResolvedValue(newPrompt);
|
||||
render(<PromptManagement />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("prompt_123")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Open modal
|
||||
fireEvent.click(screen.getByText("New Prompt"));
|
||||
|
||||
// Fill form
|
||||
const promptInput = screen.getByLabelText("Prompt Content *");
|
||||
fireEvent.change(promptInput, {
|
||||
target: { value: "New prompt content" },
|
||||
});
|
||||
|
||||
// Submit form
|
||||
fireEvent.click(screen.getByText("Create Prompt"));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockPromptsClient.create).toHaveBeenCalledWith({
|
||||
prompt: "New prompt content",
|
||||
variables: [],
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
test("handles create error gracefully", async () => {
|
||||
const error = {
|
||||
detail: {
|
||||
errors: [{ msg: "Prompt contains undeclared variables: ['test']" }],
|
||||
},
|
||||
};
|
||||
|
||||
mockPromptsClient.list.mockResolvedValue(mockPrompts);
|
||||
mockPromptsClient.create.mockRejectedValue(error);
|
||||
render(<PromptManagement />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("prompt_123")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Open modal
|
||||
fireEvent.click(screen.getByText("New Prompt"));
|
||||
|
||||
// Fill form
|
||||
const promptInput = screen.getByLabelText("Prompt Content *");
|
||||
fireEvent.change(promptInput, { target: { value: "Hello {{test}}" } });
|
||||
|
||||
// Submit form
|
||||
fireEvent.click(screen.getByText("Create Prompt"));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("Prompt contains undeclared variables: ['test']")
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("updates existing prompt successfully", async () => {
|
||||
const updatedPrompt: Prompt = {
|
||||
...mockPrompts[0],
|
||||
prompt: "Updated content",
|
||||
};
|
||||
|
||||
mockPromptsClient.list.mockResolvedValue(mockPrompts);
|
||||
mockPromptsClient.update.mockResolvedValue(updatedPrompt);
|
||||
const { container } = render(<PromptManagement />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("prompt_123")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Click edit button (first button in the action cell of the first row)
|
||||
const actionCells = container.querySelectorAll("td:last-child");
|
||||
const firstActionCell = actionCells[0];
|
||||
const editButton = firstActionCell?.querySelector("button");
|
||||
|
||||
expect(editButton).toBeInTheDocument();
|
||||
fireEvent.click(editButton!);
|
||||
|
||||
expect(screen.getByText("Edit Prompt")).toBeInTheDocument();
|
||||
|
||||
// Update content
|
||||
const promptInput = screen.getByLabelText("Prompt Content *");
|
||||
fireEvent.change(promptInput, { target: { value: "Updated content" } });
|
||||
|
||||
// Submit form
|
||||
fireEvent.click(screen.getByText("Update Prompt"));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockPromptsClient.update).toHaveBeenCalledWith("prompt_123", {
|
||||
prompt: "Updated content",
|
||||
variables: ["name"],
|
||||
version: 1,
|
||||
set_as_default: true,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
test("deletes prompt successfully", async () => {
|
||||
mockPromptsClient.list.mockResolvedValue(mockPrompts);
|
||||
mockPromptsClient.delete.mockResolvedValue(undefined);
|
||||
|
||||
// Mock window.confirm
|
||||
const originalConfirm = window.confirm;
|
||||
window.confirm = jest.fn(() => true);
|
||||
|
||||
const { container } = render(<PromptManagement />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("prompt_123")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Click delete button (second button in the action cell of the first row)
|
||||
const actionCells = container.querySelectorAll("td:last-child");
|
||||
const firstActionCell = actionCells[0];
|
||||
const buttons = firstActionCell?.querySelectorAll("button");
|
||||
const deleteButton = buttons?.[1]; // Second button should be delete
|
||||
|
||||
expect(deleteButton).toBeInTheDocument();
|
||||
fireEvent.click(deleteButton!);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockPromptsClient.delete).toHaveBeenCalledWith("prompt_123");
|
||||
});
|
||||
|
||||
// Restore window.confirm
|
||||
window.confirm = originalConfirm;
|
||||
});
|
||||
});
|
||||
});
|
||||
233
src/llama_stack/ui/components/prompts/prompt-management.tsx
Normal file
233
src/llama_stack/ui/components/prompts/prompt-management.tsx
Normal file
|
|
@ -0,0 +1,233 @@
|
|||
"use client";
|
||||
|
||||
import { useState, useEffect } from "react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Plus } from "lucide-react";
|
||||
import { PromptList } from "./prompt-list";
|
||||
import { PromptEditor } from "./prompt-editor";
|
||||
import { Prompt, PromptFormData } from "./types";
|
||||
import { useAuthClient } from "@/hooks/use-auth-client";
|
||||
|
||||
export function PromptManagement() {
|
||||
const [prompts, setPrompts] = useState<Prompt[]>([]);
|
||||
const [showPromptModal, setShowPromptModal] = useState(false);
|
||||
const [editingPrompt, setEditingPrompt] = useState<Prompt | undefined>();
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [error, setError] = useState<string | null>(null); // For main page errors (loading, etc.)
|
||||
const [modalError, setModalError] = useState<string | null>(null); // For form submission errors
|
||||
const client = useAuthClient();
|
||||
|
||||
// Load prompts from API on component mount
|
||||
useEffect(() => {
|
||||
const fetchPrompts = async () => {
|
||||
try {
|
||||
setLoading(true);
|
||||
setError(null);
|
||||
|
||||
const response = await client.prompts.list();
|
||||
setPrompts(response || []);
|
||||
} catch (err: unknown) {
|
||||
console.error("Failed to load prompts:", err);
|
||||
|
||||
// Handle different types of errors
|
||||
const error = err as Error & { status?: number };
|
||||
if (error?.message?.includes("404") || error?.status === 404) {
|
||||
setError(
|
||||
"Prompts API endpoint not found. Please ensure your Llama Stack server supports the prompts API."
|
||||
);
|
||||
} else if (
|
||||
error?.message?.includes("not implemented") ||
|
||||
error?.message?.includes("not supported")
|
||||
) {
|
||||
setError(
|
||||
"Prompts API is not yet implemented on this Llama Stack server."
|
||||
);
|
||||
} else {
|
||||
setError(
|
||||
`Failed to load prompts: ${error?.message || "Unknown error"}`
|
||||
);
|
||||
}
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
fetchPrompts();
|
||||
}, [client]);
|
||||
|
||||
const handleSavePrompt = async (formData: PromptFormData) => {
|
||||
try {
|
||||
setModalError(null);
|
||||
|
||||
if (editingPrompt) {
|
||||
// Update existing prompt
|
||||
const response = await client.prompts.update(editingPrompt.prompt_id, {
|
||||
prompt: formData.prompt,
|
||||
variables: formData.variables,
|
||||
version: editingPrompt.version,
|
||||
set_as_default: true,
|
||||
});
|
||||
|
||||
// Update local state
|
||||
setPrompts(prev =>
|
||||
prev.map(p =>
|
||||
p.prompt_id === editingPrompt.prompt_id ? response : p
|
||||
)
|
||||
);
|
||||
} else {
|
||||
// Create new prompt
|
||||
const response = await client.prompts.create({
|
||||
prompt: formData.prompt,
|
||||
variables: formData.variables,
|
||||
});
|
||||
|
||||
// Add to local state
|
||||
setPrompts(prev => [response, ...prev]);
|
||||
}
|
||||
|
||||
setShowPromptModal(false);
|
||||
setEditingPrompt(undefined);
|
||||
} catch (err) {
|
||||
console.error("Failed to save prompt:", err);
|
||||
|
||||
// Extract specific error message from API response
|
||||
const error = err as Error & {
|
||||
message?: string;
|
||||
detail?: { errors?: Array<{ msg?: string }> };
|
||||
};
|
||||
|
||||
// Try to parse JSON from error message if it's a string
|
||||
let parsedError = error;
|
||||
if (typeof error?.message === "string" && error.message.includes("{")) {
|
||||
try {
|
||||
const jsonMatch = error.message.match(/\d+\s+(.+)/);
|
||||
if (jsonMatch) {
|
||||
parsedError = JSON.parse(jsonMatch[1]);
|
||||
}
|
||||
} catch {
|
||||
// If parsing fails, use original error
|
||||
}
|
||||
}
|
||||
|
||||
// Try to get the specific validation error message
|
||||
const validationError = parsedError?.detail?.errors?.[0]?.msg;
|
||||
if (validationError) {
|
||||
// Clean up validation error messages (remove "Value error, " prefix if present)
|
||||
const cleanMessage = validationError.replace(/^Value error,\s*/i, "");
|
||||
setModalError(cleanMessage);
|
||||
} else {
|
||||
// For other errors, format them nicely with line breaks
|
||||
const statusMatch = error?.message?.match(/(\d+)\s+(.+)/);
|
||||
if (statusMatch) {
|
||||
const statusCode = statusMatch[1];
|
||||
const response = statusMatch[2];
|
||||
setModalError(
|
||||
`Failed to save prompt: Status Code ${statusCode}\n\nResponse: ${response}`
|
||||
);
|
||||
} else {
|
||||
const message = error?.message || error?.detail || "Unknown error";
|
||||
setModalError(`Failed to save prompt: ${message}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const handleEditPrompt = (prompt: Prompt) => {
|
||||
setEditingPrompt(prompt);
|
||||
setShowPromptModal(true);
|
||||
setModalError(null); // Clear any previous modal errors
|
||||
};
|
||||
|
||||
const handleDeletePrompt = async (promptId: string) => {
|
||||
try {
|
||||
setError(null);
|
||||
await client.prompts.delete(promptId);
|
||||
setPrompts(prev => prev.filter(p => p.prompt_id !== promptId));
|
||||
|
||||
// If we're deleting the currently editing prompt, close the modal
|
||||
if (editingPrompt && editingPrompt.prompt_id === promptId) {
|
||||
setShowPromptModal(false);
|
||||
setEditingPrompt(undefined);
|
||||
}
|
||||
} catch (err) {
|
||||
console.error("Failed to delete prompt:", err);
|
||||
setError("Failed to delete prompt");
|
||||
}
|
||||
};
|
||||
|
||||
const handleCreateNew = () => {
|
||||
setEditingPrompt(undefined);
|
||||
setShowPromptModal(true);
|
||||
setModalError(null); // Clear any previous modal errors
|
||||
};
|
||||
|
||||
const handleCancel = () => {
|
||||
setShowPromptModal(false);
|
||||
setEditingPrompt(undefined);
|
||||
};
|
||||
|
||||
const renderContent = () => {
|
||||
if (loading) {
|
||||
return <div className="text-muted-foreground">Loading prompts...</div>;
|
||||
}
|
||||
|
||||
if (error) {
|
||||
return <div className="text-destructive">Error: {error}</div>;
|
||||
}
|
||||
|
||||
if (!prompts || prompts.length === 0) {
|
||||
return (
|
||||
<div className="text-center py-12">
|
||||
<p className="text-muted-foreground mb-4">No prompts found.</p>
|
||||
<Button onClick={handleCreateNew}>
|
||||
<Plus className="h-4 w-4 mr-2" />
|
||||
Create Your First Prompt
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<PromptList
|
||||
prompts={prompts}
|
||||
onEdit={handleEditPrompt}
|
||||
onDelete={handleDeletePrompt}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
<div className="flex items-center justify-between">
|
||||
<h1 className="text-2xl font-semibold">Prompts</h1>
|
||||
<Button onClick={handleCreateNew} disabled={loading}>
|
||||
<Plus className="h-4 w-4 mr-2" />
|
||||
New Prompt
|
||||
</Button>
|
||||
</div>
|
||||
{renderContent()}
|
||||
|
||||
{/* Create/Edit Prompt Modal */}
|
||||
{showPromptModal && (
|
||||
<div className="fixed inset-0 bg-black/50 flex items-center justify-center z-50">
|
||||
<div className="bg-background border rounded-lg shadow-lg max-w-4xl w-full mx-4 max-h-[90vh] overflow-hidden">
|
||||
<div className="p-6 border-b">
|
||||
<h2 className="text-2xl font-bold">
|
||||
{editingPrompt ? "Edit Prompt" : "Create New Prompt"}
|
||||
</h2>
|
||||
</div>
|
||||
<div className="p-6 overflow-y-auto max-h-[calc(90vh-120px)]">
|
||||
<PromptEditor
|
||||
prompt={editingPrompt}
|
||||
onSave={handleSavePrompt}
|
||||
onCancel={handleCancel}
|
||||
onDelete={handleDeletePrompt}
|
||||
error={modalError}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
16
src/llama_stack/ui/components/prompts/types.ts
Normal file
16
src/llama_stack/ui/components/prompts/types.ts
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
export interface Prompt {
|
||||
prompt_id: string;
|
||||
prompt: string | null;
|
||||
version: number;
|
||||
variables: string[];
|
||||
is_default: boolean;
|
||||
}
|
||||
|
||||
export interface PromptFormData {
|
||||
prompt: string;
|
||||
variables: string[];
|
||||
}
|
||||
|
||||
export interface PromptFilters {
|
||||
searchTerm?: string;
|
||||
}
|
||||
36
src/llama_stack/ui/components/ui/badge.tsx
Normal file
36
src/llama_stack/ui/components/ui/badge.tsx
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
import * as React from "react";
|
||||
import { cva, type VariantProps } from "class-variance-authority";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
const badgeVariants = cva(
|
||||
"inline-flex items-center rounded-full border px-2.5 py-0.5 text-xs font-semibold transition-colors focus:outline-none focus:ring-2 focus:ring-ring focus:ring-offset-2",
|
||||
{
|
||||
variants: {
|
||||
variant: {
|
||||
default:
|
||||
"border-transparent bg-primary text-primary-foreground hover:bg-primary/80",
|
||||
secondary:
|
||||
"border-transparent bg-secondary text-secondary-foreground hover:bg-secondary/80",
|
||||
destructive:
|
||||
"border-transparent bg-destructive text-destructive-foreground hover:bg-destructive/80",
|
||||
outline: "text-foreground",
|
||||
},
|
||||
},
|
||||
defaultVariants: {
|
||||
variant: "default",
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
export interface BadgeProps
|
||||
extends React.HTMLAttributes<HTMLDivElement>,
|
||||
VariantProps<typeof badgeVariants> {}
|
||||
|
||||
function Badge({ className, variant, ...props }: BadgeProps) {
|
||||
return (
|
||||
<div className={cn(badgeVariants({ variant }), className)} {...props} />
|
||||
);
|
||||
}
|
||||
|
||||
export { Badge, badgeVariants };
|
||||
24
src/llama_stack/ui/components/ui/label.tsx
Normal file
24
src/llama_stack/ui/components/ui/label.tsx
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
import * as React from "react";
|
||||
import * as LabelPrimitive from "@radix-ui/react-label";
|
||||
import { cva, type VariantProps } from "class-variance-authority";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
const labelVariants = cva(
|
||||
"text-sm font-medium leading-none peer-disabled:cursor-not-allowed peer-disabled:opacity-70"
|
||||
);
|
||||
|
||||
const Label = React.forwardRef<
|
||||
React.ElementRef<typeof LabelPrimitive.Root>,
|
||||
React.ComponentPropsWithoutRef<typeof LabelPrimitive.Root> &
|
||||
VariantProps<typeof labelVariants>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<LabelPrimitive.Root
|
||||
ref={ref}
|
||||
className={cn(labelVariants(), className)}
|
||||
{...props}
|
||||
/>
|
||||
));
|
||||
Label.displayName = LabelPrimitive.Root.displayName;
|
||||
|
||||
export { Label };
|
||||
53
src/llama_stack/ui/components/ui/tabs.tsx
Normal file
53
src/llama_stack/ui/components/ui/tabs.tsx
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
import * as React from "react";
|
||||
import * as TabsPrimitive from "@radix-ui/react-tabs";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
const Tabs = TabsPrimitive.Root;
|
||||
|
||||
const TabsList = React.forwardRef<
|
||||
React.ElementRef<typeof TabsPrimitive.List>,
|
||||
React.ComponentPropsWithoutRef<typeof TabsPrimitive.List>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<TabsPrimitive.List
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"inline-flex h-10 items-center justify-center rounded-md bg-muted p-1 text-muted-foreground",
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
));
|
||||
TabsList.displayName = TabsPrimitive.List.displayName;
|
||||
|
||||
const TabsTrigger = React.forwardRef<
|
||||
React.ElementRef<typeof TabsPrimitive.Trigger>,
|
||||
React.ComponentPropsWithoutRef<typeof TabsPrimitive.Trigger>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<TabsPrimitive.Trigger
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"inline-flex items-center justify-center whitespace-nowrap rounded-sm px-3 py-1.5 text-sm font-medium ring-offset-background transition-all focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:pointer-events-none disabled:opacity-50 data-[state=active]:bg-background data-[state=active]:text-foreground data-[state=active]:shadow-sm",
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
));
|
||||
TabsTrigger.displayName = TabsPrimitive.Trigger.displayName;
|
||||
|
||||
const TabsContent = React.forwardRef<
|
||||
React.ElementRef<typeof TabsPrimitive.Content>,
|
||||
React.ComponentPropsWithoutRef<typeof TabsPrimitive.Content>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<TabsPrimitive.Content
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"mt-2 ring-offset-background focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2",
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
));
|
||||
TabsContent.displayName = TabsPrimitive.Content.displayName;
|
||||
|
||||
export { Tabs, TabsList, TabsTrigger, TabsContent };
|
||||
23
src/llama_stack/ui/components/ui/textarea.tsx
Normal file
23
src/llama_stack/ui/components/ui/textarea.tsx
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
import * as React from "react";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
export type TextareaProps = React.TextareaHTMLAttributes<HTMLTextAreaElement>;
|
||||
|
||||
const Textarea = React.forwardRef<HTMLTextAreaElement, TextareaProps>(
|
||||
({ className, ...props }, ref) => {
|
||||
return (
|
||||
<textarea
|
||||
className={cn(
|
||||
"flex min-h-[80px] w-full rounded-md border border-input bg-background px-3 py-2 text-sm ring-offset-background placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-50",
|
||||
className
|
||||
)}
|
||||
ref={ref}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
}
|
||||
);
|
||||
Textarea.displayName = "Textarea";
|
||||
|
||||
export { Textarea };
|
||||
62
src/llama_stack/ui/package-lock.json
generated
62
src/llama_stack/ui/package-lock.json
generated
|
|
@ -11,14 +11,16 @@
|
|||
"@radix-ui/react-collapsible": "^1.1.12",
|
||||
"@radix-ui/react-dialog": "^1.1.15",
|
||||
"@radix-ui/react-dropdown-menu": "^2.1.16",
|
||||
"@radix-ui/react-label": "^2.1.7",
|
||||
"@radix-ui/react-select": "^2.2.6",
|
||||
"@radix-ui/react-separator": "^1.1.7",
|
||||
"@radix-ui/react-slot": "^1.2.3",
|
||||
"@radix-ui/react-tabs": "^1.1.13",
|
||||
"@radix-ui/react-tooltip": "^1.2.8",
|
||||
"class-variance-authority": "^0.7.1",
|
||||
"clsx": "^2.1.1",
|
||||
"framer-motion": "^12.23.24",
|
||||
"llama-stack-client": "^0.3.0",
|
||||
"llama-stack-client": "github:llamastack/llama-stack-client-typescript",
|
||||
"lucide-react": "^0.545.0",
|
||||
"next": "15.5.4",
|
||||
"next-auth": "^4.24.11",
|
||||
|
|
@ -2597,6 +2599,29 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-label": {
|
||||
"version": "2.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-label/-/react-label-2.1.7.tgz",
|
||||
"integrity": "sha512-YT1GqPSL8kJn20djelMX7/cTRp/Y9w5IZHvfxQTVHrOqa2yMl7i/UfMqKRU5V7mEyKTrUVgJXhNQPVCG8PBLoQ==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/react-primitive": "2.1.3"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"@types/react-dom": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
|
||||
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
},
|
||||
"@types/react-dom": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-menu": {
|
||||
"version": "2.1.16",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-menu/-/react-menu-2.1.16.tgz",
|
||||
|
|
@ -2855,6 +2880,36 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-tabs": {
|
||||
"version": "1.1.13",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-tabs/-/react-tabs-1.1.13.tgz",
|
||||
"integrity": "sha512-7xdcatg7/U+7+Udyoj2zodtI9H/IIopqo+YOIcZOq1nJwXWBZ9p8xiu5llXlekDbZkca79a/fozEYQXIA4sW6A==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/primitive": "1.1.3",
|
||||
"@radix-ui/react-context": "1.1.2",
|
||||
"@radix-ui/react-direction": "1.1.1",
|
||||
"@radix-ui/react-id": "1.1.1",
|
||||
"@radix-ui/react-presence": "1.1.5",
|
||||
"@radix-ui/react-primitive": "2.1.3",
|
||||
"@radix-ui/react-roving-focus": "1.1.11",
|
||||
"@radix-ui/react-use-controllable-state": "1.2.2"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"@types/react-dom": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
|
||||
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
},
|
||||
"@types/react-dom": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-tooltip": {
|
||||
"version": "1.2.8",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-tooltip/-/react-tooltip-1.2.8.tgz",
|
||||
|
|
@ -9629,9 +9684,8 @@
|
|||
"license": "MIT"
|
||||
},
|
||||
"node_modules/llama-stack-client": {
|
||||
"version": "0.3.0",
|
||||
"resolved": "https://registry.npmjs.org/llama-stack-client/-/llama-stack-client-0.3.0.tgz",
|
||||
"integrity": "sha512-76K/t1doaGmlBbDxCADaral9Vccvys9P8pqAMIhwBhMAqWudCEORrMMhUSg+pjhamWmEKj3wa++d4zeOGbfN/w==",
|
||||
"version": "0.4.0-alpha.1",
|
||||
"resolved": "git+ssh://git@github.com/llamastack/llama-stack-client-typescript.git#78de4862c4b7d77939ac210fa9f9bde77a2c5c5f",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@types/node": "^18.11.18",
|
||||
|
|
|
|||
|
|
@ -16,14 +16,16 @@
|
|||
"@radix-ui/react-collapsible": "^1.1.12",
|
||||
"@radix-ui/react-dialog": "^1.1.15",
|
||||
"@radix-ui/react-dropdown-menu": "^2.1.16",
|
||||
"@radix-ui/react-label": "^2.1.7",
|
||||
"@radix-ui/react-select": "^2.2.6",
|
||||
"@radix-ui/react-separator": "^1.1.7",
|
||||
"@radix-ui/react-slot": "^1.2.3",
|
||||
"@radix-ui/react-tabs": "^1.1.13",
|
||||
"@radix-ui/react-tooltip": "^1.2.8",
|
||||
"class-variance-authority": "^0.7.1",
|
||||
"clsx": "^2.1.1",
|
||||
"framer-motion": "^12.23.24",
|
||||
"llama-stack-client": "^0.3.0",
|
||||
"llama-stack-client": "github:llamastack/llama-stack-client-typescript",
|
||||
"lucide-react": "^0.545.0",
|
||||
"next": "15.5.4",
|
||||
"next-auth": "^4.24.11",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue