Merge branch 'main' into toolcall-arg-recursive-type

This commit is contained in:
Ben Keith 2025-11-04 11:18:07 -05:00 committed by GitHub
commit 56aa508f82
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
443 changed files with 98114 additions and 79944 deletions

View file

@ -491,13 +491,6 @@ class Agents(Protocol):
APIs for creating and interacting with agentic systems."""
@webmethod(
route="/agents",
method="POST",
descriptive_name="create_agent",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/agents",
method="POST",
@ -515,13 +508,6 @@ class Agents(Protocol):
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn",
method="POST",
descriptive_name="create_agent_turn",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn",
method="POST",
@ -552,13 +538,6 @@ class Agents(Protocol):
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
method="POST",
descriptive_name="resume_agent_turn",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
method="POST",
@ -586,12 +565,6 @@ class Agents(Protocol):
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
method="GET",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
method="GET",
@ -612,12 +585,6 @@ class Agents(Protocol):
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}",
method="GET",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}",
method="GET",
@ -640,13 +607,6 @@ class Agents(Protocol):
"""
...
@webmethod(
route="/agents/{agent_id}/session",
method="POST",
descriptive_name="create_agent_session",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/agents/{agent_id}/session",
method="POST",
@ -666,12 +626,6 @@ class Agents(Protocol):
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}",
method="GET",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/agents/{agent_id}/session/{session_id}",
method="GET",
@ -692,12 +646,6 @@ class Agents(Protocol):
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}",
method="DELETE",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/agents/{agent_id}/session/{session_id}",
method="DELETE",
@ -715,12 +663,6 @@ class Agents(Protocol):
"""
...
@webmethod(
route="/agents/{agent_id}",
method="DELETE",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(route="/agents/{agent_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
async def delete_agent(
self,
@ -732,7 +674,6 @@ class Agents(Protocol):
"""
...
@webmethod(route="/agents", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
@webmethod(route="/agents", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
"""List all agents.
@ -743,12 +684,6 @@ class Agents(Protocol):
"""
...
@webmethod(
route="/agents/{agent_id}",
method="GET",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(route="/agents/{agent_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def get_agent(self, agent_id: str) -> Agent:
"""Describe an agent by its ID.
@ -758,12 +693,6 @@ class Agents(Protocol):
"""
...
@webmethod(
route="/agents/{agent_id}/sessions",
method="GET",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(route="/agents/{agent_id}/sessions", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def list_agent_sessions(
self,
@ -787,12 +716,6 @@ class Agents(Protocol):
#
# Both of these APIs are inherently stateful.
@webmethod(
route="/openai/v1/responses/{response_id}",
method="GET",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(route="/responses/{response_id}", method="GET", level=LLAMA_STACK_API_V1)
async def get_openai_response(
self,
@ -805,7 +728,6 @@ class Agents(Protocol):
"""
...
@webmethod(route="/openai/v1/responses", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/responses", method="POST", level=LLAMA_STACK_API_V1)
async def create_openai_response(
self,
@ -842,7 +764,6 @@ class Agents(Protocol):
"""
...
@webmethod(route="/openai/v1/responses", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/responses", method="GET", level=LLAMA_STACK_API_V1)
async def list_openai_responses(
self,
@ -861,9 +782,6 @@ class Agents(Protocol):
"""
...
@webmethod(
route="/openai/v1/responses/{response_id}/input_items", method="GET", level=LLAMA_STACK_API_V1, deprecated=True
)
@webmethod(route="/responses/{response_id}/input_items", method="GET", level=LLAMA_STACK_API_V1)
async def list_openai_response_input_items(
self,
@ -886,7 +804,6 @@ class Agents(Protocol):
"""
...
@webmethod(route="/openai/v1/responses/{response_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/responses/{response_id}", method="DELETE", level=LLAMA_STACK_API_V1)
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
"""Delete a response.

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Sequence
from typing import Annotated, Any, Literal
from pydantic import BaseModel, Field, model_validator
@ -202,7 +203,7 @@ class OpenAIResponseMessage(BaseModel):
scenarios.
"""
content: str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]
content: str | Sequence[OpenAIResponseInputMessageContent] | Sequence[OpenAIResponseOutputMessageContent]
role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"]
type: Literal["message"] = "message"
@ -254,10 +255,10 @@ class OpenAIResponseOutputMessageFileSearchToolCall(BaseModel):
"""
id: str
queries: list[str]
queries: Sequence[str]
status: str
type: Literal["file_search_call"] = "file_search_call"
results: list[OpenAIResponseOutputMessageFileSearchToolCallResults] | None = None
results: Sequence[OpenAIResponseOutputMessageFileSearchToolCallResults] | None = None
@json_schema_type
@ -597,7 +598,7 @@ class OpenAIResponseObject(BaseModel):
id: str
model: str
object: Literal["response"] = "response"
output: list[OpenAIResponseOutput]
output: Sequence[OpenAIResponseOutput]
parallel_tool_calls: bool = False
previous_response_id: str | None = None
prompt: OpenAIResponsePrompt | None = None
@ -607,7 +608,7 @@ class OpenAIResponseObject(BaseModel):
# before the field was added. New responses will have this set always.
text: OpenAIResponseText = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text"))
top_p: float | None = None
tools: list[OpenAIResponseTool] | None = None
tools: Sequence[OpenAIResponseTool] | None = None
truncation: str | None = None
usage: OpenAIResponseUsage | None = None
instructions: str | None = None
@ -1315,7 +1316,7 @@ class ListOpenAIResponseInputItem(BaseModel):
:param object: Object type identifier, always "list"
"""
data: list[OpenAIResponseInput]
data: Sequence[OpenAIResponseInput]
object: Literal["list"] = "list"
@ -1326,7 +1327,7 @@ class OpenAIResponseObjectWithInput(OpenAIResponseObject):
:param input: List of input items that led to this response
"""
input: list[OpenAIResponseInput]
input: Sequence[OpenAIResponseInput]
def to_response_object(self) -> OpenAIResponseObject:
"""Convert to OpenAIResponseObject by excluding input field."""
@ -1344,7 +1345,7 @@ class ListOpenAIResponseObject(BaseModel):
:param object: Object type identifier, always "list"
"""
data: list[OpenAIResponseObjectWithInput]
data: Sequence[OpenAIResponseObjectWithInput]
has_more: bool
first_id: str
last_id: str

View file

@ -43,7 +43,6 @@ class Batches(Protocol):
Note: This API is currently under active development and may undergo changes.
"""
@webmethod(route="/openai/v1/batches", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/batches", method="POST", level=LLAMA_STACK_API_V1)
async def create_batch(
self,
@ -64,7 +63,6 @@ class Batches(Protocol):
"""
...
@webmethod(route="/openai/v1/batches/{batch_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/batches/{batch_id}", method="GET", level=LLAMA_STACK_API_V1)
async def retrieve_batch(self, batch_id: str) -> BatchObject:
"""Retrieve information about a specific batch.
@ -74,7 +72,6 @@ class Batches(Protocol):
"""
...
@webmethod(route="/openai/v1/batches/{batch_id}/cancel", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/batches/{batch_id}/cancel", method="POST", level=LLAMA_STACK_API_V1)
async def cancel_batch(self, batch_id: str) -> BatchObject:
"""Cancel a batch that is in progress.
@ -84,7 +81,6 @@ class Batches(Protocol):
"""
...
@webmethod(route="/openai/v1/batches", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/batches", method="GET", level=LLAMA_STACK_API_V1)
async def list_batches(
self,

View file

@ -8,7 +8,7 @@ from typing import Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel, Field
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
from llama_stack.schema_utils import json_schema_type, webmethod
@ -54,7 +54,6 @@ class ListBenchmarksResponse(BaseModel):
@runtime_checkable
class Benchmarks(Protocol):
@webmethod(route="/eval/benchmarks", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/eval/benchmarks", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def list_benchmarks(self) -> ListBenchmarksResponse:
"""List all benchmarks.
@ -63,7 +62,6 @@ class Benchmarks(Protocol):
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def get_benchmark(
self,
@ -76,7 +74,6 @@ class Benchmarks(Protocol):
"""
...
@webmethod(route="/eval/benchmarks", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/eval/benchmarks", method="POST", level=LLAMA_STACK_API_V1ALPHA)
async def register_benchmark(
self,
@ -98,7 +95,6 @@ class Benchmarks(Protocol):
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
async def unregister_benchmark(self, benchmark_id: str) -> None:
"""Unregister a benchmark.

View file

@ -8,7 +8,7 @@ from typing import Any, Protocol, runtime_checkable
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datasets import Dataset
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1BETA
from llama_stack.apis.version import LLAMA_STACK_API_V1BETA
from llama_stack.schema_utils import webmethod
@ -21,7 +21,6 @@ class DatasetIO(Protocol):
# keeping for aligning with inference/safety, but this is not used
dataset_store: DatasetStore
@webmethod(route="/datasetio/iterrows/{dataset_id:path}", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
@webmethod(route="/datasetio/iterrows/{dataset_id:path}", method="GET", level=LLAMA_STACK_API_V1BETA)
async def iterrows(
self,
@ -46,9 +45,6 @@ class DatasetIO(Protocol):
"""
...
@webmethod(
route="/datasetio/append-rows/{dataset_id:path}", method="POST", deprecated=True, level=LLAMA_STACK_API_V1
)
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST", level=LLAMA_STACK_API_V1BETA)
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
"""Append rows to a dataset.

View file

@ -10,7 +10,7 @@ from typing import Annotated, Any, Literal, Protocol
from pydantic import BaseModel, Field
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1BETA
from llama_stack.apis.version import LLAMA_STACK_API_V1BETA
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@ -146,7 +146,6 @@ class ListDatasetsResponse(BaseModel):
class Datasets(Protocol):
@webmethod(route="/datasets", method="POST", deprecated=True, level=LLAMA_STACK_API_V1)
@webmethod(route="/datasets", method="POST", level=LLAMA_STACK_API_V1BETA)
async def register_dataset(
self,
@ -216,7 +215,6 @@ class Datasets(Protocol):
"""
...
@webmethod(route="/datasets/{dataset_id:path}", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
@webmethod(route="/datasets/{dataset_id:path}", method="GET", level=LLAMA_STACK_API_V1BETA)
async def get_dataset(
self,
@ -229,7 +227,6 @@ class Datasets(Protocol):
"""
...
@webmethod(route="/datasets", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
@webmethod(route="/datasets", method="GET", level=LLAMA_STACK_API_V1BETA)
async def list_datasets(self) -> ListDatasetsResponse:
"""List all datasets.
@ -238,7 +235,6 @@ class Datasets(Protocol):
"""
...
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE", deprecated=True, level=LLAMA_STACK_API_V1)
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE", level=LLAMA_STACK_API_V1BETA)
async def unregister_dataset(
self,

View file

@ -13,7 +13,7 @@ from llama_stack.apis.common.job_types import Job
from llama_stack.apis.inference import SamplingParams, SystemMessage
from llama_stack.apis.scoring import ScoringResult
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@ -86,7 +86,6 @@ class Eval(Protocol):
Llama Stack Evaluation API for running evaluations on model and agent candidates."""
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST", level=LLAMA_STACK_API_V1ALPHA)
async def run_eval(
self,
@ -101,9 +100,6 @@ class Eval(Protocol):
"""
...
@webmethod(
route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST", level=LLAMA_STACK_API_V1, deprecated=True
)
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST", level=LLAMA_STACK_API_V1ALPHA)
async def evaluate_rows(
self,
@ -122,9 +118,6 @@ class Eval(Protocol):
"""
...
@webmethod(
route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True
)
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
"""Get the status of a job.
@ -135,12 +128,6 @@ class Eval(Protocol):
"""
...
@webmethod(
route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}",
method="DELETE",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
"""Cancel a job.
@ -150,12 +137,6 @@ class Eval(Protocol):
"""
...
@webmethod(
route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result",
method="GET",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET", level=LLAMA_STACK_API_V1ALPHA
)

View file

@ -110,7 +110,6 @@ class Files(Protocol):
"""
# OpenAI Files API Endpoints
@webmethod(route="/openai/v1/files", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/files", method="POST", level=LLAMA_STACK_API_V1)
async def openai_upload_file(
self,
@ -134,7 +133,6 @@ class Files(Protocol):
"""
...
@webmethod(route="/openai/v1/files", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/files", method="GET", level=LLAMA_STACK_API_V1)
async def openai_list_files(
self,
@ -155,7 +153,6 @@ class Files(Protocol):
"""
...
@webmethod(route="/openai/v1/files/{file_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/files/{file_id}", method="GET", level=LLAMA_STACK_API_V1)
async def openai_retrieve_file(
self,
@ -170,7 +167,6 @@ class Files(Protocol):
"""
...
@webmethod(route="/openai/v1/files/{file_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/files/{file_id}", method="DELETE", level=LLAMA_STACK_API_V1)
async def openai_delete_file(
self,
@ -183,7 +179,6 @@ class Files(Protocol):
"""
...
@webmethod(route="/openai/v1/files/{file_id}/content", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/files/{file_id}/content", method="GET", level=LLAMA_STACK_API_V1)
async def openai_retrieve_file_content(
self,

View file

@ -1189,7 +1189,6 @@ class InferenceProvider(Protocol):
raise NotImplementedError("Reranking is not implemented")
return # this is so mypy's safe-super rule will consider the method concrete
@webmethod(route="/openai/v1/completions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/completions", method="POST", level=LLAMA_STACK_API_V1)
async def openai_completion(
self,
@ -1202,7 +1201,6 @@ class InferenceProvider(Protocol):
"""
...
@webmethod(route="/openai/v1/chat/completions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/chat/completions", method="POST", level=LLAMA_STACK_API_V1)
async def openai_chat_completion(
self,
@ -1215,7 +1213,6 @@ class InferenceProvider(Protocol):
"""
...
@webmethod(route="/openai/v1/embeddings", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/embeddings", method="POST", level=LLAMA_STACK_API_V1)
async def openai_embeddings(
self,
@ -1240,7 +1237,6 @@ class Inference(InferenceProvider):
- Rerank models: these models reorder the documents based on their relevance to a query.
"""
@webmethod(route="/openai/v1/chat/completions", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/chat/completions", method="GET", level=LLAMA_STACK_API_V1)
async def list_chat_completions(
self,
@ -1259,9 +1255,6 @@ class Inference(InferenceProvider):
"""
raise NotImplementedError("List chat completions is not implemented")
@webmethod(
route="/openai/v1/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True
)
@webmethod(route="/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1)
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
"""Get chat completion.

View file

@ -4,14 +4,21 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Protocol, runtime_checkable
from typing import Literal, Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.apis.version import (
LLAMA_STACK_API_V1,
)
from llama_stack.providers.datatypes import HealthStatus
from llama_stack.schema_utils import json_schema_type, webmethod
# Valid values for the route filter parameter.
# Actual API levels: v1, v1alpha, v1beta (filters by level, excludes deprecated)
# Special filter value: "deprecated" (shows deprecated routes regardless of level)
ApiFilter = Literal["v1", "v1alpha", "v1beta", "deprecated"]
@json_schema_type
class RouteInfo(BaseModel):
@ -64,11 +71,12 @@ class Inspect(Protocol):
"""
@webmethod(route="/inspect/routes", method="GET", level=LLAMA_STACK_API_V1)
async def list_routes(self) -> ListRoutesResponse:
async def list_routes(self, api_filter: ApiFilter | None = None) -> ListRoutesResponse:
"""List routes.
List all available API routes with their methods and implementing providers.
:param api_filter: Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, returns only non-deprecated v1 routes.
:returns: Response containing information about all available routes.
"""
...

View file

@ -90,12 +90,14 @@ class OpenAIModel(BaseModel):
:object: The object type, which will be "model"
:created: The Unix timestamp in seconds when the model was created
:owned_by: The owner of the model
:custom_metadata: Llama Stack-specific metadata including model_type, provider info, and additional metadata
"""
id: str
object: Literal["model"] = "model"
created: int
owned_by: str
custom_metadata: dict[str, Any] | None = None
class OpenAIListModelsResponse(BaseModel):
@ -105,7 +107,6 @@ class OpenAIListModelsResponse(BaseModel):
@runtime_checkable
@trace_protocol
class Models(Protocol):
@webmethod(route="/models", method="GET", level=LLAMA_STACK_API_V1)
async def list_models(self) -> ListModelsResponse:
"""List all models.
@ -113,7 +114,7 @@ class Models(Protocol):
"""
...
@webmethod(route="/openai/v1/models", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/models", method="GET", level=LLAMA_STACK_API_V1)
async def openai_list_models(self) -> OpenAIListModelsResponse:
"""List models using the OpenAI API.

View file

@ -13,7 +13,7 @@ from pydantic import BaseModel, Field
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.common.training_types import Checkpoint
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@ -284,7 +284,6 @@ class PostTrainingJobArtifactsResponse(BaseModel):
class PostTraining(Protocol):
@webmethod(route="/post-training/supervised-fine-tune", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/post-training/supervised-fine-tune", method="POST", level=LLAMA_STACK_API_V1ALPHA)
async def supervised_fine_tune(
self,
@ -312,7 +311,6 @@ class PostTraining(Protocol):
"""
...
@webmethod(route="/post-training/preference-optimize", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/post-training/preference-optimize", method="POST", level=LLAMA_STACK_API_V1ALPHA)
async def preference_optimize(
self,
@ -335,7 +333,6 @@ class PostTraining(Protocol):
"""
...
@webmethod(route="/post-training/jobs", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/post-training/jobs", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
"""Get all training jobs.
@ -344,7 +341,6 @@ class PostTraining(Protocol):
"""
...
@webmethod(route="/post-training/job/status", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/post-training/job/status", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse:
"""Get the status of a training job.
@ -354,7 +350,6 @@ class PostTraining(Protocol):
"""
...
@webmethod(route="/post-training/job/cancel", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/post-training/job/cancel", method="POST", level=LLAMA_STACK_API_V1ALPHA)
async def cancel_training_job(self, job_uuid: str) -> None:
"""Cancel a training job.
@ -363,7 +358,6 @@ class PostTraining(Protocol):
"""
...
@webmethod(route="/post-training/job/artifacts", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/post-training/job/artifacts", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
"""Get the artifacts of a training job.

View file

@ -121,7 +121,6 @@ class Safety(Protocol):
"""
...
@webmethod(route="/openai/v1/moderations", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/moderations", method="POST", level=LLAMA_STACK_API_V1)
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
"""Create moderation.

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .synthetic_data_generation import *

View file

@ -1,77 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Protocol
from pydantic import BaseModel
from llama_stack.apis.inference import Message
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.schema_utils import json_schema_type, webmethod
class FilteringFunction(Enum):
"""The type of filtering function.
:cvar none: No filtering applied, accept all generated synthetic data
:cvar random: Random sampling of generated data points
:cvar top_k: Keep only the top-k highest scoring synthetic data samples
:cvar top_p: Nucleus-style filtering, keep samples exceeding cumulative score threshold
:cvar top_k_top_p: Combined top-k and top-p filtering strategy
:cvar sigmoid: Apply sigmoid function for probability-based filtering
"""
none = "none"
random = "random"
top_k = "top_k"
top_p = "top_p"
top_k_top_p = "top_k_top_p"
sigmoid = "sigmoid"
@json_schema_type
class SyntheticDataGenerationRequest(BaseModel):
"""Request to generate synthetic data. A small batch of prompts and a filtering function
:param dialogs: List of conversation messages to use as input for synthetic data generation
:param filtering_function: Type of filtering to apply to generated synthetic data samples
:param model: (Optional) The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint
"""
dialogs: list[Message]
filtering_function: FilteringFunction = FilteringFunction.none
model: str | None = None
@json_schema_type
class SyntheticDataGenerationResponse(BaseModel):
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold.
:param synthetic_data: List of generated synthetic data samples that passed the filtering criteria
:param statistics: (Optional) Statistical information about the generation process and filtering results
"""
synthetic_data: list[dict[str, Any]]
statistics: dict[str, Any] | None = None
class SyntheticDataGeneration(Protocol):
@webmethod(route="/synthetic-data-generation/generate", level=LLAMA_STACK_API_V1)
def synthetic_data_generate(
self,
dialogs: list[Message],
filtering_function: FilteringFunction = FilteringFunction.none,
model: str | None = None,
) -> SyntheticDataGenerationResponse:
"""Generate synthetic data based on input dialogs and apply filtering.
:param dialogs: List of conversation messages to use as input for synthetic data generation
:param filtering_function: Type of filtering to apply to generated synthetic data samples
:param model: (Optional) The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint
:returns: Response containing filtered synthetic data samples and optional statistics
"""
...

View file

@ -8,7 +8,6 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
from fastapi import Body
@ -18,7 +17,6 @@ from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
from llama_stack.schema_utils import json_schema_type, webmethod
from llama_stack.strong_typing.schema import register_schema
@ -61,38 +59,19 @@ class Chunk(BaseModel):
"""
A chunk of content that can be inserted into a vector database.
:param content: The content of the chunk, which can be interleaved text, images, or other types.
:param embedding: Optional embedding for the chunk. If not provided, it will be computed later.
:param chunk_id: Unique identifier for the chunk. Must be provided explicitly.
:param metadata: Metadata associated with the chunk that will be used in the model context during inference.
:param stored_chunk_id: The chunk ID that is stored in the vector database. Used for backend functionality.
:param embedding: Optional embedding for the chunk. If not provided, it will be computed later.
:param chunk_metadata: Metadata for the chunk that will NOT be used in the context during inference.
The `chunk_metadata` is required backend functionality.
"""
content: InterleavedContent
chunk_id: str
metadata: dict[str, Any] = Field(default_factory=dict)
embedding: list[float] | None = None
# The alias parameter serializes the field as "chunk_id" in JSON but keeps the internal name as "stored_chunk_id"
stored_chunk_id: str | None = Field(default=None, alias="chunk_id")
chunk_metadata: ChunkMetadata | None = None
model_config = {"populate_by_name": True}
def model_post_init(self, __context):
# Extract chunk_id from metadata if present
if self.metadata and "chunk_id" in self.metadata:
self.stored_chunk_id = self.metadata.pop("chunk_id")
@property
def chunk_id(self) -> str:
"""Returns the chunk ID, which is either an input `chunk_id` or a generated one if not set."""
if self.stored_chunk_id:
return self.stored_chunk_id
if "document_id" in self.metadata:
return generate_chunk_id(self.metadata["document_id"], str(self.content))
return generate_chunk_id(str(uuid.uuid4()), str(self.content))
@property
def document_id(self) -> str | None:
"""Returns the document_id from either metadata or chunk_metadata, with metadata taking precedence."""
@ -566,7 +545,6 @@ class VectorIO(Protocol):
...
# OpenAI Vector Stores API endpoints
@webmethod(route="/openai/v1/vector_stores", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/vector_stores", method="POST", level=LLAMA_STACK_API_V1)
async def openai_create_vector_store(
self,
@ -579,7 +557,6 @@ class VectorIO(Protocol):
"""
...
@webmethod(route="/openai/v1/vector_stores", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/vector_stores", method="GET", level=LLAMA_STACK_API_V1)
async def openai_list_vector_stores(
self,
@ -598,9 +575,6 @@ class VectorIO(Protocol):
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True
)
@webmethod(route="/vector_stores/{vector_store_id}", method="GET", level=LLAMA_STACK_API_V1)
async def openai_retrieve_vector_store(
self,
@ -613,9 +587,6 @@ class VectorIO(Protocol):
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}", method="POST", level=LLAMA_STACK_API_V1, deprecated=True
)
@webmethod(
route="/vector_stores/{vector_store_id}",
method="POST",
@ -638,9 +609,6 @@ class VectorIO(Protocol):
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True
)
@webmethod(
route="/vector_stores/{vector_store_id}",
method="DELETE",
@ -657,12 +625,6 @@ class VectorIO(Protocol):
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/search",
method="POST",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/vector_stores/{vector_store_id}/search",
method="POST",
@ -695,12 +657,6 @@ class VectorIO(Protocol):
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/files",
method="POST",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/vector_stores/{vector_store_id}/files",
method="POST",
@ -723,12 +679,6 @@ class VectorIO(Protocol):
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/files",
method="GET",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/vector_stores/{vector_store_id}/files",
method="GET",
@ -755,12 +705,6 @@ class VectorIO(Protocol):
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}",
method="GET",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/vector_stores/{vector_store_id}/files/{file_id}",
method="GET",
@ -779,12 +723,6 @@ class VectorIO(Protocol):
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}/content",
method="GET",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/vector_stores/{vector_store_id}/files/{file_id}/content",
method="GET",
@ -803,12 +741,6 @@ class VectorIO(Protocol):
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}",
method="POST",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/vector_stores/{vector_store_id}/files/{file_id}",
method="POST",
@ -829,12 +761,6 @@ class VectorIO(Protocol):
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}",
method="DELETE",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/vector_stores/{vector_store_id}/files/{file_id}",
method="DELETE",
@ -858,12 +784,6 @@ class VectorIO(Protocol):
method="POST",
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/file_batches",
method="POST",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
async def openai_create_vector_store_file_batch(
self,
vector_store_id: str,
@ -882,12 +802,6 @@ class VectorIO(Protocol):
method="GET",
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/file_batches/{batch_id}",
method="GET",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
async def openai_retrieve_vector_store_file_batch(
self,
batch_id: str,
@ -901,12 +815,6 @@ class VectorIO(Protocol):
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/file_batches/{batch_id}/files",
method="GET",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/vector_stores/{vector_store_id}/file_batches/{batch_id}/files",
method="GET",
@ -935,12 +843,6 @@ class VectorIO(Protocol):
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/file_batches/{batch_id}/cancel",
method="POST",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/vector_stores/{vector_store_id}/file_batches/{batch_id}/cancel",
method="POST",

View file

@ -8,16 +8,30 @@ import argparse
import os
import ssl
import subprocess
import sys
from pathlib import Path
import uvicorn
import yaml
from termcolor import cprint
from llama_stack.cli.stack.utils import ImageType
from llama_stack.cli.subcommand import Subcommand
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.core.datatypes import Api, Provider, StackRunConfig
from llama_stack.core.distribution import get_provider_registry
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
from llama_stack.core.storage.datatypes import (
InferenceStoreReference,
KVStoreReference,
ServerStoresConfig,
SqliteKVStoreConfig,
SqliteSqlStoreConfig,
SqlStoreReference,
StorageConfig,
)
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.log import LoggingConfig, get_logger
REPO_ROOT = Path(__file__).parent.parent.parent.parent
@ -68,6 +82,12 @@ class StackRun(Subcommand):
action="store_true",
help="Start the UI server",
)
self.parser.add_argument(
"--providers",
type=str,
default=None,
help="Run a stack with only a list of providers. This list is formatted like: api1=provider1,api1=provider2,api2=provider3. Where there can be multiple providers per API.",
)
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
import yaml
@ -93,6 +113,55 @@ class StackRun(Subcommand):
config_file = resolve_config_or_distro(args.config, Mode.RUN)
except ValueError as e:
self.parser.error(str(e))
elif args.providers:
provider_list: dict[str, list[Provider]] = dict()
for api_provider in args.providers.split(","):
if "=" not in api_provider:
cprint(
"Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2",
color="red",
file=sys.stderr,
)
sys.exit(1)
api, provider_type = api_provider.split("=")
providers_for_api = get_provider_registry().get(Api(api), None)
if providers_for_api is None:
cprint(
f"{api} is not a valid API.",
color="red",
file=sys.stderr,
)
sys.exit(1)
if provider_type in providers_for_api:
config_type = instantiate_class_type(providers_for_api[provider_type].config_class)
if config_type is not None and hasattr(config_type, "sample_run_config"):
config = config_type.sample_run_config(__distro_dir__="~/.llama/distributions/providers-run")
else:
config = {}
provider = Provider(
provider_type=provider_type,
config=config,
provider_id=provider_type.split("::")[1],
)
provider_list.setdefault(api, []).append(provider)
else:
cprint(
f"{provider} is not a valid provider for the {api} API.",
color="red",
file=sys.stderr,
)
sys.exit(1)
run_config = self._generate_run_config_from_providers(providers=provider_list)
config_dict = run_config.model_dump(mode="json")
# Write config to disk in providers-run directory
distro_dir = DISTRIBS_BASE_DIR / "providers-run"
config_file = distro_dir / "run.yaml"
logger.info(f"Writing generated config to: {config_file}")
with open(config_file, "w") as f:
yaml.dump(config_dict, f, default_flow_style=False, sort_keys=False)
else:
config_file = None
@ -106,7 +175,8 @@ class StackRun(Subcommand):
try:
config = parse_and_maybe_upgrade_config(config_dict)
if not os.path.exists(str(config.external_providers_dir)):
# Create external_providers_dir if it's specified and doesn't exist
if config.external_providers_dir and not os.path.exists(str(config.external_providers_dir)):
os.makedirs(str(config.external_providers_dir), exist_ok=True)
except AttributeError as e:
self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
@ -127,7 +197,7 @@ class StackRun(Subcommand):
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
port = args.port or config.server.port
host = config.server.host or ["::", "0.0.0.0"]
host = config.server.host or "0.0.0.0"
# Set the config file in environment so create_app can find it
os.environ["LLAMA_STACK_CONFIG"] = str(config_file)
@ -139,6 +209,7 @@ class StackRun(Subcommand):
"lifespan": "on",
"log_level": logger.getEffectiveLevel(),
"log_config": logger_config,
"workers": config.server.workers,
}
keyfile = config.server.tls_keyfile
@ -212,3 +283,44 @@ class StackRun(Subcommand):
)
except Exception as e:
logger.error(f"Failed to start UI development server in {ui_dir}: {e}")
def _generate_run_config_from_providers(self, providers: dict[str, list[Provider]]):
apis = list(providers.keys())
distro_dir = DISTRIBS_BASE_DIR / "providers-run"
# need somewhere to put the storage.
os.makedirs(distro_dir, exist_ok=True)
storage = StorageConfig(
backends={
"kv_default": SqliteKVStoreConfig(
db_path=f"${{env.SQLITE_STORE_DIR:={distro_dir}}}/kvstore.db",
),
"sql_default": SqliteSqlStoreConfig(
db_path=f"${{env.SQLITE_STORE_DIR:={distro_dir}}}/sql_store.db",
),
},
stores=ServerStoresConfig(
metadata=KVStoreReference(
backend="kv_default",
namespace="registry",
),
inference=InferenceStoreReference(
backend="sql_default",
table_name="inference_store",
),
conversations=SqlStoreReference(
backend="sql_default",
table_name="openai_conversations",
),
prompts=KVStoreReference(
backend="kv_default",
namespace="prompts",
),
),
)
return StackRunConfig(
image_name="providers-run",
apis=apis,
providers=providers,
storage=storage,
)

View file

@ -17,7 +17,6 @@ from llama_stack.core.distribution import (
get_provider_registry,
)
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
from llama_stack.core.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.core.utils.prompt_for_config import prompt_for_config
from llama_stack.log import get_logger
@ -194,19 +193,11 @@ def upgrade_from_routing_table(
def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfig:
version = config_dict.get("version", None)
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
processed_config_dict = replace_env_vars(config_dict)
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))
if "routing_table" in config_dict:
logger.info("Upgrading config...")
config_dict = upgrade_from_routing_table(config_dict)
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
if not config_dict.get("external_providers_dir", None):
config_dict["external_providers_dir"] = EXTERNAL_PROVIDERS_DIR
processed_config_dict = replace_env_vars(config_dict)
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))

View file

@ -473,6 +473,10 @@ class ServerConfig(BaseModel):
"- true: Enable localhost CORS for development\n"
"- {allow_origins: [...], allow_methods: [...], ...}: Full configuration",
)
workers: int = Field(
default=1,
description="Number of workers to use for the server",
)
class StackRunConfig(BaseModel):

View file

@ -15,6 +15,7 @@ from llama_stack.apis.inspect import (
RouteInfo,
VersionInfo,
)
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.core.external import load_external_apis
from llama_stack.core.server.routes import get_all_api_routes
@ -39,9 +40,21 @@ class DistributionInspectImpl(Inspect):
async def initialize(self) -> None:
pass
async def list_routes(self) -> ListRoutesResponse:
async def list_routes(self, api_filter: str | None = None) -> ListRoutesResponse:
run_config: StackRunConfig = self.config.run_config
# Helper function to determine if a route should be included based on api_filter
def should_include_route(webmethod) -> bool:
if api_filter is None:
# Default: only non-deprecated v1 APIs
return not webmethod.deprecated and webmethod.level == LLAMA_STACK_API_V1
elif api_filter == "deprecated":
# Special filter: show deprecated routes regardless of their actual level
return bool(webmethod.deprecated)
else:
# Filter by API level (non-deprecated routes only)
return not webmethod.deprecated and webmethod.level == api_filter
ret = []
external_apis = load_external_apis(run_config)
all_endpoints = get_all_api_routes(external_apis)
@ -55,8 +68,8 @@ class DistributionInspectImpl(Inspect):
method=next(iter([m for m in e.methods if m != "HEAD"])),
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
)
for e, _ in endpoints
if e.methods is not None
for e, webmethod in endpoints
if e.methods is not None and should_include_route(webmethod)
]
)
else:
@ -69,8 +82,8 @@ class DistributionInspectImpl(Inspect):
method=next(iter([m for m in e.methods if m != "HEAD"])),
provider_types=[p.provider_type for p in providers],
)
for e, _ in endpoints
if e.methods is not None
for e, webmethod in endpoints
if e.methods is not None and should_include_route(webmethod)
]
)

View file

@ -13,6 +13,8 @@ from llama_stack.core.datatypes import (
ModelWithOwner,
RegistryEntrySource,
)
from llama_stack.core.request_headers import PROVIDER_DATA_VAR, NeedsRequestProviderData
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl, lookup_model
@ -42,19 +44,104 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
await self.update_registered_models(provider_id, models)
async def _get_dynamic_models_from_provider_data(self) -> list[Model]:
"""
Fetch models from providers that have credentials in the current request's provider_data.
This allows users to see models available to them from providers that require
per-request API keys (via X-LlamaStack-Provider-Data header).
Returns models with fully qualified identifiers (provider_id/model_id) but does NOT
cache them in the registry since they are user-specific.
"""
provider_data = PROVIDER_DATA_VAR.get()
if not provider_data:
return []
dynamic_models = []
for provider_id, provider in self.impls_by_provider_id.items():
# Check if this provider supports provider_data
if not isinstance(provider, NeedsRequestProviderData):
continue
# Check if provider has a validator (some providers like ollama don't need per-request credentials)
spec = getattr(provider, "__provider_spec__", None)
if not spec or not getattr(spec, "provider_data_validator", None):
continue
# Validate provider_data silently - we're speculatively checking all providers
# so validation failures are expected when user didn't provide keys for this provider
try:
validator = instantiate_class_type(spec.provider_data_validator)
validator(**provider_data)
except Exception:
# User didn't provide credentials for this provider - skip silently
continue
# Validation succeeded! User has credentials for this provider
# Now try to list models
try:
models = await provider.list_models()
if not models:
continue
# Ensure models have fully qualified identifiers with provider_id prefix
for model in models:
# Only add prefix if model identifier doesn't already have it
if not model.identifier.startswith(f"{provider_id}/"):
model.identifier = f"{provider_id}/{model.provider_resource_id}"
dynamic_models.append(model)
logger.debug(f"Fetched {len(models)} models from provider {provider_id} using provider_data")
except Exception as e:
logger.debug(f"Failed to list models from provider {provider_id} with provider_data: {e}")
continue
return dynamic_models
async def list_models(self) -> ListModelsResponse:
return ListModelsResponse(data=await self.get_all_with_type("model"))
# Get models from registry
registry_models = await self.get_all_with_type("model")
# Get additional models available via provider_data (user-specific, not cached)
dynamic_models = await self._get_dynamic_models_from_provider_data()
# Combine, avoiding duplicates (registry takes precedence)
registry_identifiers = {m.identifier for m in registry_models}
unique_dynamic_models = [m for m in dynamic_models if m.identifier not in registry_identifiers]
return ListModelsResponse(data=registry_models + unique_dynamic_models)
async def openai_list_models(self) -> OpenAIListModelsResponse:
models = await self.get_all_with_type("model")
# Get models from registry
registry_models = await self.get_all_with_type("model")
# Get additional models available via provider_data (user-specific, not cached)
dynamic_models = await self._get_dynamic_models_from_provider_data()
# Combine, avoiding duplicates (registry takes precedence)
registry_identifiers = {m.identifier for m in registry_models}
unique_dynamic_models = [m for m in dynamic_models if m.identifier not in registry_identifiers]
all_models = registry_models + unique_dynamic_models
openai_models = [
OpenAIModel(
id=model.identifier,
object="model",
created=int(time.time()),
owned_by="llama_stack",
custom_metadata={
"model_type": model.model_type,
"provider_id": model.provider_id,
"provider_resource_id": model.provider_resource_id,
**model.metadata,
},
)
for model in models
for model in all_models
]
return OpenAIListModelsResponse(data=openai_models)

View file

@ -14,6 +14,7 @@ from typing import Any
import yaml
from llama_stack.apis.agents import Agents
from llama_stack.apis.batches import Batches
from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.datasetio import DatasetIO
@ -30,7 +31,6 @@ from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
@ -63,8 +63,8 @@ class LlamaStack(
Providers,
Inference,
Agents,
Batches,
Safety,
SyntheticDataGeneration,
Datasets,
PostTraining,
VectorIO,

View file

@ -12,7 +12,7 @@ from llama_stack.core.ui.modules.api import llama_stack_api
def models():
# Models Section
st.header("Models")
models_info = {m.identifier: m.to_dict() for m in llama_stack_api.client.models.list()}
models_info = {m.id: m.model_dump() for m in llama_stack_api.client.models.list()}
selected_model = st.selectbox("Select a model", list(models_info.keys()))
st.json(models_info[selected_model])

View file

@ -12,7 +12,11 @@ from llama_stack.core.ui.modules.api import llama_stack_api
with st.sidebar:
st.header("Configuration")
available_models = llama_stack_api.client.models.list()
available_models = [model.identifier for model in available_models if model.model_type == "llm"]
available_models = [
model.id
for model in available_models
if model.custom_metadata and model.custom_metadata.get("model_type") == "llm"
]
selected_model = st.selectbox(
"Choose a model",
available_models,

View file

@ -11,6 +11,7 @@ import uuid
import warnings
from collections.abc import AsyncGenerator
from datetime import UTC, datetime
from typing import Any, cast
import httpx
@ -125,12 +126,12 @@ class ChatAgent(ShieldRunnerMixin):
)
def turn_to_messages(self, turn: Turn) -> list[Message]:
messages = []
messages: list[Message] = []
# NOTE: if a toolcall response is in a step, we do not add it when processing the input messages
tool_call_ids = set()
for step in turn.steps:
if step.step_type == StepType.tool_execution.value:
if step.step_type == StepType.tool_execution.value and isinstance(step, ToolExecutionStep):
for response in step.tool_responses:
tool_call_ids.add(response.call_id)
@ -149,9 +150,9 @@ class ChatAgent(ShieldRunnerMixin):
messages.append(msg)
for step in turn.steps:
if step.step_type == StepType.inference.value:
if step.step_type == StepType.inference.value and isinstance(step, InferenceStep):
messages.append(step.model_response)
elif step.step_type == StepType.tool_execution.value:
elif step.step_type == StepType.tool_execution.value and isinstance(step, ToolExecutionStep):
for response in step.tool_responses:
messages.append(
ToolResponseMessage(
@ -159,8 +160,8 @@ class ChatAgent(ShieldRunnerMixin):
content=response.content,
)
)
elif step.step_type == StepType.shield_call.value:
if step.violation:
elif step.step_type == StepType.shield_call.value and isinstance(step, ShieldCallStep):
if step.violation and step.violation.user_message:
# CompletionMessage itself in the ShieldResponse
messages.append(
CompletionMessage(
@ -174,7 +175,7 @@ class ChatAgent(ShieldRunnerMixin):
return await self.storage.create_session(name)
async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]:
messages = []
messages: list[Message] = []
if self.agent_config.instructions != "":
messages.append(SystemMessage(content=self.agent_config.instructions))
@ -231,7 +232,9 @@ class ChatAgent(ShieldRunnerMixin):
steps = []
messages = await self.get_messages_from_turns(turns)
if is_resume:
assert isinstance(request, AgentTurnResumeRequest)
tool_response_messages = [
ToolResponseMessage(call_id=x.call_id, content=x.content) for x in request.tool_responses
]
@ -252,42 +255,52 @@ class ChatAgent(ShieldRunnerMixin):
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
request.session_id, request.turn_id
)
now = datetime.now(UTC).isoformat()
now_dt = datetime.now(UTC)
tool_execution_step = ToolExecutionStep(
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
turn_id=request.turn_id,
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
tool_responses=request.tool_responses,
completed_at=now,
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
completed_at=now_dt,
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now_dt),
)
steps.append(tool_execution_step)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value,
step_type=StepType.tool_execution,
step_id=tool_execution_step.step_id,
step_details=tool_execution_step,
)
)
)
input_messages = last_turn.input_messages
# Cast needed due to list invariance - last_turn.input_messages is the right type
input_messages = last_turn.input_messages # type: ignore[assignment]
turn_id = request.turn_id
actual_turn_id = request.turn_id
start_time = last_turn.started_at
else:
assert isinstance(request, AgentTurnCreateRequest)
messages.extend(request.messages)
start_time = datetime.now(UTC).isoformat()
input_messages = request.messages
start_time = datetime.now(UTC)
# Cast needed due to list invariance - request.messages is the right type
input_messages = request.messages # type: ignore[assignment]
# Use the generated turn_id from beginning of function
actual_turn_id = turn_id if turn_id else str(uuid.uuid4())
output_message = None
req_documents = request.documents if isinstance(request, AgentTurnCreateRequest) and not is_resume else None
req_sampling = (
self.agent_config.sampling_params if self.agent_config.sampling_params is not None else SamplingParams()
)
async for chunk in self.run(
session_id=request.session_id,
turn_id=turn_id,
turn_id=actual_turn_id,
input_messages=messages,
sampling_params=self.agent_config.sampling_params,
sampling_params=req_sampling,
stream=request.stream,
documents=request.documents if not is_resume else None,
documents=req_documents,
):
if isinstance(chunk, CompletionMessage):
output_message = chunk
@ -295,20 +308,23 @@ class ChatAgent(ShieldRunnerMixin):
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
event = chunk.event
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
steps.append(event.payload.step_details)
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value and hasattr(
event.payload, "step_details"
):
step_details = event.payload.step_details
steps.append(step_details)
yield chunk
assert output_message is not None
turn = Turn(
turn_id=turn_id,
turn_id=actual_turn_id,
session_id=request.session_id,
input_messages=input_messages,
input_messages=input_messages, # type: ignore[arg-type]
output_message=output_message,
started_at=start_time,
completed_at=datetime.now(UTC).isoformat(),
completed_at=datetime.now(UTC),
steps=steps,
)
await self.storage.add_turn_to_session(request.session_id, turn)
@ -345,9 +361,9 @@ class ChatAgent(ShieldRunnerMixin):
# return a "final value" for the `yield from` statement. we simulate that by yielding a
# final boolean (to see whether an exception happened) and then explicitly testing for it.
if len(self.input_shields) > 0:
if self.input_shields:
async for res in self.run_multiple_shields_wrapper(
turn_id, input_messages, self.input_shields, "user-input"
turn_id, cast(list[OpenAIMessageParam], input_messages), self.input_shields, "user-input"
):
if isinstance(res, bool):
return
@ -374,9 +390,9 @@ class ChatAgent(ShieldRunnerMixin):
# for output shields run on the full input and output combination
messages = input_messages + [final_response]
if len(self.output_shields) > 0:
if self.output_shields:
async for res in self.run_multiple_shields_wrapper(
turn_id, messages, self.output_shields, "assistant-output"
turn_id, cast(list[OpenAIMessageParam], messages), self.output_shields, "assistant-output"
):
if isinstance(res, bool):
return
@ -388,7 +404,7 @@ class ChatAgent(ShieldRunnerMixin):
async def run_multiple_shields_wrapper(
self,
turn_id: str,
messages: list[Message],
messages: list[OpenAIMessageParam],
shields: list[str],
touchpoint: str,
) -> AsyncGenerator:
@ -402,12 +418,12 @@ class ChatAgent(ShieldRunnerMixin):
return
step_id = str(uuid.uuid4())
shield_call_start_time = datetime.now(UTC).isoformat()
shield_call_start_time = datetime.now(UTC)
try:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.shield_call.value,
step_type=StepType.shield_call,
step_id=step_id,
metadata=dict(touchpoint=touchpoint),
)
@ -419,14 +435,14 @@ class ChatAgent(ShieldRunnerMixin):
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_type=StepType.shield_call,
step_id=step_id,
step_details=ShieldCallStep(
step_id=step_id,
turn_id=turn_id,
violation=e.violation,
started_at=shield_call_start_time,
completed_at=datetime.now(UTC).isoformat(),
completed_at=datetime.now(UTC),
),
)
)
@ -443,14 +459,14 @@ class ChatAgent(ShieldRunnerMixin):
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_type=StepType.shield_call,
step_id=step_id,
step_details=ShieldCallStep(
step_id=step_id,
turn_id=turn_id,
violation=None,
started_at=shield_call_start_time,
completed_at=datetime.now(UTC).isoformat(),
completed_at=datetime.now(UTC),
),
)
)
@ -496,21 +512,22 @@ class ChatAgent(ShieldRunnerMixin):
else:
self.tool_name_to_args[tool_name]["vector_store_ids"].append(session_info.vector_store_id)
output_attachments = []
output_attachments: list[Attachment] = []
n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
# Build a map of custom tools to their definitions for faster lookup
client_tools = {}
for tool in self.agent_config.client_tools:
client_tools[tool.name] = tool
if self.agent_config.client_tools:
for tool in self.agent_config.client_tools:
client_tools[tool.name] = tool
while True:
step_id = str(uuid.uuid4())
inference_start_time = datetime.now(UTC).isoformat()
inference_start_time = datetime.now(UTC)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.inference.value,
step_type=StepType.inference,
step_id=step_id,
)
)
@ -538,7 +555,7 @@ class ChatAgent(ShieldRunnerMixin):
else:
return value
def _add_type(openai_msg: dict) -> OpenAIMessageParam:
def _add_type(openai_msg: Any) -> OpenAIMessageParam:
# Serialize any nested Pydantic models to plain dicts
openai_msg = _serialize_nested(openai_msg)
@ -588,7 +605,7 @@ class ChatAgent(ShieldRunnerMixin):
messages=openai_messages,
tools=openai_tools if openai_tools else None,
tool_choice=tool_choice,
response_format=self.agent_config.response_format,
response_format=self.agent_config.response_format, # type: ignore[arg-type]
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
@ -598,7 +615,8 @@ class ChatAgent(ShieldRunnerMixin):
# Convert OpenAI stream back to Llama Stack format
response_stream = convert_openai_chat_completion_stream(
openai_stream, enable_incremental_tool_calls=True
openai_stream, # type: ignore[arg-type]
enable_incremental_tool_calls=True,
)
async for chunk in response_stream:
@ -620,7 +638,7 @@ class ChatAgent(ShieldRunnerMixin):
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_type=StepType.inference,
step_id=step_id,
delta=delta,
)
@ -633,7 +651,7 @@ class ChatAgent(ShieldRunnerMixin):
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_type=StepType.inference,
step_id=step_id,
delta=delta,
)
@ -651,7 +669,9 @@ class ChatAgent(ShieldRunnerMixin):
output_attr = json.dumps(
{
"content": content,
"tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls],
"tool_calls": [
json.loads(t.model_dump_json()) for t in tool_calls if isinstance(t, ToolCall)
],
}
)
span.set_attribute("output", output_attr)
@ -667,16 +687,18 @@ class ChatAgent(ShieldRunnerMixin):
if tool_calls:
content = ""
# Filter out string tool calls for CompletionMessage (only keep ToolCall objects)
valid_tool_calls = [t for t in tool_calls if isinstance(t, ToolCall)]
message = CompletionMessage(
content=content,
stop_reason=stop_reason,
tool_calls=tool_calls,
tool_calls=valid_tool_calls if valid_tool_calls else None,
)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.inference.value,
step_type=StepType.inference,
step_id=step_id,
step_details=InferenceStep(
# somewhere deep, we are re-assigning message or closing over some
@ -686,13 +708,14 @@ class ChatAgent(ShieldRunnerMixin):
turn_id=turn_id,
model_response=copy.deepcopy(message),
started_at=inference_start_time,
completed_at=datetime.now(UTC).isoformat(),
completed_at=datetime.now(UTC),
),
)
)
)
if n_iter >= self.agent_config.max_infer_iters:
max_iters = self.agent_config.max_infer_iters if self.agent_config.max_infer_iters is not None else 10
if n_iter >= max_iters:
logger.info(f"done with MAX iterations ({n_iter}), exiting.")
# NOTE: mark end_of_turn to indicate to client that we are done with the turn
# Do not continue the tool call loop after this point
@ -705,14 +728,16 @@ class ChatAgent(ShieldRunnerMixin):
yield message
break
if len(message.tool_calls) == 0:
if not message.tool_calls or len(message.tool_calls) == 0:
if stop_reason == StopReason.end_of_turn:
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
if len(output_attachments) > 0:
if isinstance(message.content, list):
message.content += output_attachments
# List invariance - attachments are compatible at runtime
message.content += output_attachments # type: ignore[arg-type]
else:
message.content = [message.content] + output_attachments
# List invariance - attachments are compatible at runtime
message.content = [message.content] + output_attachments # type: ignore[assignment]
yield message
else:
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
@ -725,11 +750,12 @@ class ChatAgent(ShieldRunnerMixin):
non_client_tool_calls = []
# Separate client and non-client tool calls
for tool_call in message.tool_calls:
if tool_call.tool_name in client_tools:
client_tool_calls.append(tool_call)
else:
non_client_tool_calls.append(tool_call)
if message.tool_calls:
for tool_call in message.tool_calls:
if tool_call.tool_name in client_tools:
client_tool_calls.append(tool_call)
else:
non_client_tool_calls.append(tool_call)
# Process non-client tool calls first
for tool_call in non_client_tool_calls:
@ -737,7 +763,7 @@ class ChatAgent(ShieldRunnerMixin):
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.tool_execution.value,
step_type=StepType.tool_execution,
step_id=step_id,
)
)
@ -746,7 +772,7 @@ class ChatAgent(ShieldRunnerMixin):
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.tool_execution.value,
step_type=StepType.tool_execution,
step_id=step_id,
delta=ToolCallDelta(
parse_status=ToolCallParseStatus.in_progress,
@ -766,7 +792,7 @@ class ChatAgent(ShieldRunnerMixin):
if self.telemetry_enabled
else {},
) as span:
tool_execution_start_time = datetime.now(UTC).isoformat()
tool_execution_start_time = datetime.now(UTC)
tool_result = await self.execute_tool_call_maybe(
session_id,
tool_call,
@ -796,14 +822,14 @@ class ChatAgent(ShieldRunnerMixin):
)
],
started_at=tool_execution_start_time,
completed_at=datetime.now(UTC).isoformat(),
completed_at=datetime.now(UTC),
)
# Yield the step completion event
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value,
step_type=StepType.tool_execution,
step_id=step_id,
step_details=tool_execution_step,
)
@ -833,7 +859,7 @@ class ChatAgent(ShieldRunnerMixin):
turn_id=turn_id,
tool_calls=client_tool_calls,
tool_responses=[],
started_at=datetime.now(UTC).isoformat(),
started_at=datetime.now(UTC),
),
)
@ -868,19 +894,20 @@ class ChatAgent(ShieldRunnerMixin):
toolgroup_to_args = toolgroup_to_args or {}
tool_name_to_def = {}
tool_name_to_def: dict[str, ToolDefinition] = {}
tool_name_to_args = {}
for tool_def in self.agent_config.client_tools:
if tool_name_to_def.get(tool_def.name, None):
raise ValueError(f"Tool {tool_def.name} already exists")
if self.agent_config.client_tools:
for tool_def in self.agent_config.client_tools:
if tool_name_to_def.get(tool_def.name, None):
raise ValueError(f"Tool {tool_def.name} already exists")
# Use input_schema from ToolDef directly
tool_name_to_def[tool_def.name] = ToolDefinition(
tool_name=tool_def.name,
description=tool_def.description,
input_schema=tool_def.input_schema,
)
# Use input_schema from ToolDef directly
tool_name_to_def[tool_def.name] = ToolDefinition(
tool_name=tool_def.name,
description=tool_def.description,
input_schema=tool_def.input_schema,
)
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
@ -908,15 +935,17 @@ class ChatAgent(ShieldRunnerMixin):
else:
identifier = None
if tool_name_to_def.get(identifier, None):
raise ValueError(f"Tool {identifier} already exists")
if identifier:
tool_name_to_def[identifier] = ToolDefinition(
tool_name=identifier,
# Convert BuiltinTool to string for dictionary key
identifier_str = identifier.value if isinstance(identifier, BuiltinTool) else identifier
if tool_name_to_def.get(identifier_str, None):
raise ValueError(f"Tool {identifier_str} already exists")
tool_name_to_def[identifier_str] = ToolDefinition(
tool_name=identifier_str,
description=tool_def.description,
input_schema=tool_def.input_schema,
)
tool_name_to_args[identifier] = toolgroup_to_args.get(toolgroup_name, {})
tool_name_to_args[identifier_str] = toolgroup_to_args.get(toolgroup_name, {})
self.tool_defs, self.tool_name_to_args = (
list(tool_name_to_def.values()),
@ -966,14 +995,17 @@ class ChatAgent(ShieldRunnerMixin):
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse arguments for tool call: {tool_call.arguments}") from e
result = await self.tool_runtime_api.invoke_tool(
tool_name=tool_name_str,
kwargs={
"session_id": session_id,
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
**args,
**self.tool_name_to_args.get(tool_name_str, {}),
},
result = cast(
ToolInvocationResult,
await self.tool_runtime_api.invoke_tool(
tool_name=tool_name_str,
kwargs={
"session_id": session_id,
# get the arguments generated by the model and augment with toolgroup arg overrides for the agent
**args,
**self.tool_name_to_args.get(tool_name_str, {}),
},
),
)
logger.debug(f"tool call {tool_name_str} completed with result: {result}")
return result
@ -983,7 +1015,7 @@ async def load_data_from_url(url: str) -> str:
if url.startswith("http"):
async with httpx.AsyncClient() as client:
r = await client.get(url)
resp = r.text
resp: str = r.text
return resp
raise ValueError(f"Unexpected URL: {type(url)}")
@ -1017,7 +1049,7 @@ def _interpret_content_as_attachment(
snippet = match.group(1)
data = json.loads(snippet)
return Attachment(
url=URL(uri="file://" + data["filepath"]),
content=URL(uri="file://" + data["filepath"]),
mime_type=data["mimetype"],
)

View file

@ -21,6 +21,7 @@ from llama_stack.apis.agents import (
Document,
ListOpenAIResponseInputItem,
ListOpenAIResponseObject,
OpenAIDeleteResponseObject,
OpenAIResponseInput,
OpenAIResponseInputTool,
OpenAIResponseObject,
@ -141,7 +142,7 @@ class MetaReferenceAgentsImpl(Agents):
persistence_store=(
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
),
created_at=agent_info.created_at,
created_at=agent_info.created_at.isoformat(),
policy=self.policy,
telemetry_enabled=self.telemetry_enabled,
)
@ -163,9 +164,9 @@ class MetaReferenceAgentsImpl(Agents):
agent_id: str,
session_id: str,
messages: list[UserMessage | ToolResponseMessage],
toolgroups: list[AgentToolGroup] | None = None,
documents: list[Document] | None = None,
stream: bool | None = False,
documents: list[Document] | None = None,
toolgroups: list[AgentToolGroup] | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
request = AgentTurnCreateRequest(
@ -221,6 +222,8 @@ class MetaReferenceAgentsImpl(Agents):
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
agent = await self._get_agent_impl(agent_id)
turn = await agent.storage.get_session_turn(session_id, turn_id)
if turn is None:
raise ValueError(f"Turn {turn_id} not found in session {session_id}")
return turn
async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse:
@ -232,13 +235,15 @@ class MetaReferenceAgentsImpl(Agents):
async def get_agents_session(
self,
agent_id: str,
session_id: str,
agent_id: str,
turn_ids: list[str] | None = None,
) -> Session:
agent = await self._get_agent_impl(agent_id)
session_info = await agent.storage.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
turns = await agent.storage.get_session_turns(session_id)
if turn_ids:
turns = [turn for turn in turns if turn.turn_id in turn_ids]
@ -249,7 +254,7 @@ class MetaReferenceAgentsImpl(Agents):
started_at=session_info.started_at,
)
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
async def delete_agents_session(self, session_id: str, agent_id: str) -> None:
agent = await self._get_agent_impl(agent_id)
# Delete turns first, then the session
@ -302,7 +307,7 @@ class MetaReferenceAgentsImpl(Agents):
agent = Agent(
agent_id=agent_id,
agent_config=chat_agent.agent_config,
created_at=chat_agent.created_at,
created_at=datetime.fromisoformat(chat_agent.created_at),
)
return agent
@ -323,6 +328,7 @@ class MetaReferenceAgentsImpl(Agents):
self,
response_id: str,
) -> OpenAIResponseObject:
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
return await self.openai_responses_impl.get_openai_response(response_id)
async def create_openai_response(
@ -342,7 +348,8 @@ class MetaReferenceAgentsImpl(Agents):
max_infer_iters: int | None = 10,
guardrails: list[ResponseGuardrail] | None = None,
) -> OpenAIResponseObject:
return await self.openai_responses_impl.create_openai_response(
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
result = await self.openai_responses_impl.create_openai_response(
input,
model,
prompt,
@ -358,6 +365,7 @@ class MetaReferenceAgentsImpl(Agents):
max_infer_iters,
guardrails,
)
return result # type: ignore[no-any-return]
async def list_openai_responses(
self,
@ -366,6 +374,7 @@ class MetaReferenceAgentsImpl(Agents):
model: str | None = None,
order: Order | None = Order.desc,
) -> ListOpenAIResponseObject:
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
return await self.openai_responses_impl.list_openai_responses(after, limit, model, order)
async def list_openai_response_input_items(
@ -377,9 +386,11 @@ class MetaReferenceAgentsImpl(Agents):
limit: int | None = 20,
order: Order | None = Order.desc,
) -> ListOpenAIResponseInputItem:
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
return await self.openai_responses_impl.list_openai_response_input_items(
response_id, after, before, include, limit, order
)
async def delete_openai_response(self, response_id: str) -> None:
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
return await self.openai_responses_impl.delete_openai_response(response_id)

View file

@ -6,12 +6,14 @@
import json
import uuid
from dataclasses import dataclass
from datetime import UTC, datetime
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
from llama_stack.apis.common.errors import SessionNotFoundError
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.core.access_control.datatypes import AccessRule
from llama_stack.core.access_control.conditions import User as ProtocolUser
from llama_stack.core.access_control.datatypes import AccessRule, Action
from llama_stack.core.datatypes import User
from llama_stack.core.request_headers import get_authenticated_user
from llama_stack.log import get_logger
@ -33,6 +35,15 @@ class AgentInfo(AgentConfig):
created_at: datetime
@dataclass
class SessionResource:
"""Concrete implementation of ProtectedResource for session access control."""
type: str
identifier: str
owner: ProtocolUser # Use the protocol type for structural compatibility
class AgentPersistence:
def __init__(self, agent_id: str, kvstore: KVStore, policy: list[AccessRule]):
self.agent_id = agent_id
@ -53,8 +64,15 @@ class AgentPersistence:
turns=[],
identifier=name, # should this be qualified in any way?
)
if not is_action_allowed(self.policy, "create", session_info, user):
raise AccessDeniedError("create", session_info, user)
# Only perform access control if we have an authenticated user
if user is not None and session_info.identifier is not None:
resource = SessionResource(
type=session_info.type,
identifier=session_info.identifier,
owner=user,
)
if not is_action_allowed(self.policy, Action.CREATE, resource, user):
raise AccessDeniedError(Action.CREATE, resource, user)
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}",
@ -62,7 +80,7 @@ class AgentPersistence:
)
return session_id
async def get_session_info(self, session_id: str) -> AgentSessionInfo:
async def get_session_info(self, session_id: str) -> AgentSessionInfo | None:
value = await self.kvstore.get(
key=f"session:{self.agent_id}:{session_id}",
)
@ -83,7 +101,22 @@ class AgentPersistence:
if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"):
return True
return is_action_allowed(self.policy, "read", session_info, get_authenticated_user())
# Get current user - if None, skip access control (e.g., in tests)
user = get_authenticated_user()
if user is None:
return True
# Access control requires identifier and owner to be set
if session_info.identifier is None or session_info.owner is None:
return True
# At this point, both identifier and owner are guaranteed to be non-None
resource = SessionResource(
type=session_info.type,
identifier=session_info.identifier,
owner=session_info.owner,
)
return is_action_allowed(self.policy, Action.READ, resource, user)
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
"""Get session info if the user has access to it. For internal use by sub-session methods."""

View file

@ -91,7 +91,8 @@ class OpenAIResponsesImpl:
input: str | list[OpenAIResponseInput],
previous_response: _OpenAIResponseObjectWithInputAndMessages,
):
new_input_items = previous_response.input.copy()
# Convert Sequence to list for mutation
new_input_items = list(previous_response.input)
new_input_items.extend(previous_response.output)
if isinstance(input, str):
@ -107,7 +108,7 @@ class OpenAIResponsesImpl:
tools: list[OpenAIResponseInputTool] | None,
previous_response_id: str | None,
conversation: str | None,
) -> tuple[str | list[OpenAIResponseInput], list[OpenAIMessageParam]]:
) -> tuple[str | list[OpenAIResponseInput], list[OpenAIMessageParam], ToolContext]:
"""Process input with optional previous response context.
Returns:
@ -208,6 +209,9 @@ class OpenAIResponsesImpl:
messages: list[OpenAIMessageParam],
) -> None:
new_input_id = f"msg_{uuid.uuid4()}"
# Type input_items_data as the full OpenAIResponseInput union to avoid list invariance issues
input_items_data: list[OpenAIResponseInput] = []
if isinstance(input, str):
# synthesize a message from the input string
input_content = OpenAIResponseInputMessageContentText(text=input)
@ -219,7 +223,6 @@ class OpenAIResponsesImpl:
input_items_data = [input_content_item]
else:
# we already have a list of messages
input_items_data = []
for input_item in input:
if isinstance(input_item, OpenAIResponseMessage):
# These may or may not already have an id, so dump to dict, check for id, and add if missing
@ -251,7 +254,7 @@ class OpenAIResponsesImpl:
tools: list[OpenAIResponseInputTool] | None = None,
include: list[str] | None = None,
max_infer_iters: int | None = 10,
guardrails: list[ResponseGuardrailSpec] | None = None,
guardrails: list[str | ResponseGuardrailSpec] | None = None,
):
stream = bool(stream)
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
@ -289,16 +292,19 @@ class OpenAIResponsesImpl:
failed_response = None
async for stream_chunk in stream_gen:
if stream_chunk.type in {"response.completed", "response.incomplete"}:
if final_response is not None:
raise ValueError(
"The response stream produced multiple terminal responses! "
f"Earlier response from {final_event_type}"
)
final_response = stream_chunk.response
final_event_type = stream_chunk.type
elif stream_chunk.type == "response.failed":
failed_response = stream_chunk.response
match stream_chunk.type:
case "response.completed" | "response.incomplete":
if final_response is not None:
raise ValueError(
"The response stream produced multiple terminal responses! "
f"Earlier response from {final_event_type}"
)
final_response = stream_chunk.response
final_event_type = stream_chunk.type
case "response.failed":
failed_response = stream_chunk.response
case _:
pass # Other event types don't have .response
if failed_response is not None:
error_message = (
@ -326,6 +332,11 @@ class OpenAIResponsesImpl:
max_infer_iters: int | None = 10,
guardrail_ids: list[str] | None = None,
) -> AsyncIterator[OpenAIResponseObjectStream]:
# These should never be None when called from create_openai_response (which sets defaults)
# but we assert here to help mypy understand the types
assert text is not None, "text must not be None"
assert max_infer_iters is not None, "max_infer_iters must not be None"
# Input preprocessing
all_input, messages, tool_context = await self._process_input_with_previous_response(
input, tools, previous_response_id, conversation
@ -368,16 +379,19 @@ class OpenAIResponsesImpl:
final_response = None
failed_response = None
output_items = []
# Type as ConversationItem to avoid list invariance issues
output_items: list[ConversationItem] = []
async for stream_chunk in orchestrator.create_response():
if stream_chunk.type in {"response.completed", "response.incomplete"}:
final_response = stream_chunk.response
elif stream_chunk.type == "response.failed":
failed_response = stream_chunk.response
if stream_chunk.type == "response.output_item.done":
item = stream_chunk.item
output_items.append(item)
match stream_chunk.type:
case "response.completed" | "response.incomplete":
final_response = stream_chunk.response
case "response.failed":
failed_response = stream_chunk.response
case "response.output_item.done":
item = stream_chunk.item
output_items.append(item)
case _:
pass # Other event types
# Store and sync before yielding terminal events
# This ensures the storage/syncing happens even if the consumer breaks after receiving the event
@ -410,7 +424,8 @@ class OpenAIResponsesImpl:
self, conversation_id: str, input: str | list[OpenAIResponseInput] | None, output_items: list[ConversationItem]
) -> None:
"""Sync content and response messages to the conversation."""
conversation_items = []
# Type as ConversationItem union to avoid list invariance issues
conversation_items: list[ConversationItem] = []
if isinstance(input, str):
conversation_items.append(

View file

@ -111,7 +111,7 @@ class StreamingResponseOrchestrator:
text: OpenAIResponseText,
max_infer_iters: int,
tool_executor, # Will be the tool execution logic from the main class
instructions: str,
instructions: str | None,
safety_api,
guardrail_ids: list[str] | None = None,
prompt: OpenAIResponsePrompt | None = None,
@ -128,7 +128,9 @@ class StreamingResponseOrchestrator:
self.prompt = prompt
self.sequence_number = 0
# Store MCP tool mapping that gets built during tool processing
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = ctx.tool_context.previous_tools or {}
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = (
ctx.tool_context.previous_tools if ctx.tool_context else {}
)
# Track final messages after all tool executions
self.final_messages: list[OpenAIMessageParam] = []
# mapping for annotations
@ -229,7 +231,8 @@ class StreamingResponseOrchestrator:
params = OpenAIChatCompletionRequestWithExtraBody(
model=self.ctx.model,
messages=messages,
tools=self.ctx.chat_tools,
# Pydantic models are dict-compatible but mypy treats them as distinct types
tools=self.ctx.chat_tools, # type: ignore[arg-type]
stream=True,
temperature=self.ctx.temperature,
response_format=response_format,
@ -272,7 +275,12 @@ class StreamingResponseOrchestrator:
# Handle choices with no tool calls
for choice in current_response.choices:
if not (choice.message.tool_calls and self.ctx.response_tools):
has_tool_calls = (
isinstance(choice.message, OpenAIAssistantMessageParam)
and choice.message.tool_calls
and self.ctx.response_tools
)
if not has_tool_calls:
output_messages.append(
await convert_chat_choice_to_response_message(
choice,
@ -722,7 +730,10 @@ class StreamingResponseOrchestrator:
)
# Accumulate arguments for final response (only for subsequent chunks)
if not is_new_tool_call:
if not is_new_tool_call and response_tool_call is not None:
# Both should have functions since we're inside the tool_call.function check above
assert response_tool_call.function is not None
assert tool_call.function is not None
response_tool_call.function.arguments = (
response_tool_call.function.arguments or ""
) + tool_call.function.arguments
@ -747,10 +758,13 @@ class StreamingResponseOrchestrator:
for tool_call_index in sorted(chat_response_tool_calls.keys()):
tool_call = chat_response_tool_calls[tool_call_index]
# Ensure that arguments, if sent back to the inference provider, are not None
tool_call.function.arguments = tool_call.function.arguments or "{}"
if tool_call.function:
tool_call.function.arguments = tool_call.function.arguments or "{}"
tool_call_item_id = tool_call_item_ids[tool_call_index]
final_arguments = tool_call.function.arguments
tool_call_name = chat_response_tool_calls[tool_call_index].function.name
final_arguments: str = tool_call.function.arguments or "{}" if tool_call.function else "{}"
func = chat_response_tool_calls[tool_call_index].function
tool_call_name = func.name if func else ""
# Check if this is an MCP tool call
is_mcp_tool = tool_call_name and tool_call_name in self.mcp_tool_to_server
@ -894,12 +908,11 @@ class StreamingResponseOrchestrator:
self.sequence_number += 1
if tool_call.function.name and tool_call.function.name in self.mcp_tool_to_server:
item = OpenAIResponseOutputMessageMCPCall(
item: OpenAIResponseOutput = OpenAIResponseOutputMessageMCPCall(
arguments="",
name=tool_call.function.name,
id=matching_item_id,
server_label=self.mcp_tool_to_server[tool_call.function.name].server_label,
status="in_progress",
)
elif tool_call.function.name == "web_search":
item = OpenAIResponseOutputMessageWebSearchToolCall(
@ -1008,7 +1021,7 @@ class StreamingResponseOrchestrator:
description=tool.description,
input_schema=tool.input_schema,
)
return convert_tooldef_to_openai_tool(tool_def)
return convert_tooldef_to_openai_tool(tool_def) # type: ignore[return-value] # Returns dict but ChatCompletionToolParam expects TypedDict
# Initialize chat_tools if not already set
if self.ctx.chat_tools is None:
@ -1016,7 +1029,7 @@ class StreamingResponseOrchestrator:
for input_tool in tools:
if input_tool.type == "function":
self.ctx.chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump()))
self.ctx.chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump())) # type: ignore[typeddict-item,arg-type] # Dict compatible with FunctionDefinition
elif input_tool.type in WebSearchToolTypes:
tool_name = "web_search"
# Need to access tool_groups_api from tool_executor
@ -1055,8 +1068,8 @@ class StreamingResponseOrchestrator:
if isinstance(mcp_tool.allowed_tools, list):
always_allowed = mcp_tool.allowed_tools
elif isinstance(mcp_tool.allowed_tools, AllowedToolsFilter):
always_allowed = mcp_tool.allowed_tools.always
never_allowed = mcp_tool.allowed_tools.never
# AllowedToolsFilter only has tool_names field (not allowed/disallowed)
always_allowed = mcp_tool.allowed_tools.tool_names
# Call list_mcp_tools
tool_defs = None
@ -1088,7 +1101,7 @@ class StreamingResponseOrchestrator:
openai_tool = convert_tooldef_to_chat_tool(t)
if self.ctx.chat_tools is None:
self.ctx.chat_tools = []
self.ctx.chat_tools.append(openai_tool)
self.ctx.chat_tools.append(openai_tool) # type: ignore[arg-type] # Returns dict but ChatCompletionToolParam expects TypedDict
# Add to MCP tool mapping
if t.name in self.mcp_tool_to_server:
@ -1120,13 +1133,17 @@ class StreamingResponseOrchestrator:
self, output_messages: list[OpenAIResponseOutput]
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Handle all mcp tool lists from previous response that are still valid:
for tool in self.ctx.tool_context.previous_tool_listings:
async for evt in self._reuse_mcp_list_tools(tool, output_messages):
yield evt
# Process all remaining tools (including MCP tools) and emit streaming events
if self.ctx.tool_context.tools_to_process:
async for stream_event in self._process_new_tools(self.ctx.tool_context.tools_to_process, output_messages):
yield stream_event
# tool_context can be None when no tools are provided in the response request
if self.ctx.tool_context:
for tool in self.ctx.tool_context.previous_tool_listings:
async for evt in self._reuse_mcp_list_tools(tool, output_messages):
yield evt
# Process all remaining tools (including MCP tools) and emit streaming events
if self.ctx.tool_context.tools_to_process:
async for stream_event in self._process_new_tools(
self.ctx.tool_context.tools_to_process, output_messages
):
yield stream_event
def _approval_required(self, tool_name: str) -> bool:
if tool_name not in self.mcp_tool_to_server:
@ -1220,7 +1237,7 @@ class StreamingResponseOrchestrator:
openai_tool = convert_tooldef_to_openai_tool(tool_def)
if self.ctx.chat_tools is None:
self.ctx.chat_tools = []
self.ctx.chat_tools.append(openai_tool)
self.ctx.chat_tools.append(openai_tool) # type: ignore[arg-type] # Returns dict but ChatCompletionToolParam expects TypedDict
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
id=f"mcp_list_{uuid.uuid4()}",

View file

@ -7,6 +7,7 @@
import asyncio
import json
from collections.abc import AsyncIterator
from typing import Any
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputToolFileSearch,
@ -22,6 +23,7 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseObjectStreamResponseWebSearchCallSearching,
OpenAIResponseOutputMessageFileSearchToolCall,
OpenAIResponseOutputMessageFileSearchToolCallResults,
OpenAIResponseOutputMessageMCPCall,
OpenAIResponseOutputMessageWebSearchToolCall,
)
from llama_stack.apis.common.content_types import (
@ -67,7 +69,7 @@ class ToolExecutor:
) -> AsyncIterator[ToolExecutionResult]:
tool_call_id = tool_call.id
function = tool_call.function
tool_kwargs = json.loads(function.arguments) if function.arguments else {}
tool_kwargs = json.loads(function.arguments) if function and function.arguments else {}
if not function or not tool_call_id or not function.name:
yield ToolExecutionResult(sequence_number=sequence_number)
@ -84,7 +86,16 @@ class ToolExecutor:
error_exc, result = await self._execute_tool(function.name, tool_kwargs, ctx, mcp_tool_to_server)
# Emit completion events for tool execution
has_error = error_exc or (result and ((result.error_code and result.error_code > 0) or result.error_message))
has_error = bool(
error_exc
or (
result
and (
((error_code := getattr(result, "error_code", None)) and error_code > 0)
or getattr(result, "error_message", None)
)
)
)
async for event_result in self._emit_completion_events(
function.name, ctx, sequence_number, output_index, item_id, has_error, mcp_tool_to_server
):
@ -101,7 +112,9 @@ class ToolExecutor:
sequence_number=sequence_number,
final_output_message=output_message,
final_input_message=input_message,
citation_files=result.metadata.get("citation_files") if result and result.metadata else None,
citation_files=(
metadata.get("citation_files") if result and (metadata := getattr(result, "metadata", None)) else None
),
)
async def _execute_knowledge_search_via_vector_store(
@ -188,8 +201,9 @@ class ToolExecutor:
citation_files[file_id] = filename
# Cast to proper InterleavedContent type (list invariance)
return ToolInvocationResult(
content=content_items,
content=content_items, # type: ignore[arg-type]
metadata={
"document_ids": [r.file_id for r in search_results],
"chunks": [r.content[0].text if r.content else "" for r in search_results],
@ -209,51 +223,60 @@ class ToolExecutor:
) -> AsyncIterator[ToolExecutionResult]:
"""Emit progress events for tool execution start."""
# Emit in_progress event based on tool type (only for tools with specific streaming events)
progress_event = None
if mcp_tool_to_server and function_name in mcp_tool_to_server:
sequence_number += 1
progress_event = OpenAIResponseObjectStreamResponseMcpCallInProgress(
item_id=item_id,
output_index=output_index,
yield ToolExecutionResult(
stream_event=OpenAIResponseObjectStreamResponseMcpCallInProgress(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
),
sequence_number=sequence_number,
)
elif function_name == "web_search":
sequence_number += 1
progress_event = OpenAIResponseObjectStreamResponseWebSearchCallInProgress(
item_id=item_id,
output_index=output_index,
yield ToolExecutionResult(
stream_event=OpenAIResponseObjectStreamResponseWebSearchCallInProgress(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
),
sequence_number=sequence_number,
)
elif function_name == "knowledge_search":
sequence_number += 1
progress_event = OpenAIResponseObjectStreamResponseFileSearchCallInProgress(
item_id=item_id,
output_index=output_index,
yield ToolExecutionResult(
stream_event=OpenAIResponseObjectStreamResponseFileSearchCallInProgress(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
),
sequence_number=sequence_number,
)
if progress_event:
yield ToolExecutionResult(stream_event=progress_event, sequence_number=sequence_number)
# For web search, emit searching event
if function_name == "web_search":
sequence_number += 1
searching_event = OpenAIResponseObjectStreamResponseWebSearchCallSearching(
item_id=item_id,
output_index=output_index,
yield ToolExecutionResult(
stream_event=OpenAIResponseObjectStreamResponseWebSearchCallSearching(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
),
sequence_number=sequence_number,
)
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
# For file search, emit searching event
if function_name == "knowledge_search":
sequence_number += 1
searching_event = OpenAIResponseObjectStreamResponseFileSearchCallSearching(
item_id=item_id,
output_index=output_index,
yield ToolExecutionResult(
stream_event=OpenAIResponseObjectStreamResponseFileSearchCallSearching(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
),
sequence_number=sequence_number,
)
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
async def _execute_tool(
self,
@ -261,7 +284,7 @@ class ToolExecutor:
tool_kwargs: dict,
ctx: ChatCompletionContext,
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
) -> tuple[Exception | None, any]:
) -> tuple[Exception | None, Any]:
"""Execute the tool and return error exception and result."""
error_exc = None
result = None
@ -284,9 +307,13 @@ class ToolExecutor:
kwargs=tool_kwargs,
)
elif function_name == "knowledge_search":
response_file_search_tool = next(
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)),
None,
response_file_search_tool = (
next(
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)),
None,
)
if ctx.response_tools
else None
)
if response_file_search_tool:
# Use vector_stores.search API instead of knowledge_search tool
@ -322,35 +349,34 @@ class ToolExecutor:
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
) -> AsyncIterator[ToolExecutionResult]:
"""Emit completion or failure events for tool execution."""
completion_event = None
if mcp_tool_to_server and function_name in mcp_tool_to_server:
sequence_number += 1
if has_error:
completion_event = OpenAIResponseObjectStreamResponseMcpCallFailed(
mcp_failed_event = OpenAIResponseObjectStreamResponseMcpCallFailed(
sequence_number=sequence_number,
)
yield ToolExecutionResult(stream_event=mcp_failed_event, sequence_number=sequence_number)
else:
completion_event = OpenAIResponseObjectStreamResponseMcpCallCompleted(
mcp_completed_event = OpenAIResponseObjectStreamResponseMcpCallCompleted(
sequence_number=sequence_number,
)
yield ToolExecutionResult(stream_event=mcp_completed_event, sequence_number=sequence_number)
elif function_name == "web_search":
sequence_number += 1
completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted(
web_completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
)
yield ToolExecutionResult(stream_event=web_completion_event, sequence_number=sequence_number)
elif function_name == "knowledge_search":
sequence_number += 1
completion_event = OpenAIResponseObjectStreamResponseFileSearchCallCompleted(
file_completion_event = OpenAIResponseObjectStreamResponseFileSearchCallCompleted(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
)
if completion_event:
yield ToolExecutionResult(stream_event=completion_event, sequence_number=sequence_number)
yield ToolExecutionResult(stream_event=file_completion_event, sequence_number=sequence_number)
async def _build_result_messages(
self,
@ -360,21 +386,18 @@ class ToolExecutor:
tool_kwargs: dict,
ctx: ChatCompletionContext,
error_exc: Exception | None,
result: any,
result: Any,
has_error: bool,
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
) -> tuple[any, any]:
) -> tuple[Any, Any]:
"""Build output and input messages from tool execution results."""
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
# Build output message
message: Any
if mcp_tool_to_server and function.name in mcp_tool_to_server:
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseOutputMessageMCPCall,
)
message = OpenAIResponseOutputMessageMCPCall(
id=item_id,
arguments=function.arguments,
@ -383,10 +406,14 @@ class ToolExecutor:
)
if error_exc:
message.error = str(error_exc)
elif (result and result.error_code and result.error_code > 0) or (result and result.error_message):
message.error = f"Error (code {result.error_code}): {result.error_message}"
elif result and result.content:
message.output = interleaved_content_as_str(result.content)
elif (result and (error_code := getattr(result, "error_code", None)) and error_code > 0) or (
result and getattr(result, "error_message", None)
):
ec = getattr(result, "error_code", "unknown")
em = getattr(result, "error_message", "")
message.error = f"Error (code {ec}): {em}"
elif result and (content := getattr(result, "content", None)):
message.output = interleaved_content_as_str(content)
else:
if function.name == "web_search":
message = OpenAIResponseOutputMessageWebSearchToolCall(
@ -401,17 +428,17 @@ class ToolExecutor:
queries=[tool_kwargs.get("query", "")],
status="completed",
)
if result and "document_ids" in result.metadata:
if result and (metadata := getattr(result, "metadata", None)) and "document_ids" in metadata:
message.results = []
for i, doc_id in enumerate(result.metadata["document_ids"]):
text = result.metadata["chunks"][i] if "chunks" in result.metadata else None
score = result.metadata["scores"][i] if "scores" in result.metadata else None
for i, doc_id in enumerate(metadata["document_ids"]):
text = metadata["chunks"][i] if "chunks" in metadata else None
score = metadata["scores"][i] if "scores" in metadata else None
message.results.append(
OpenAIResponseOutputMessageFileSearchToolCallResults(
file_id=doc_id,
filename=doc_id,
text=text,
score=score,
text=text if text is not None else "",
score=score if score is not None else 0.0,
attributes={},
)
)
@ -421,27 +448,32 @@ class ToolExecutor:
raise ValueError(f"Unknown tool {function.name} called")
# Build input message
input_message = None
if result and result.content:
if isinstance(result.content, str):
content = result.content
elif isinstance(result.content, list):
content = []
for item in result.content:
input_message: OpenAIToolMessageParam | None = None
if result and (result_content := getattr(result, "content", None)):
# all the mypy contortions here are still unsatisfactory with random Any typing
if isinstance(result_content, str):
msg_content: str | list[Any] = result_content
elif isinstance(result_content, list):
content_list: list[Any] = []
for item in result_content:
part: Any
if isinstance(item, TextContentItem):
part = OpenAIChatCompletionContentPartTextParam(text=item.text)
elif isinstance(item, ImageContentItem):
if item.image.data:
url = f"data:image;base64,{item.image.data}"
url_value = f"data:image;base64,{item.image.data}"
else:
url = item.image.url
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url))
url_value = str(item.image.url) if item.image.url else ""
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url_value))
else:
raise ValueError(f"Unknown result content type: {type(item)}")
content.append(part)
content_list.append(part)
msg_content = content_list
else:
raise ValueError(f"Unknown result content type: {type(result.content)}")
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
raise ValueError(f"Unknown result content type: {type(result_content)}")
# OpenAIToolMessageParam accepts str | list[TextParam] but we may have images
# This is runtime-safe as the API accepts it, but mypy complains
input_message = OpenAIToolMessageParam(content=msg_content, tool_call_id=tool_call_id) # type: ignore[arg-type]
else:
text = str(error_exc) if error_exc else "Tool execution failed"
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
from dataclasses import dataclass
from typing import cast
from openai.types.chat import ChatCompletionToolParam
from pydantic import BaseModel
@ -100,17 +101,19 @@ class ToolContext(BaseModel):
if isinstance(tool, OpenAIResponseToolMCP):
previous_tools_by_label[tool.server_label] = tool
# collect tool definitions which are the same in current and previous requests:
tools_to_process = []
tools_to_process: list[OpenAIResponseInputTool] = []
matched: dict[str, OpenAIResponseInputToolMCP] = {}
for tool in self.current_tools:
# Mypy confuses OpenAIResponseInputTool (Input union) with OpenAIResponseTool (output union)
# which differ only in MCP type (InputToolMCP vs ToolMCP). Code is correct.
for tool in cast(list[OpenAIResponseInputTool], self.current_tools): # type: ignore[assignment]
if isinstance(tool, OpenAIResponseInputToolMCP) and tool.server_label in previous_tools_by_label:
previous_tool = previous_tools_by_label[tool.server_label]
if previous_tool.allowed_tools == tool.allowed_tools:
matched[tool.server_label] = tool
else:
tools_to_process.append(tool)
tools_to_process.append(tool) # type: ignore[arg-type]
else:
tools_to_process.append(tool)
tools_to_process.append(tool) # type: ignore[arg-type]
# tools that are not the same or were not previously defined need to be processed:
self.tools_to_process = tools_to_process
# for all matched definitions, get the mcp_list_tools objects from the previous output:
@ -119,9 +122,11 @@ class ToolContext(BaseModel):
]
# reconstruct the tool to server mappings that can be reused:
for listing in self.previous_tool_listings:
# listing is OpenAIResponseOutputMessageMCPListTools which has tools: list[MCPListToolsTool]
definition = matched[listing.server_label]
for tool in listing.tools:
self.previous_tools[tool.name] = definition
for mcp_tool in listing.tools:
# mcp_tool is MCPListToolsTool which has a name: str field
self.previous_tools[mcp_tool.name] = definition
def available_tools(self) -> list[OpenAIResponseTool]:
if not self.current_tools:
@ -139,6 +144,8 @@ class ToolContext(BaseModel):
server_label=tool.server_label,
allowed_tools=tool.allowed_tools,
)
# Exhaustive check - all tool types should be handled above
raise AssertionError(f"Unexpected tool type: {type(tool)}")
return [convert_tool(tool) for tool in self.current_tools]

View file

@ -7,6 +7,7 @@
import asyncio
import re
import uuid
from collections.abc import Sequence
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
from llama_stack.apis.agents.openai_responses import (
@ -71,14 +72,14 @@ async def convert_chat_choice_to_response_message(
return OpenAIResponseMessage(
id=message_id or f"msg_{uuid.uuid4()}",
content=[OpenAIResponseOutputMessageContentOutputText(text=clean_text, annotations=annotations)],
content=[OpenAIResponseOutputMessageContentOutputText(text=clean_text, annotations=list(annotations))],
status="completed",
role="assistant",
)
async def convert_response_content_to_chat_content(
content: (str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]),
content: str | Sequence[OpenAIResponseInputMessageContent | OpenAIResponseOutputMessageContent],
) -> str | list[OpenAIChatCompletionContentPartParam]:
"""
Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts.
@ -88,7 +89,8 @@ async def convert_response_content_to_chat_content(
if isinstance(content, str):
return content
converted_parts = []
# Type with union to avoid list invariance issues
converted_parts: list[OpenAIChatCompletionContentPartParam] = []
for content_part in content:
if isinstance(content_part, OpenAIResponseInputMessageContentText):
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
@ -158,9 +160,11 @@ async def convert_response_input_to_chat_messages(
),
)
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
# Output can be None, use empty string as fallback
output_content = input_item.output if input_item.output is not None else ""
messages.append(
OpenAIToolMessageParam(
content=input_item.output,
content=output_content,
tool_call_id=input_item.id,
)
)
@ -172,7 +176,8 @@ async def convert_response_input_to_chat_messages(
):
# these are handled by the responses impl itself and not pass through to chat completions
pass
else:
elif isinstance(input_item, OpenAIResponseMessage):
# Narrow type to OpenAIResponseMessage which has content and role attributes
content = await convert_response_content_to_chat_content(input_item.content)
message_type = await get_message_type_by_role(input_item.role)
if message_type is None:
@ -191,7 +196,8 @@ async def convert_response_input_to_chat_messages(
last_user_content = getattr(last_user_msg, "content", None)
if last_user_content == content:
continue # Skip duplicate user message
messages.append(message_type(content=content))
# Dynamic message type call - different message types have different content expectations
messages.append(message_type(content=content)) # type: ignore[call-arg,arg-type]
if len(tool_call_results):
# Check if unpaired function_call_outputs reference function_calls from previous messages
if previous_messages:
@ -237,8 +243,11 @@ async def convert_response_text_to_chat_response_format(
if text.format["type"] == "json_object":
return OpenAIResponseFormatJSONObject()
if text.format["type"] == "json_schema":
# Assert name exists for json_schema format
assert text.format.get("name"), "json_schema format requires a name"
schema_name: str = text.format["name"] # type: ignore[assignment]
return OpenAIResponseFormatJSONSchema(
json_schema=OpenAIJSONSchema(name=text.format["name"], schema=text.format["schema"])
json_schema=OpenAIJSONSchema(name=schema_name, schema=text.format["schema"])
)
raise ValueError(f"Unsupported text format: {text.format}")
@ -251,7 +260,7 @@ async def get_message_type_by_role(role: str) -> type[OpenAIMessageParam] | None
"assistant": OpenAIAssistantMessageParam,
"developer": OpenAIDeveloperMessageParam,
}
return role_to_type.get(role)
return role_to_type.get(role) # type: ignore[return-value] # Pydantic models use ModelMetaclass
def _extract_citations_from_text(
@ -320,7 +329,8 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[
# Look up shields to get their provider_resource_id (actual model ID)
model_ids = []
shields_list = await safety_api.routing_table.list_shields()
# TODO: list_shields not in Safety interface but available at runtime via API routing
shields_list = await safety_api.routing_table.list_shields() # type: ignore[attr-defined]
for guardrail_id in guardrail_ids:
matching_shields = [shield for shield in shields_list.data if shield.identifier == guardrail_id]
@ -337,7 +347,9 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[
for result in response.results:
if result.flagged:
message = result.user_message or "Content blocked by safety guardrails"
flagged_categories = [cat for cat, flagged in result.categories.items() if flagged]
flagged_categories = (
[cat for cat, flagged in result.categories.items() if flagged] if result.categories else []
)
violation_type = result.metadata.get("violation_type", []) if result.metadata else []
if flagged_categories:
@ -347,6 +359,9 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[
return message
# No violations found
return None
def extract_guardrail_ids(guardrails: list | None) -> list[str]:
"""Extract guardrail IDs from guardrails parameter, handling both string IDs and ResponseGuardrailSpec objects."""

View file

@ -6,7 +6,7 @@
import asyncio
from llama_stack.apis.inference import Message
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
from llama_stack.core.telemetry import tracing
from llama_stack.log import get_logger
@ -31,7 +31,7 @@ class ShieldRunnerMixin:
self.input_shields = input_shields
self.output_shields = output_shields
async def run_multiple_shields(self, messages: list[Message], identifiers: list[str]) -> None:
async def run_multiple_shields(self, messages: list[OpenAIMessageParam], identifiers: list[str]) -> None:
async def run_shield_with_span(identifier: str):
async with tracing.span(f"run_shield_{identifier}"):
return await self.safety_api.run_shield(

View file

@ -33,4 +33,5 @@ class AnthropicInferenceAdapter(OpenAIMixin):
return "https://api.anthropic.com/v1"
async def list_provider_model_ids(self) -> Iterable[str]:
return [m.id async for m in AsyncAnthropic(api_key=self.get_api_key()).models.list()]
api_key = self._get_api_key_from_config_or_provider_data()
return [m.id async for m in AsyncAnthropic(api_key=api_key).models.list()]

View file

@ -33,10 +33,11 @@ class DatabricksInferenceAdapter(OpenAIMixin):
async def list_provider_model_ids(self) -> Iterable[str]:
# Filter out None values from endpoint names
api_token = self._get_api_key_from_config_or_provider_data()
return [
endpoint.name # type: ignore[misc]
for endpoint in WorkspaceClient(
host=self.config.url, token=self.get_api_key()
host=self.config.url, token=api_token
).serving_endpoints.list() # TODO: this is not async
]

View file

@ -181,3 +181,22 @@ vlm_response = client.chat.completions.create(
print(f"VLM Response: {vlm_response.choices[0].message.content}")
```
### Rerank Example
The following example shows how to rerank documents using an NVIDIA NIM.
```python
rerank_response = client.alpha.inference.rerank(
model="nvidia/nvidia/llama-3.2-nv-rerankqa-1b-v2",
query="query",
items=[
"item_1",
"item_2",
"item_3",
],
)
for i, result in enumerate(rerank_response):
print(f"{i+1}. [Index: {result.index}, " f"Score: {(result.relevance_score):.3f}]")
```

View file

@ -28,6 +28,7 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
Attributes:
url (str): A base url for accessing the NVIDIA NIM, e.g. http://localhost:8000
api_key (str): The access key for the hosted NIM endpoints
rerank_model_to_url (dict[str, str]): Mapping of rerank model identifiers to their API endpoints
There are two ways to access NVIDIA NIMs -
0. Hosted: Preview APIs hosted at https://integrate.api.nvidia.com
@ -55,6 +56,14 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
default_factory=lambda: os.getenv("NVIDIA_APPEND_API_VERSION", "True").lower() != "false",
description="When set to false, the API version will not be appended to the base_url. By default, it is true.",
)
rerank_model_to_url: dict[str, str] = Field(
default_factory=lambda: {
"nv-rerank-qa-mistral-4b:1": "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking",
"nvidia/nv-rerankqa-mistral-4b-v3": "https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking",
"nvidia/llama-3.2-nv-rerankqa-1b-v2": "https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking",
},
description="Mapping of rerank model identifiers to their API endpoints. ",
)
@classmethod
def sample_run_config(

View file

@ -5,6 +5,19 @@
# the root directory of this source tree.
from collections.abc import Iterable
import aiohttp
from llama_stack.apis.inference import (
RerankData,
RerankResponse,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -61,3 +74,101 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
:return: The NVIDIA API base URL
"""
return f"{self.config.url}/v1" if self.config.append_api_version else self.config.url
async def list_provider_model_ids(self) -> Iterable[str]:
"""
Return both dynamic model IDs and statically configured rerank model IDs.
"""
dynamic_ids: Iterable[str] = []
try:
dynamic_ids = await super().list_provider_model_ids()
except Exception:
# If the dynamic listing fails, proceed with just configured rerank IDs
dynamic_ids = []
configured_rerank_ids = list(self.config.rerank_model_to_url.keys())
return list(dict.fromkeys(list(dynamic_ids) + configured_rerank_ids)) # remove duplicates
def construct_model_from_identifier(self, identifier: str) -> Model:
"""
Classify rerank models from config; otherwise use the base behavior.
"""
if identifier in self.config.rerank_model_to_url:
return Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=identifier,
identifier=identifier,
model_type=ModelType.rerank,
)
return super().construct_model_from_identifier(identifier)
async def rerank(
self,
model: str,
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
max_num_results: int | None = None,
) -> RerankResponse:
provider_model_id = await self._get_provider_model_id(model)
ranking_url = self.get_base_url()
if _is_nvidia_hosted(self.config) and provider_model_id in self.config.rerank_model_to_url:
ranking_url = self.config.rerank_model_to_url[provider_model_id]
logger.debug(f"Using rerank endpoint: {ranking_url} for model: {provider_model_id}")
# Convert query to text format
if isinstance(query, str):
query_text = query
elif isinstance(query, OpenAIChatCompletionContentPartTextParam):
query_text = query.text
else:
raise ValueError("Query must be a string or text content part")
# Convert items to text format
passages = []
for item in items:
if isinstance(item, str):
passages.append({"text": item})
elif isinstance(item, OpenAIChatCompletionContentPartTextParam):
passages.append({"text": item.text})
else:
raise ValueError("Items must be strings or text content parts")
payload = {
"model": provider_model_id,
"query": {"text": query_text},
"passages": passages,
}
headers = {
"Authorization": f"Bearer {self.get_api_key()}",
"Content-Type": "application/json",
}
try:
async with aiohttp.ClientSession() as session:
async with session.post(ranking_url, headers=headers, json=payload) as response:
if response.status != 200:
response_text = await response.text()
raise ConnectionError(
f"NVIDIA rerank API request failed with status {response.status}: {response_text}"
)
result = await response.json()
rankings = result.get("rankings", [])
# Convert to RerankData format
rerank_data = []
for ranking in rankings:
rerank_data.append(RerankData(index=ranking["index"], relevance_score=ranking["logit"]))
# Apply max_num_results limit
if max_num_results is not None:
rerank_data = rerank_data[:max_num_results]
return RerankResponse(data=rerank_data)
except aiohttp.ClientError as e:
raise ConnectionError(f"Failed to connect to NVIDIA rerank API at {ranking_url}: {e}") from e

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Iterable
import google.auth.transport.requests
from google.auth import default
@ -42,3 +43,12 @@ class VertexAIInferenceAdapter(OpenAIMixin):
Source: https://cloud.google.com/vertex-ai/generative-ai/docs/start/openai
"""
return f"https://{self.config.location}-aiplatform.googleapis.com/v1/projects/{self.config.project}/locations/{self.config.location}/endpoints/openapi"
async def list_provider_model_ids(self) -> Iterable[str]:
"""
VertexAI doesn't currently offer a way to query a list of available models from Google's Model Garden
For now we return a hardcoded version of the available models
:return: An iterable of model IDs
"""
return ["vertexai/gemini-2.0-flash", "vertexai/gemini-2.5-flash", "vertexai/gemini-2.5-pro"]

View file

@ -35,6 +35,7 @@ class InferenceStore:
self.reference = reference
self.sql_store = None
self.policy = policy
self.enable_write_queue = True
# Async write queue and worker control
self._queue: asyncio.Queue[tuple[OpenAIChatCompletion, list[OpenAIMessageParam]]] | None = None
@ -47,14 +48,13 @@ class InferenceStore:
base_store = sqlstore_impl(self.reference)
self.sql_store = AuthorizedSqlStore(base_store, self.policy)
# Disable write queue for SQLite to avoid concurrency issues
backend_name = self.reference.backend
backend_config = _SQLSTORE_BACKENDS.get(backend_name)
if backend_config is None:
raise ValueError(
f"Unregistered SQL backend '{backend_name}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
)
self.enable_write_queue = backend_config.type != StorageBackendType.SQL_SQLITE
# Disable write queue for SQLite since WAL mode handles concurrency
# Keep it enabled for other backends (like Postgres) for performance
backend_config = _SQLSTORE_BACKENDS.get(self.reference.backend)
if backend_config and backend_config.type == StorageBackendType.SQL_SQLITE:
self.enable_write_queue = False
logger.debug("Write queue disabled for SQLite (WAL mode handles concurrency)")
await self.sql_store.create_table(
"chat_completions",
{
@ -70,8 +70,9 @@ class InferenceStore:
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
for _ in range(self._num_writers):
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
else:
logger.info("Write queue disabled for SQLite to avoid concurrency issues")
logger.debug(
f"Inference store write queue enabled with {self._num_writers} writers, max queue size {self._max_write_queue_size}"
)
async def shutdown(self) -> None:
if not self._worker_tasks:

View file

@ -128,7 +128,9 @@ class LiteLLMOpenAIMixin(
return schema
async def _get_params(self, request: ChatCompletionRequest) -> dict:
input_dict = {}
from typing import Any
input_dict: dict[str, Any] = {}
input_dict["messages"] = [
await convert_message_to_openai_dict_new(m, download_images=self.download_images) for m in request.messages
@ -139,30 +141,27 @@ class LiteLLMOpenAIMixin(
f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported."
)
fmt = fmt.json_schema
name = fmt["title"]
del fmt["title"]
fmt["additionalProperties"] = False
# Convert to dict for manipulation
fmt_dict = dict(fmt.json_schema)
name = fmt_dict["title"]
del fmt_dict["title"]
fmt_dict["additionalProperties"] = False
# Apply additionalProperties: False recursively to all objects
fmt = self._add_additional_properties_recursive(fmt)
fmt_dict = self._add_additional_properties_recursive(fmt_dict)
input_dict["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": name,
"schema": fmt,
"schema": fmt_dict,
"strict": self.json_schema_strict,
},
}
if request.tools:
input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools]
if request.tool_config.tool_choice:
input_dict["tool_choice"] = (
request.tool_config.tool_choice.value
if isinstance(request.tool_config.tool_choice, ToolChoice)
else request.tool_config.tool_choice
)
if request.tool_config and (tool_choice := request.tool_config.tool_choice):
input_dict["tool_choice"] = tool_choice.value if isinstance(tool_choice, ToolChoice) else tool_choice
return {
"model": request.model,
@ -176,10 +175,10 @@ class LiteLLMOpenAIMixin(
def get_api_key(self) -> str:
provider_data = self.get_request_provider_data()
key_field = self.provider_data_api_key_field
if provider_data and getattr(provider_data, key_field, None):
api_key = getattr(provider_data, key_field)
else:
api_key = self.api_key_from_config
if provider_data and key_field and (api_key := getattr(provider_data, key_field, None)):
return str(api_key) # type: ignore[no-any-return] # getattr returns Any, can't narrow without runtime type inspection
api_key = self.api_key_from_config
if not api_key:
raise ValueError(
"API key is not set. Please provide a valid API key in the "
@ -192,7 +191,13 @@ class LiteLLMOpenAIMixin(
self,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
if not self.model_store:
raise ValueError("Model store is not initialized")
model_obj = await self.model_store.get_model(params.model)
if model_obj.provider_resource_id is None:
raise ValueError(f"Model {params.model} has no provider_resource_id")
provider_resource_id = model_obj.provider_resource_id
# Convert input to list if it's a string
input_list = [params.input] if isinstance(params.input, str) else params.input
@ -200,7 +205,7 @@ class LiteLLMOpenAIMixin(
# Call litellm embedding function
# litellm.drop_params = True
response = litellm.embedding(
model=self.get_litellm_model_name(model_obj.provider_resource_id),
model=self.get_litellm_model_name(provider_resource_id),
input=input_list,
api_key=self.get_api_key(),
api_base=self.api_base,
@ -217,7 +222,7 @@ class LiteLLMOpenAIMixin(
return OpenAIEmbeddingsResponse(
data=data,
model=model_obj.provider_resource_id,
model=provider_resource_id,
usage=usage,
)
@ -225,10 +230,16 @@ class LiteLLMOpenAIMixin(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
if not self.model_store:
raise ValueError("Model store is not initialized")
model_obj = await self.model_store.get_model(params.model)
if model_obj.provider_resource_id is None:
raise ValueError(f"Model {params.model} has no provider_resource_id")
provider_resource_id = model_obj.provider_resource_id
request_params = await prepare_openai_completion_params(
model=self.get_litellm_model_name(model_obj.provider_resource_id),
model=self.get_litellm_model_name(provider_resource_id),
prompt=params.prompt,
best_of=params.best_of,
echo=params.echo,
@ -249,7 +260,8 @@ class LiteLLMOpenAIMixin(
api_key=self.get_api_key(),
api_base=self.api_base,
)
return await litellm.atext_completion(**request_params)
# LiteLLM returns compatible type but mypy can't verify external library
return await litellm.atext_completion(**request_params) # type: ignore[no-any-return] # external lib lacks type stubs
async def openai_chat_completion(
self,
@ -265,10 +277,16 @@ class LiteLLMOpenAIMixin(
elif "include_usage" not in stream_options:
stream_options = {**stream_options, "include_usage": True}
if not self.model_store:
raise ValueError("Model store is not initialized")
model_obj = await self.model_store.get_model(params.model)
if model_obj.provider_resource_id is None:
raise ValueError(f"Model {params.model} has no provider_resource_id")
provider_resource_id = model_obj.provider_resource_id
request_params = await prepare_openai_completion_params(
model=self.get_litellm_model_name(model_obj.provider_resource_id),
model=self.get_litellm_model_name(provider_resource_id),
messages=params.messages,
frequency_penalty=params.frequency_penalty,
function_call=params.function_call,
@ -294,7 +312,8 @@ class LiteLLMOpenAIMixin(
api_key=self.get_api_key(),
api_base=self.api_base,
)
return await litellm.acompletion(**request_params)
# LiteLLM returns compatible type but mypy can't verify external library
return await litellm.acompletion(**request_params) # type: ignore[no-any-return] # external lib lacks type stubs
async def check_model_availability(self, model: str) -> bool:
"""

View file

@ -20,7 +20,7 @@ logger = get_logger(name=__name__, category="providers::utils")
class RemoteInferenceProviderConfig(BaseModel):
allowed_models: list[str] | None = Field( # TODO: make this non-optional and give a list() default
allowed_models: list[str] | None = Field(
default=None,
description="List of models that should be registered with the model registry. If None, all models are allowed.",
)

View file

@ -161,8 +161,10 @@ def get_sampling_strategy_options(params: SamplingParams) -> dict:
if isinstance(params.strategy, GreedySamplingStrategy):
options["temperature"] = 0.0
elif isinstance(params.strategy, TopPSamplingStrategy):
options["temperature"] = params.strategy.temperature
options["top_p"] = params.strategy.top_p
if params.strategy.temperature is not None:
options["temperature"] = params.strategy.temperature
if params.strategy.top_p is not None:
options["top_p"] = params.strategy.top_p
elif isinstance(params.strategy, TopKSamplingStrategy):
options["top_k"] = params.strategy.top_k
else:
@ -192,12 +194,12 @@ def get_sampling_options(params: SamplingParams | None) -> dict:
def text_from_choice(choice) -> str:
if hasattr(choice, "delta") and choice.delta:
return choice.delta.content
return choice.delta.content # type: ignore[no-any-return] # external OpenAI types lack precise annotations
if hasattr(choice, "message"):
return choice.message.content
return choice.message.content # type: ignore[no-any-return] # external OpenAI types lack precise annotations
return choice.text
return choice.text # type: ignore[no-any-return] # external OpenAI types lack precise annotations
def get_stop_reason(finish_reason: str) -> StopReason:
@ -216,7 +218,7 @@ def convert_openai_completion_logprobs(
) -> list[TokenLogProbs] | None:
if not logprobs:
return None
if hasattr(logprobs, "top_logprobs"):
if hasattr(logprobs, "top_logprobs") and logprobs.top_logprobs:
return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs]
# Together supports logprobs with top_k=1 only. This means for each token position,
@ -236,7 +238,7 @@ def convert_openai_completion_logprobs_stream(text: str, logprobs: float | OpenA
if isinstance(logprobs, float):
# Adapt response from Together CompletionChoicesChunk
return [TokenLogProbs(logprobs_by_token={text: logprobs})]
if hasattr(logprobs, "top_logprobs"):
if hasattr(logprobs, "top_logprobs") and logprobs.top_logprobs:
return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs]
return None
@ -245,23 +247,24 @@ def process_completion_response(
response: OpenAICompatCompletionResponse,
) -> CompletionResponse:
choice = response.choices[0]
text = choice.text or ""
# drop suffix <eot_id> if present and return stop reason as end of turn
if choice.text.endswith("<|eot_id|>"):
if text.endswith("<|eot_id|>"):
return CompletionResponse(
stop_reason=StopReason.end_of_turn,
content=choice.text[: -len("<|eot_id|>")],
content=text[: -len("<|eot_id|>")],
logprobs=convert_openai_completion_logprobs(choice.logprobs),
)
# drop suffix <eom_id> if present and return stop reason as end of message
if choice.text.endswith("<|eom_id|>"):
if text.endswith("<|eom_id|>"):
return CompletionResponse(
stop_reason=StopReason.end_of_message,
content=choice.text[: -len("<|eom_id|>")],
content=text[: -len("<|eom_id|>")],
logprobs=convert_openai_completion_logprobs(choice.logprobs),
)
return CompletionResponse(
stop_reason=get_stop_reason(choice.finish_reason),
content=choice.text,
stop_reason=get_stop_reason(choice.finish_reason or "stop"),
content=text,
logprobs=convert_openai_completion_logprobs(choice.logprobs),
)
@ -272,10 +275,10 @@ def process_chat_completion_response(
) -> ChatCompletionResponse:
choice = response.choices[0]
if choice.finish_reason == "tool_calls":
if not choice.message or not choice.message.tool_calls:
if not hasattr(choice, "message") or not choice.message or not choice.message.tool_calls: # type: ignore[attr-defined] # OpenAICompatCompletionChoice is runtime duck-typed
raise ValueError("Tool calls are not present in the response")
tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls]
tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls] # type: ignore[attr-defined] # OpenAICompatCompletionChoice is runtime duck-typed
if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls):
# If we couldn't parse a tool call, jsonify the tool calls and return them
return ChatCompletionResponse(
@ -287,9 +290,11 @@ def process_chat_completion_response(
)
else:
# Otherwise, return tool calls as normal
# Filter to only valid ToolCall objects
valid_tool_calls = [tc for tc in tool_calls if isinstance(tc, ToolCall)]
return ChatCompletionResponse(
completion_message=CompletionMessage(
tool_calls=tool_calls,
tool_calls=valid_tool_calls,
stop_reason=StopReason.end_of_turn,
# Content is not optional
content="",
@ -299,7 +304,7 @@ def process_chat_completion_response(
# TODO: This does not work well with tool calls for vLLM remote provider
# Ref: https://github.com/meta-llama/llama-stack/issues/1058
raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason))
raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason or "stop"))
# NOTE: If we do not set tools in chat-completion request, we should not
# expect the ToolCall in the response. Instead, we should return the raw
@ -324,8 +329,8 @@ def process_chat_completion_response(
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=raw_message.content,
stop_reason=raw_message.stop_reason,
content=raw_message.content, # type: ignore[arg-type] # decode_assistant_message returns Union[str, InterleavedContent]
stop_reason=raw_message.stop_reason or StopReason.end_of_turn,
tool_calls=raw_message.tool_calls,
),
logprobs=None,
@ -448,7 +453,7 @@ async def process_chat_completion_stream_response(
)
# parse tool calls and report errors
message = decode_assistant_message(buffer, stop_reason)
message = decode_assistant_message(buffer, stop_reason or StopReason.end_of_turn)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
@ -463,7 +468,7 @@ async def process_chat_completion_stream_response(
)
)
request_tools = {t.tool_name: t for t in request.tools}
request_tools = {t.tool_name: t for t in (request.tools or [])}
for tool_call in message.tool_calls:
if tool_call.tool_name in request_tools:
yield ChatCompletionResponseStreamChunk(
@ -525,7 +530,7 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
}
if hasattr(message, "tool_calls") and message.tool_calls:
result["tool_calls"] = []
tool_calls_list = []
for tc in message.tool_calls:
# The tool.tool_name can be a str or a BuiltinTool enum. If
# it's the latter, convert to a string.
@ -533,7 +538,7 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
if isinstance(tool_name, BuiltinTool):
tool_name = tool_name.value
result["tool_calls"].append(
tool_calls_list.append(
{
"id": tc.call_id,
"type": "function",
@ -543,6 +548,7 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
},
}
)
result["tool_calls"] = tool_calls_list # type: ignore[assignment] # dict allows Any value, stricter type expected
return result
@ -608,7 +614,7 @@ async def convert_message_to_openai_dict_new(
),
)
elif isinstance(content_, list):
return [await impl(item) for item in content_]
return [await impl(item) for item in content_] # type: ignore[misc] # recursive list comprehension confuses mypy's type narrowing
else:
raise ValueError(f"Unsupported content type: {type(content_)}")
@ -620,7 +626,7 @@ async def convert_message_to_openai_dict_new(
else:
return [ret]
out: OpenAIChatCompletionMessage = None
out: OpenAIChatCompletionMessage
if isinstance(message, UserMessage):
out = OpenAIChatCompletionUserMessage(
role="user",
@ -636,7 +642,7 @@ async def convert_message_to_openai_dict_new(
),
type="function",
)
for tool in message.tool_calls
for tool in (message.tool_calls or [])
]
params = {}
if tool_calls:
@ -644,18 +650,18 @@ async def convert_message_to_openai_dict_new(
out = OpenAIChatCompletionAssistantMessage(
role="assistant",
content=await _convert_message_content(message.content),
**params,
**params, # type: ignore[typeddict-item] # tool_calls dict expansion conflicts with TypedDict optional field
)
elif isinstance(message, ToolResponseMessage):
out = OpenAIChatCompletionToolMessage(
role="tool",
tool_call_id=message.call_id,
content=await _convert_message_content(message.content),
content=await _convert_message_content(message.content), # type: ignore[typeddict-item] # content union type incompatible with TypedDict str requirement
)
elif isinstance(message, SystemMessage):
out = OpenAIChatCompletionSystemMessage(
role="system",
content=await _convert_message_content(message.content),
content=await _convert_message_content(message.content), # type: ignore[typeddict-item] # content union type incompatible with TypedDict str requirement
)
else:
raise ValueError(f"Unsupported message type: {type(message)}")
@ -758,16 +764,16 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
function = out["function"]
if isinstance(tool.tool_name, BuiltinTool):
function["name"] = tool.tool_name.value
function["name"] = tool.tool_name.value # type: ignore[index] # dict value inferred as Any but mypy sees Collection[str]
else:
function["name"] = tool.tool_name
function["name"] = tool.tool_name # type: ignore[index] # dict value inferred as Any but mypy sees Collection[str]
if tool.description:
function["description"] = tool.description
function["description"] = tool.description # type: ignore[index] # dict value inferred as Any but mypy sees Collection[str]
if tool.input_schema:
# Pass through the entire JSON Schema as-is
function["parameters"] = tool.input_schema
function["parameters"] = tool.input_schema # type: ignore[index] # dict value inferred as Any but mypy sees Collection[str]
# NOTE: OpenAI does not support output_schema, so we drop it here
# It's stored in LlamaStack for validation and other provider usage
@ -815,15 +821,15 @@ def _convert_openai_request_tool_config(tool_choice: str | dict[str, Any] | None
tool_config = ToolConfig()
if tool_choice:
try:
tool_choice = ToolChoice(tool_choice)
tool_choice = ToolChoice(tool_choice) # type: ignore[assignment] # reassigning to enum narrows union but mypy can't track after exception
except ValueError:
pass
tool_config.tool_choice = tool_choice
tool_config.tool_choice = tool_choice # type: ignore[assignment] # ToolConfig.tool_choice accepts Union[ToolChoice, dict] but mypy tracks narrower type
return tool_config
def _convert_openai_request_tools(tools: list[dict[str, Any]] | None = None) -> list[ToolDefinition]:
lls_tools = []
lls_tools: list[ToolDefinition] = []
if not tools:
return lls_tools
@ -843,16 +849,16 @@ def _convert_openai_request_tools(tools: list[dict[str, Any]] | None = None) ->
def _convert_openai_request_response_format(
response_format: OpenAIResponseFormatParam = None,
response_format: OpenAIResponseFormatParam | None = None,
):
if not response_format:
return None
# response_format can be a dict or a pydantic model
response_format = dict(response_format)
if response_format.get("type", "") == "json_schema":
response_format_dict = dict(response_format) # type: ignore[arg-type] # OpenAIResponseFormatParam union needs dict conversion
if response_format_dict.get("type", "") == "json_schema":
return JsonSchemaResponseFormat(
type="json_schema",
json_schema=response_format.get("json_schema", {}).get("schema", ""),
type="json_schema", # type: ignore[arg-type] # Literal["json_schema"] incompatible with expected type
json_schema=response_format_dict.get("json_schema", {}).get("schema", ""),
)
return None
@ -938,16 +944,15 @@ def _convert_openai_sampling_params(
# Map an explicit temperature of 0 to greedy sampling
if temperature == 0:
strategy = GreedySamplingStrategy()
sampling_params.strategy = GreedySamplingStrategy()
else:
# OpenAI defaults to 1.0 for temperature and top_p if unset
if temperature is None:
temperature = 1.0
if top_p is None:
top_p = 1.0
strategy = TopPSamplingStrategy(temperature=temperature, top_p=top_p)
sampling_params.strategy = TopPSamplingStrategy(temperature=temperature, top_p=top_p) # type: ignore[assignment] # SamplingParams.strategy union accepts this type
sampling_params.strategy = strategy
return sampling_params
@ -957,23 +962,24 @@ def openai_messages_to_messages(
"""
Convert a list of OpenAIChatCompletionMessage into a list of Message.
"""
converted_messages = []
converted_messages: list[Message] = []
for message in messages:
converted_message: Message
if message.role == "system":
converted_message = SystemMessage(content=openai_content_to_content(message.content))
converted_message = SystemMessage(content=openai_content_to_content(message.content)) # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types
elif message.role == "user":
converted_message = UserMessage(content=openai_content_to_content(message.content))
converted_message = UserMessage(content=openai_content_to_content(message.content)) # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types
elif message.role == "assistant":
converted_message = CompletionMessage(
content=openai_content_to_content(message.content),
tool_calls=_convert_openai_tool_calls(message.tool_calls),
content=openai_content_to_content(message.content), # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types
tool_calls=_convert_openai_tool_calls(message.tool_calls) if message.tool_calls else [], # type: ignore[arg-type] # OpenAI tool_calls type incompatible with conversion function
stop_reason=StopReason.end_of_turn,
)
elif message.role == "tool":
converted_message = ToolResponseMessage(
role="tool",
call_id=message.tool_call_id,
content=openai_content_to_content(message.content),
content=openai_content_to_content(message.content), # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types
)
else:
raise ValueError(f"Unknown role {message.role}")
@ -990,9 +996,9 @@ def openai_content_to_content(content: str | Iterable[OpenAIChatCompletionConten
return [openai_content_to_content(c) for c in content]
elif hasattr(content, "type"):
if content.type == "text":
return TextContentItem(type="text", text=content.text)
return TextContentItem(type="text", text=content.text) # type: ignore[attr-defined] # Iterable narrowed by hasattr check but mypy doesn't track
elif content.type == "image_url":
return ImageContentItem(type="image", image=_URLOrData(url=URL(uri=content.image_url.url)))
return ImageContentItem(type="image", image=_URLOrData(url=URL(uri=content.image_url.url))) # type: ignore[attr-defined] # Iterable narrowed by hasattr check but mypy doesn't track
else:
raise ValueError(f"Unknown content type: {content.type}")
else:
@ -1041,9 +1047,9 @@ def convert_openai_chat_completion_choice(
completion_message=CompletionMessage(
content=choice.message.content or "", # CompletionMessage content is not optional
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
tool_calls=_convert_openai_tool_calls(choice.message.tool_calls),
tool_calls=_convert_openai_tool_calls(choice.message.tool_calls) if choice.message.tool_calls else [], # type: ignore[arg-type] # OpenAI tool_calls Optional type broadens union
),
logprobs=_convert_openai_logprobs(getattr(choice, "logprobs", None)),
logprobs=_convert_openai_logprobs(getattr(choice, "logprobs", None)), # type: ignore[arg-type] # getattr returns Any, can't narrow without inspection
)
@ -1070,7 +1076,7 @@ async def convert_openai_chat_completion_stream(
choice = chunk.choices[0] # assuming only one choice per chunk
# we assume there's only one finish_reason in the stream
stop_reason = _convert_openai_finish_reason(choice.finish_reason) or stop_reason
stop_reason = _convert_openai_finish_reason(choice.finish_reason) if choice.finish_reason else stop_reason
logprobs = getattr(choice, "logprobs", None)
# if there's a tool call, emit an event for each tool in the list
@ -1083,7 +1089,7 @@ async def convert_openai_chat_completion_stream(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=TextDelta(text=choice.delta.content),
logprobs=_convert_openai_logprobs(logprobs),
logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result
)
)
@ -1101,10 +1107,10 @@ async def convert_openai_chat_completion_stream(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=ToolCallDelta(
tool_call=_convert_openai_tool_calls([tool_call])[0],
tool_call=_convert_openai_tool_calls([tool_call])[0], # type: ignore[arg-type, list-item] # delta tool_call type differs from complete tool_call
parse_status=ToolCallParseStatus.succeeded,
),
logprobs=_convert_openai_logprobs(logprobs),
logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result
)
)
else:
@ -1125,12 +1131,15 @@ async def convert_openai_chat_completion_stream(
if tool_call.function.name:
buffer["name"] = tool_call.function.name
delta = f"{buffer['name']}("
buffer["content"] += delta
if buffer["content"] is not None:
buffer["content"] += delta
if tool_call.function.arguments:
delta = tool_call.function.arguments
buffer["arguments"] += delta
buffer["content"] += delta
if buffer["arguments"] is not None and delta:
buffer["arguments"] += delta
if buffer["content"] is not None and delta:
buffer["content"] += delta
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
@ -1139,7 +1148,7 @@ async def convert_openai_chat_completion_stream(
tool_call=delta,
parse_status=ToolCallParseStatus.in_progress,
),
logprobs=_convert_openai_logprobs(logprobs),
logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result
)
)
elif choice.delta.content:
@ -1147,7 +1156,7 @@ async def convert_openai_chat_completion_stream(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=TextDelta(text=choice.delta.content or ""),
logprobs=_convert_openai_logprobs(logprobs),
logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result
)
)
@ -1155,7 +1164,8 @@ async def convert_openai_chat_completion_stream(
logger.debug(f"toolcall_buffer[{idx}]: {buffer}")
if buffer["name"]:
delta = ")"
buffer["content"] += delta
if buffer["content"] is not None:
buffer["content"] += delta
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
@ -1168,16 +1178,16 @@ async def convert_openai_chat_completion_stream(
)
try:
tool_call = ToolCall(
call_id=buffer["call_id"],
tool_name=buffer["name"],
arguments=buffer["arguments"],
parsed_tool_call = ToolCall(
call_id=buffer["call_id"] or "",
tool_name=buffer["name"] or "",
arguments=buffer["arguments"] or "",
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call=tool_call,
tool_call=parsed_tool_call, # type: ignore[arg-type] # ToolCallDelta.tool_call accepts Union[str, ToolCall]
parse_status=ToolCallParseStatus.succeeded,
),
stop_reason=stop_reason,
@ -1189,7 +1199,7 @@ async def convert_openai_chat_completion_stream(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call=buffer["content"],
tool_call=buffer["content"], # type: ignore[arg-type] # ToolCallDelta.tool_call accepts Union[str, ToolCall]
parse_status=ToolCallParseStatus.failed,
),
stop_reason=stop_reason,
@ -1250,7 +1260,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
messages = openai_messages_to_messages(messages)
messages = openai_messages_to_messages(messages) # type: ignore[assignment] # converted from OpenAI to LlamaStack message format
response_format = _convert_openai_request_response_format(response_format)
sampling_params = _convert_openai_sampling_params(
max_tokens=max_tokens,
@ -1259,15 +1269,15 @@ class OpenAIChatCompletionToLlamaStackMixin:
)
tool_config = _convert_openai_request_tool_config(tool_choice)
tools = _convert_openai_request_tools(tools)
tools = _convert_openai_request_tools(tools) # type: ignore[assignment] # converted from OpenAI to LlamaStack tool format
if tool_config.tool_choice == ToolChoice.none:
tools = []
tools = [] # type: ignore[assignment] # empty list narrows return type but mypy tracks broader type
outstanding_responses = []
# "n" is the number of completions to generate per prompt
n = n or 1
for _i in range(0, n):
response = self.chat_completion(
response = self.chat_completion( # type: ignore[attr-defined] # mixin expects class to implement chat_completion
model_id=model,
messages=messages,
sampling_params=sampling_params,
@ -1279,7 +1289,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
outstanding_responses.append(response)
if stream:
return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(self, model, outstanding_responses)
return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(self, model, outstanding_responses) # type: ignore[no-any-return] # mixin async generator return type too complex for mypy
return await OpenAIChatCompletionToLlamaStackMixin._process_non_stream_response(
self, model, outstanding_responses
@ -1295,14 +1305,16 @@ class OpenAIChatCompletionToLlamaStackMixin:
response = await outstanding_response
async for chunk in response:
event = chunk.event
finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason)
finish_reason = (
_convert_stop_reason_to_openai_finish_reason(event.stop_reason) if event.stop_reason else None
)
if isinstance(event.delta, TextDelta):
text_delta = event.delta.text
delta = OpenAIChoiceDelta(content=text_delta)
yield OpenAIChatCompletionChunk(
id=id,
choices=[OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)],
choices=[OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)], # type: ignore[arg-type] # finish_reason Optional[str] incompatible with Literal union
created=int(time.time()),
model=model,
object="chat.completion.chunk",
@ -1310,13 +1322,17 @@ class OpenAIChatCompletionToLlamaStackMixin:
elif isinstance(event.delta, ToolCallDelta):
if event.delta.parse_status == ToolCallParseStatus.succeeded:
tool_call = event.delta.tool_call
if isinstance(tool_call, str):
continue
# First chunk includes full structure
openai_tool_call = OpenAIChoiceDeltaToolCall(
index=0,
id=tool_call.call_id,
function=OpenAIChoiceDeltaToolCallFunction(
name=tool_call.tool_name,
name=tool_call.tool_name
if isinstance(tool_call.tool_name, str)
else tool_call.tool_name.value, # type: ignore[arg-type] # enum .value extraction on Union confuses mypy
arguments="",
),
)
@ -1324,7 +1340,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
yield OpenAIChatCompletionChunk(
id=id,
choices=[
OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)
OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta) # type: ignore[arg-type] # finish_reason Optional[str] incompatible with Literal union
],
created=int(time.time()),
model=model,
@ -1341,7 +1357,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
yield OpenAIChatCompletionChunk(
id=id,
choices=[
OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)
OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta) # type: ignore[arg-type] # finish_reason Optional[str] incompatible with Literal union
],
created=int(time.time()),
model=model,
@ -1351,7 +1367,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
async def _process_non_stream_response(
self, model: str, outstanding_responses: list[Awaitable[ChatCompletionResponse]]
) -> OpenAIChatCompletion:
choices = []
choices: list[OpenAIChatCompletionChoice] = []
for outstanding_response in outstanding_responses:
response = await outstanding_response
completion_message = response.completion_message
@ -1360,14 +1376,14 @@ class OpenAIChatCompletionToLlamaStackMixin:
choice = OpenAIChatCompletionChoice(
index=len(choices),
message=message,
message=message, # type: ignore[arg-type] # OpenAIChatCompletionMessage union incompatible with narrower Message type
finish_reason=finish_reason,
)
choices.append(choice)
choices.append(choice) # type: ignore[arg-type] # OpenAIChatCompletionChoice type annotation mismatch
return OpenAIChatCompletion(
id=f"chatcmpl-{uuid.uuid4()}",
choices=choices,
choices=choices, # type: ignore[arg-type] # list[OpenAIChatCompletionChoice] union incompatible
created=int(time.time()),
model=model,
object="chat.completion",

View file

@ -83,9 +83,6 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
# This is set in list_models() and used in check_model_availability()
_model_cache: dict[str, Model] = {}
# List of allowed models for this provider, if empty all models allowed
allowed_models: list[str] = []
# Optional field name in provider data to look for API key, which takes precedence
provider_data_api_key_field: str | None = None
@ -441,7 +438,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
for provider_model_id in provider_models_ids:
if not isinstance(provider_model_id, str):
raise ValueError(f"Model ID {provider_model_id} from list_provider_model_ids() is not a string")
if self.allowed_models and provider_model_id not in self.allowed_models:
if self.config.allowed_models is not None and provider_model_id not in self.config.allowed_models:
logger.info(f"Skipping model {provider_model_id} as it is not in the allowed models list")
continue
model = self.construct_model_from_identifier(provider_model_id)

View file

@ -196,6 +196,7 @@ def make_overlapped_chunks(
chunks.append(
Chunk(
content=chunk,
chunk_id=chunk_id,
metadata=chunk_metadata,
chunk_metadata=backend_chunk_metadata,
)

View file

@ -70,13 +70,13 @@ class ResponsesStore:
base_store = sqlstore_impl(self.reference)
self.sql_store = AuthorizedSqlStore(base_store, self.policy)
# Disable write queue for SQLite since WAL mode handles concurrency
# Keep it enabled for other backends (like Postgres) for performance
backend_config = _SQLSTORE_BACKENDS.get(self.reference.backend)
if backend_config is None:
raise ValueError(
f"Unregistered SQL backend '{self.reference.backend}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
)
if backend_config.type == StorageBackendType.SQL_SQLITE:
if backend_config and backend_config.type == StorageBackendType.SQL_SQLITE:
self.enable_write_queue = False
logger.debug("Write queue disabled for SQLite (WAL mode handles concurrency)")
await self.sql_store.create_table(
"openai_responses",
{
@ -99,8 +99,9 @@ class ResponsesStore:
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
for _ in range(self._num_writers):
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
else:
logger.debug("Write queue disabled for SQLite to avoid concurrency issues")
logger.debug(
f"Responses store write queue enabled with {self._num_writers} writers, max queue size {self._max_write_queue_size}"
)
async def shutdown(self) -> None:
if not self._worker_tasks:

View file

@ -17,6 +17,7 @@ from sqlalchemy import (
String,
Table,
Text,
event,
inspect,
select,
text,
@ -75,7 +76,36 @@ class SqlAlchemySqlStoreImpl(SqlStore):
self.metadata = MetaData()
def create_engine(self) -> AsyncEngine:
return create_async_engine(self.config.engine_str, pool_pre_ping=True)
# Configure connection args for better concurrency support
connect_args = {}
if "sqlite" in self.config.engine_str:
# SQLite-specific optimizations for concurrent access
# With WAL mode, most locks resolve in milliseconds, but allow up to 5s for edge cases
connect_args["timeout"] = 5.0
connect_args["check_same_thread"] = False # Allow usage across asyncio tasks
engine = create_async_engine(
self.config.engine_str,
pool_pre_ping=True,
connect_args=connect_args,
)
# Enable WAL mode for SQLite to support concurrent readers and writers
if "sqlite" in self.config.engine_str:
@event.listens_for(engine.sync_engine, "connect")
def set_sqlite_pragma(dbapi_conn, connection_record):
cursor = dbapi_conn.cursor()
# Enable Write-Ahead Logging for better concurrency
cursor.execute("PRAGMA journal_mode=WAL")
# Set busy timeout to 5 seconds (retry instead of immediate failure)
# With WAL mode, locks should be brief; if we hit 5s there's a bigger issue
cursor.execute("PRAGMA busy_timeout=5000")
# Use NORMAL synchronous mode for better performance (still safe with WAL)
cursor.execute("PRAGMA synchronous=NORMAL")
cursor.close()
return engine
async def create_table(
self,

View file

@ -430,6 +430,32 @@ def _unwrap_generic_list(typ: type[list[T]]) -> type[T]:
return list_type # type: ignore[no-any-return]
def is_generic_sequence(typ: object) -> bool:
"True if the specified type is a generic Sequence, i.e. `Sequence[T]`."
import collections.abc
typ = unwrap_annotated_type(typ)
return typing.get_origin(typ) is collections.abc.Sequence
def unwrap_generic_sequence(typ: object) -> type:
"""
Extracts the item type of a Sequence type.
:param typ: The Sequence type `Sequence[T]`.
:returns: The item type `T`.
"""
return rewrap_annotated_type(_unwrap_generic_sequence, typ) # type: ignore[arg-type]
def _unwrap_generic_sequence(typ: object) -> type:
"Extracts the item type of a Sequence type (e.g. returns `T` for `Sequence[T]`)."
(sequence_type,) = typing.get_args(typ) # unpack single tuple element
return sequence_type # type: ignore[no-any-return]
def is_generic_set(typ: object) -> TypeGuard[type[set]]:
"True if the specified type is a generic set, i.e. `Set[T]`."

View file

@ -18,10 +18,12 @@ from .inspection import (
TypeLike,
is_generic_dict,
is_generic_list,
is_generic_sequence,
is_type_optional,
is_type_union,
unwrap_generic_dict,
unwrap_generic_list,
unwrap_generic_sequence,
unwrap_optional_type,
unwrap_union_types,
)
@ -155,24 +157,28 @@ def python_type_to_name(data_type: TypeLike, force: bool = False) -> str:
if metadata is not None:
# type is Annotated[T, ...]
arg = typing.get_args(data_type)[0]
return python_type_to_name(arg)
return python_type_to_name(arg, force=force)
if force:
# generic types
if is_type_optional(data_type, strict=True):
inner_name = python_type_to_name(unwrap_optional_type(data_type))
inner_name = python_type_to_name(unwrap_optional_type(data_type), force=True)
return f"Optional__{inner_name}"
elif is_generic_list(data_type):
item_name = python_type_to_name(unwrap_generic_list(data_type))
item_name = python_type_to_name(unwrap_generic_list(data_type), force=True)
return f"List__{item_name}"
elif is_generic_sequence(data_type):
# Treat Sequence the same as List for schema generation purposes
item_name = python_type_to_name(unwrap_generic_sequence(data_type), force=True)
return f"List__{item_name}"
elif is_generic_dict(data_type):
key_type, value_type = unwrap_generic_dict(data_type)
key_name = python_type_to_name(key_type)
value_name = python_type_to_name(value_type)
key_name = python_type_to_name(key_type, force=True)
value_name = python_type_to_name(value_type, force=True)
return f"Dict__{key_name}__{value_name}"
elif is_type_union(data_type):
member_types = unwrap_union_types(data_type)
member_names = "__".join(python_type_to_name(member_type) for member_type in member_types)
member_names = "__".join(python_type_to_name(member_type, force=True) for member_type in member_types)
return f"Union__{member_names}"
# named system or user-defined type

View file

@ -111,7 +111,7 @@ def get_class_property_docstrings(
def docstring_to_schema(data_type: type) -> Schema:
short_description, long_description = get_class_docstrings(data_type)
schema: Schema = {
"title": python_type_to_name(data_type),
"title": python_type_to_name(data_type, force=True),
}
description = "\n".join(filter(None, [short_description, long_description]))
@ -417,6 +417,10 @@ class JsonSchemaGenerator:
if origin_type is list:
(list_type,) = typing.get_args(typ) # unpack single tuple element
return {"type": "array", "items": self.type_to_schema(list_type)}
elif origin_type is collections.abc.Sequence:
# Treat Sequence the same as list for JSON schema (both are arrays)
(sequence_type,) = typing.get_args(typ) # unpack single tuple element
return {"type": "array", "items": self.type_to_schema(sequence_type)}
elif origin_type is dict:
key_type, value_type = typing.get_args(typ)
if not (key_type is str or key_type is int or is_type_enum(key_type)):

View file

@ -156,7 +156,7 @@ def normalize_inference_request(method: str, url: str, headers: dict[str, Any],
}
# Include test_id for isolation, except for shared infrastructure endpoints
if parsed.path not in ("/api/tags", "/v1/models"):
if parsed.path not in ("/api/tags", "/v1/models", "/v1/openai/v1/models"):
normalized["test_id"] = test_id
normalized_json = json.dumps(normalized, sort_keys=True)
@ -430,7 +430,7 @@ class ResponseStorage:
# For model-list endpoints, include digest in filename to distinguish different model sets
endpoint = request.get("endpoint")
if endpoint in ("/api/tags", "/v1/models"):
if endpoint in ("/api/tags", "/v1/models", "/v1/openai/v1/models"):
digest = _model_identifiers_digest(endpoint, response)
response_file = f"models-{request_hash}-{digest}.json"
@ -554,13 +554,14 @@ def _model_identifiers_digest(endpoint: str, response: dict[str, Any]) -> str:
Supported endpoints:
- '/api/tags' (Ollama): response body has 'models': [ { name/model/digest/id/... }, ... ]
- '/v1/models' (OpenAI): response body is: [ { id: ... }, ... ]
- '/v1/openai/v1/models' (OpenAI): response body is: [ { id: ... }, ... ]
Returns a list of unique identifiers or None if structure doesn't match.
"""
if "models" in response["body"]:
# ollama
items = response["body"]["models"]
else:
# openai
# openai or openai-style endpoints
items = response["body"]
idents = [m.model if endpoint == "/api/tags" else m.id for m in items]
return sorted(set(idents))
@ -581,7 +582,7 @@ def _combine_model_list_responses(endpoint: str, records: list[dict[str, Any]])
seen: dict[str, dict[str, Any]] = {}
for rec in records:
body = rec["response"]["body"]
if endpoint == "/v1/models":
if endpoint in ("/v1/models", "/v1/openai/v1/models"):
for m in body:
key = m.id
seen[key] = m
@ -665,7 +666,7 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
logger.info(f" Test context: {get_test_context()}")
if mode == APIRecordingMode.LIVE or storage is None:
if endpoint == "/v1/models":
if endpoint in ("/v1/models", "/v1/openai/v1/models"):
return original_method(self, *args, **kwargs)
else:
return await original_method(self, *args, **kwargs)
@ -699,7 +700,7 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
recording = None
if mode == APIRecordingMode.REPLAY or mode == APIRecordingMode.RECORD_IF_MISSING:
# Special handling for model-list endpoints: merge all recordings with this hash
if endpoint in ("/api/tags", "/v1/models"):
if endpoint in ("/api/tags", "/v1/models", "/v1/openai/v1/models"):
records = storage._model_list_responses(request_hash)
recording = _combine_model_list_responses(endpoint, records)
else:
@ -739,13 +740,13 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
)
if mode == APIRecordingMode.RECORD or (mode == APIRecordingMode.RECORD_IF_MISSING and not recording):
if endpoint == "/v1/models":
if endpoint in ("/v1/models", "/v1/openai/v1/models"):
response = original_method(self, *args, **kwargs)
else:
response = await original_method(self, *args, **kwargs)
# we want to store the result of the iterator, not the iterator itself
if endpoint == "/v1/models":
if endpoint in ("/v1/models", "/v1/openai/v1/models"):
response = [m async for m in response]
request_data = {

View file

@ -51,10 +51,14 @@ async function proxyRequest(request: NextRequest, method: string) {
);
// Create response with same status and headers
const proxyResponse = new NextResponse(responseText, {
status: response.status,
statusText: response.statusText,
});
// Handle 204 No Content responses specially
const proxyResponse =
response.status === 204
? new NextResponse(null, { status: 204 })
: new NextResponse(responseText, {
status: response.status,
statusText: response.statusText,
});
// Copy response headers (except problematic ones)
response.headers.forEach((value, key) => {

View file

@ -0,0 +1,5 @@
import { PromptManagement } from "@/components/prompts";
export default function PromptsPage() {
return <PromptManagement />;
}

View file

@ -8,6 +8,7 @@ import {
MessageCircle,
Settings2,
Compass,
FileText,
} from "lucide-react";
import Link from "next/link";
import { usePathname } from "next/navigation";
@ -50,6 +51,11 @@ const manageItems = [
url: "/logs/vector-stores",
icon: Database,
},
{
title: "Prompts",
url: "/prompts",
icon: FileText,
},
{
title: "Documentation",
url: "https://llama-stack.readthedocs.io/en/latest/references/api_reference/index.html",

View file

@ -0,0 +1,4 @@
export { PromptManagement } from "./prompt-management";
export { PromptList } from "./prompt-list";
export { PromptEditor } from "./prompt-editor";
export * from "./types";

View file

@ -0,0 +1,309 @@
import React from "react";
import { render, screen, fireEvent } from "@testing-library/react";
import "@testing-library/jest-dom";
import { PromptEditor } from "./prompt-editor";
import type { Prompt, PromptFormData } from "./types";
describe("PromptEditor", () => {
const mockOnSave = jest.fn();
const mockOnCancel = jest.fn();
const mockOnDelete = jest.fn();
const defaultProps = {
onSave: mockOnSave,
onCancel: mockOnCancel,
onDelete: mockOnDelete,
};
beforeEach(() => {
jest.clearAllMocks();
});
describe("Create Mode", () => {
test("renders create form correctly", () => {
render(<PromptEditor {...defaultProps} />);
expect(screen.getByLabelText("Prompt Content *")).toBeInTheDocument();
expect(screen.getByText("Variables")).toBeInTheDocument();
expect(screen.getByText("Preview")).toBeInTheDocument();
expect(screen.getByText("Create Prompt")).toBeInTheDocument();
expect(screen.getByText("Cancel")).toBeInTheDocument();
});
test("shows preview placeholder when no content", () => {
render(<PromptEditor {...defaultProps} />);
expect(
screen.getByText("Enter content to preview the compiled prompt")
).toBeInTheDocument();
});
test("submits form with correct data", () => {
render(<PromptEditor {...defaultProps} />);
const promptInput = screen.getByLabelText("Prompt Content *");
fireEvent.change(promptInput, {
target: { value: "Hello {{name}}, welcome!" },
});
fireEvent.click(screen.getByText("Create Prompt"));
expect(mockOnSave).toHaveBeenCalledWith({
prompt: "Hello {{name}}, welcome!",
variables: [],
});
});
test("prevents submission with empty prompt", () => {
render(<PromptEditor {...defaultProps} />);
fireEvent.click(screen.getByText("Create Prompt"));
expect(mockOnSave).not.toHaveBeenCalled();
});
});
describe("Edit Mode", () => {
const mockPrompt: Prompt = {
prompt_id: "prompt_123",
prompt: "Hello {{name}}, how is {{weather}}?",
version: 1,
variables: ["name", "weather"],
is_default: true,
};
test("renders edit form with existing data", () => {
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
expect(
screen.getByDisplayValue("Hello {{name}}, how is {{weather}}?")
).toBeInTheDocument();
expect(screen.getAllByText("name")).toHaveLength(2); // One in variables, one in preview
expect(screen.getAllByText("weather")).toHaveLength(2); // One in variables, one in preview
expect(screen.getByText("Update Prompt")).toBeInTheDocument();
expect(screen.getByText("Delete Prompt")).toBeInTheDocument();
});
test("submits updated data correctly", () => {
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
const promptInput = screen.getByLabelText("Prompt Content *");
fireEvent.change(promptInput, {
target: { value: "Updated: Hello {{name}}!" },
});
fireEvent.click(screen.getByText("Update Prompt"));
expect(mockOnSave).toHaveBeenCalledWith({
prompt: "Updated: Hello {{name}}!",
variables: ["name", "weather"],
});
});
});
describe("Variables Management", () => {
test("adds new variable", () => {
render(<PromptEditor {...defaultProps} />);
const variableInput = screen.getByPlaceholderText(
"Add variable name (e.g. user_name, topic)"
);
fireEvent.change(variableInput, { target: { value: "testVar" } });
fireEvent.click(screen.getByText("Add"));
expect(screen.getByText("testVar")).toBeInTheDocument();
});
test("prevents adding duplicate variables", () => {
render(<PromptEditor {...defaultProps} />);
const variableInput = screen.getByPlaceholderText(
"Add variable name (e.g. user_name, topic)"
);
// Add first variable
fireEvent.change(variableInput, { target: { value: "test" } });
fireEvent.click(screen.getByText("Add"));
// Try to add same variable again
fireEvent.change(variableInput, { target: { value: "test" } });
// Button should be disabled
expect(screen.getByText("Add")).toBeDisabled();
});
test("removes variable", () => {
const mockPrompt: Prompt = {
prompt_id: "prompt_123",
prompt: "Hello {{name}}",
version: 1,
variables: ["name", "location"],
is_default: true,
};
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
// Check that both variables are present initially
expect(screen.getAllByText("name").length).toBeGreaterThan(0);
expect(screen.getAllByText("location").length).toBeGreaterThan(0);
// Remove the location variable by clicking the X button with the specific title
const removeLocationButton = screen.getByTitle(
"Remove location variable"
);
fireEvent.click(removeLocationButton);
// Name should still be there, location should be gone from the variables section
expect(screen.getAllByText("name").length).toBeGreaterThan(0);
expect(
screen.queryByTitle("Remove location variable")
).not.toBeInTheDocument();
});
test("adds variable on Enter key", () => {
render(<PromptEditor {...defaultProps} />);
const variableInput = screen.getByPlaceholderText(
"Add variable name (e.g. user_name, topic)"
);
fireEvent.change(variableInput, { target: { value: "enterVar" } });
// Simulate Enter key press
fireEvent.keyPress(variableInput, {
key: "Enter",
code: "Enter",
charCode: 13,
preventDefault: jest.fn(),
});
// Check if the variable was added by looking for the badge
expect(screen.getAllByText("enterVar").length).toBeGreaterThan(0);
});
});
describe("Preview Functionality", () => {
test("shows live preview with variables", () => {
render(<PromptEditor {...defaultProps} />);
// Add prompt content
const promptInput = screen.getByLabelText("Prompt Content *");
fireEvent.change(promptInput, {
target: { value: "Hello {{name}}, welcome to {{place}}!" },
});
// Add variables
const variableInput = screen.getByPlaceholderText(
"Add variable name (e.g. user_name, topic)"
);
fireEvent.change(variableInput, { target: { value: "name" } });
fireEvent.click(screen.getByText("Add"));
fireEvent.change(variableInput, { target: { value: "place" } });
fireEvent.click(screen.getByText("Add"));
// Check that preview area shows the content
expect(screen.getByText("Compiled Prompt")).toBeInTheDocument();
});
test("shows variable value inputs in preview", () => {
const mockPrompt: Prompt = {
prompt_id: "prompt_123",
prompt: "Hello {{name}}",
version: 1,
variables: ["name"],
is_default: true,
};
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
expect(screen.getByText("Variable Values")).toBeInTheDocument();
expect(
screen.getByPlaceholderText("Enter value for name")
).toBeInTheDocument();
});
test("shows color legend for variable states", () => {
render(<PromptEditor {...defaultProps} />);
// Add content to show preview
const promptInput = screen.getByLabelText("Prompt Content *");
fireEvent.change(promptInput, {
target: { value: "Hello {{name}}" },
});
expect(screen.getByText("Used")).toBeInTheDocument();
expect(screen.getByText("Unused")).toBeInTheDocument();
expect(screen.getByText("Undefined")).toBeInTheDocument();
});
});
describe("Error Handling", () => {
test("displays error message", () => {
const errorMessage = "Prompt contains undeclared variables";
render(<PromptEditor {...defaultProps} error={errorMessage} />);
expect(screen.getByText(errorMessage)).toBeInTheDocument();
});
});
describe("Delete Functionality", () => {
const mockPrompt: Prompt = {
prompt_id: "prompt_123",
prompt: "Hello {{name}}",
version: 1,
variables: ["name"],
is_default: true,
};
test("shows delete button in edit mode", () => {
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
expect(screen.getByText("Delete Prompt")).toBeInTheDocument();
});
test("hides delete button in create mode", () => {
render(<PromptEditor {...defaultProps} />);
expect(screen.queryByText("Delete Prompt")).not.toBeInTheDocument();
});
test("calls onDelete with confirmation", () => {
const originalConfirm = window.confirm;
window.confirm = jest.fn(() => true);
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
fireEvent.click(screen.getByText("Delete Prompt"));
expect(window.confirm).toHaveBeenCalledWith(
"Are you sure you want to delete this prompt? This action cannot be undone."
);
expect(mockOnDelete).toHaveBeenCalledWith("prompt_123");
window.confirm = originalConfirm;
});
test("does not delete when confirmation is cancelled", () => {
const originalConfirm = window.confirm;
window.confirm = jest.fn(() => false);
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
fireEvent.click(screen.getByText("Delete Prompt"));
expect(mockOnDelete).not.toHaveBeenCalled();
window.confirm = originalConfirm;
});
});
describe("Cancel Functionality", () => {
test("calls onCancel when cancel button is clicked", () => {
render(<PromptEditor {...defaultProps} />);
fireEvent.click(screen.getByText("Cancel"));
expect(mockOnCancel).toHaveBeenCalled();
});
});
});

View file

@ -0,0 +1,346 @@
"use client";
import { useState, useEffect } from "react";
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
import { Label } from "@/components/ui/label";
import { Textarea } from "@/components/ui/textarea";
import { Badge } from "@/components/ui/badge";
import {
Card,
CardContent,
CardDescription,
CardHeader,
CardTitle,
} from "@/components/ui/card";
import { Separator } from "@/components/ui/separator";
import { X, Plus, Save, Trash2 } from "lucide-react";
import { Prompt, PromptFormData } from "./types";
interface PromptEditorProps {
prompt?: Prompt;
onSave: (prompt: PromptFormData) => void;
onCancel: () => void;
onDelete?: (promptId: string) => void;
error?: string | null;
}
export function PromptEditor({
prompt,
onSave,
onCancel,
onDelete,
error,
}: PromptEditorProps) {
const [formData, setFormData] = useState<PromptFormData>({
prompt: "",
variables: [],
});
const [newVariable, setNewVariable] = useState("");
const [variableValues, setVariableValues] = useState<Record<string, string>>(
{}
);
useEffect(() => {
if (prompt) {
setFormData({
prompt: prompt.prompt || "",
variables: prompt.variables || [],
});
}
}, [prompt]);
const handleSubmit = (e: React.FormEvent) => {
e.preventDefault();
if (!formData.prompt.trim()) {
return;
}
onSave(formData);
};
const addVariable = () => {
if (
newVariable.trim() &&
!formData.variables.includes(newVariable.trim())
) {
setFormData(prev => ({
...prev,
variables: [...prev.variables, newVariable.trim()],
}));
setNewVariable("");
}
};
const removeVariable = (variableToRemove: string) => {
setFormData(prev => ({
...prev,
variables: prev.variables.filter(
variable => variable !== variableToRemove
),
}));
};
const renderPreview = () => {
const text = formData.prompt;
if (!text) return text;
// Split text by variable patterns and process each part
const parts = text.split(/(\{\{\s*\w+\s*\}\})/g);
return parts.map((part, index) => {
const variableMatch = part.match(/\{\{\s*(\w+)\s*\}\}/);
if (variableMatch) {
const variableName = variableMatch[1];
const isDefined = formData.variables.includes(variableName);
const value = variableValues[variableName];
if (!isDefined) {
// Variable not in variables list - likely a typo/bug (RED)
return (
<span
key={index}
className="bg-red-100 text-red-800 dark:bg-red-900 dark:text-red-200 px-1 rounded font-medium"
>
{part}
</span>
);
} else if (value && value.trim()) {
// Variable defined and has value - show the value (GREEN)
return (
<span
key={index}
className="bg-green-100 text-green-800 dark:bg-green-900 dark:text-green-200 px-1 rounded font-medium"
>
{value}
</span>
);
} else {
// Variable defined but empty (YELLOW)
return (
<span
key={index}
className="bg-yellow-100 text-yellow-800 dark:bg-yellow-900 dark:text-yellow-200 px-1 rounded font-medium"
>
{part}
</span>
);
}
}
return part;
});
};
const updateVariableValue = (variable: string, value: string) => {
setVariableValues(prev => ({
...prev,
[variable]: value,
}));
};
return (
<form onSubmit={handleSubmit} className="space-y-6">
{error && (
<div className="p-4 bg-destructive/10 border border-destructive/20 rounded-md">
<p className="text-destructive text-sm">{error}</p>
</div>
)}
<div className="grid grid-cols-1 lg:grid-cols-2 gap-6">
{/* Form Section */}
<div className="space-y-4">
<div>
<Label htmlFor="prompt">Prompt Content *</Label>
<Textarea
id="prompt"
value={formData.prompt}
onChange={e =>
setFormData(prev => ({ ...prev, prompt: e.target.value }))
}
placeholder="Enter your prompt content here. Use {{variable_name}} for dynamic variables."
className="min-h-32 font-mono mt-2"
required
/>
<p className="text-xs text-muted-foreground mt-2">
Use double curly braces around variable names, e.g.,{" "}
{`{{user_name}}`} or {`{{topic}}`}
</p>
</div>
<div className="space-y-3">
<Label className="text-sm font-medium">Variables</Label>
<div className="flex gap-2 mt-2">
<Input
value={newVariable}
onChange={e => setNewVariable(e.target.value)}
placeholder="Add variable name (e.g. user_name, topic)"
onKeyPress={e =>
e.key === "Enter" && (e.preventDefault(), addVariable())
}
className="flex-1"
/>
<Button
type="button"
onClick={addVariable}
size="sm"
disabled={
!newVariable.trim() ||
formData.variables.includes(newVariable.trim())
}
>
<Plus className="h-4 w-4" />
Add
</Button>
</div>
{formData.variables.length > 0 && (
<div className="border rounded-lg p-3 bg-muted/20">
<div className="flex flex-wrap gap-2">
{formData.variables.map(variable => (
<Badge
key={variable}
variant="secondary"
className="text-sm px-2 py-1"
>
{variable}
<button
type="button"
onClick={() => removeVariable(variable)}
className="ml-2 hover:text-destructive transition-colors"
title={`Remove ${variable} variable`}
>
<X className="h-3 w-3" />
</button>
</Badge>
))}
</div>
</div>
)}
<p className="text-xs text-muted-foreground">
Variables that can be used in the prompt template. Each variable
should match a {`{{variable}}`} placeholder in the content above.
</p>
</div>
</div>
{/* Preview Section */}
<div className="space-y-4">
<Card>
<CardHeader>
<CardTitle className="text-lg">Preview</CardTitle>
<CardDescription>
Live preview of compiled prompt and variable substitution.
</CardDescription>
</CardHeader>
<CardContent className="space-y-4">
{formData.prompt ? (
<>
{/* Variable Values */}
{formData.variables.length > 0 && (
<div className="space-y-3">
<Label className="text-sm font-medium">
Variable Values
</Label>
<div className="space-y-2">
{formData.variables.map(variable => (
<div
key={variable}
className="grid grid-cols-2 gap-3 items-center"
>
<div className="text-sm font-mono text-muted-foreground">
{variable}
</div>
<Input
id={`var-${variable}`}
value={variableValues[variable] || ""}
onChange={e =>
updateVariableValue(variable, e.target.value)
}
placeholder={`Enter value for ${variable}`}
className="text-sm"
/>
</div>
))}
</div>
<Separator />
</div>
)}
{/* Live Preview */}
<div>
<Label className="text-sm font-medium mb-2 block">
Compiled Prompt
</Label>
<div className="bg-muted/50 p-4 rounded-lg border">
<div className="text-sm leading-relaxed whitespace-pre-wrap">
{renderPreview()}
</div>
</div>
<div className="flex flex-wrap gap-4 mt-2 text-xs">
<div className="flex items-center gap-1">
<div className="w-3 h-3 bg-green-500 dark:bg-green-400 border rounded"></div>
<span className="text-muted-foreground">Used</span>
</div>
<div className="flex items-center gap-1">
<div className="w-3 h-3 bg-yellow-500 dark:bg-yellow-400 border rounded"></div>
<span className="text-muted-foreground">Unused</span>
</div>
<div className="flex items-center gap-1">
<div className="w-3 h-3 bg-red-500 dark:bg-red-400 border rounded"></div>
<span className="text-muted-foreground">Undefined</span>
</div>
</div>
</div>
</>
) : (
<div className="text-center py-8">
<div className="text-muted-foreground text-sm">
Enter content to preview the compiled prompt
</div>
<div className="text-xs text-muted-foreground mt-2">
Use {`{{variable_name}}`} to add dynamic variables
</div>
</div>
)}
</CardContent>
</Card>
</div>
</div>
<Separator />
<div className="flex justify-between">
<div>
{prompt && onDelete && (
<Button
type="button"
variant="destructive"
onClick={() => {
if (
confirm(
`Are you sure you want to delete this prompt? This action cannot be undone.`
)
) {
onDelete(prompt.prompt_id);
}
}}
>
<Trash2 className="h-4 w-4 mr-2" />
Delete Prompt
</Button>
)}
</div>
<div className="flex gap-2">
<Button type="button" variant="outline" onClick={onCancel}>
Cancel
</Button>
<Button type="submit">
<Save className="h-4 w-4 mr-2" />
{prompt ? "Update" : "Create"} Prompt
</Button>
</div>
</div>
</form>
);
}

View file

@ -0,0 +1,259 @@
import React from "react";
import { render, screen, fireEvent } from "@testing-library/react";
import "@testing-library/jest-dom";
import { PromptList } from "./prompt-list";
import type { Prompt } from "./types";
describe("PromptList", () => {
const mockOnEdit = jest.fn();
const mockOnDelete = jest.fn();
const defaultProps = {
prompts: [],
onEdit: mockOnEdit,
onDelete: mockOnDelete,
};
beforeEach(() => {
jest.clearAllMocks();
});
describe("Empty State", () => {
test("renders empty message when no prompts", () => {
render(<PromptList {...defaultProps} />);
expect(screen.getByText("No prompts yet")).toBeInTheDocument();
});
test("shows filtered empty message when search has no results", () => {
const prompts: Prompt[] = [
{
prompt_id: "prompt_123",
prompt: "Hello world",
version: 1,
variables: [],
is_default: false,
},
];
render(<PromptList {...defaultProps} prompts={prompts} />);
// Search for something that doesn't exist
const searchInput = screen.getByPlaceholderText("Search prompts...");
fireEvent.change(searchInput, { target: { value: "nonexistent" } });
expect(
screen.getByText("No prompts match your filters")
).toBeInTheDocument();
});
});
describe("Prompts Display", () => {
const mockPrompts: Prompt[] = [
{
prompt_id: "prompt_123",
prompt: "Hello {{name}}, how are you?",
version: 1,
variables: ["name"],
is_default: true,
},
{
prompt_id: "prompt_456",
prompt: "Summarize this {{text}} in {{length}} words",
version: 2,
variables: ["text", "length"],
is_default: false,
},
{
prompt_id: "prompt_789",
prompt: "Simple prompt with no variables",
version: 1,
variables: [],
is_default: false,
},
];
test("renders prompts table with correct headers", () => {
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
expect(screen.getByText("ID")).toBeInTheDocument();
expect(screen.getByText("Content")).toBeInTheDocument();
expect(screen.getByText("Variables")).toBeInTheDocument();
expect(screen.getByText("Version")).toBeInTheDocument();
expect(screen.getByText("Actions")).toBeInTheDocument();
});
test("renders prompt data correctly", () => {
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
// Check prompt IDs
expect(screen.getByText("prompt_123")).toBeInTheDocument();
expect(screen.getByText("prompt_456")).toBeInTheDocument();
expect(screen.getByText("prompt_789")).toBeInTheDocument();
// Check content
expect(
screen.getByText("Hello {{name}}, how are you?")
).toBeInTheDocument();
expect(
screen.getByText("Summarize this {{text}} in {{length}} words")
).toBeInTheDocument();
expect(
screen.getByText("Simple prompt with no variables")
).toBeInTheDocument();
// Check versions
expect(screen.getAllByText("1")).toHaveLength(2); // Two prompts with version 1
expect(screen.getByText("2")).toBeInTheDocument();
// Check default badge
expect(screen.getByText("Default")).toBeInTheDocument();
});
test("renders variables correctly", () => {
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
// Check variables display
expect(screen.getByText("name")).toBeInTheDocument();
expect(screen.getByText("text")).toBeInTheDocument();
expect(screen.getByText("length")).toBeInTheDocument();
expect(screen.getByText("None")).toBeInTheDocument(); // For prompt with no variables
});
test("prompt ID links are clickable and call onEdit", () => {
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
// Click on the first prompt ID link
const promptLink = screen.getByRole("button", { name: "prompt_123" });
fireEvent.click(promptLink);
expect(mockOnEdit).toHaveBeenCalledWith(mockPrompts[0]);
});
test("edit buttons call onEdit", () => {
const { container } = render(
<PromptList {...defaultProps} prompts={mockPrompts} />
);
// Find the action buttons in the table - they should be in the last column
const actionCells = container.querySelectorAll("td:last-child");
const firstActionCell = actionCells[0];
const editButton = firstActionCell?.querySelector("button");
expect(editButton).toBeInTheDocument();
fireEvent.click(editButton!);
expect(mockOnEdit).toHaveBeenCalledWith(mockPrompts[0]);
});
test("delete buttons call onDelete with confirmation", () => {
const originalConfirm = window.confirm;
window.confirm = jest.fn(() => true);
const { container } = render(
<PromptList {...defaultProps} prompts={mockPrompts} />
);
// Find the delete button (second button in the first action cell)
const actionCells = container.querySelectorAll("td:last-child");
const firstActionCell = actionCells[0];
const buttons = firstActionCell?.querySelectorAll("button");
const deleteButton = buttons?.[1]; // Second button should be delete
expect(deleteButton).toBeInTheDocument();
fireEvent.click(deleteButton!);
expect(window.confirm).toHaveBeenCalledWith(
"Are you sure you want to delete this prompt? This action cannot be undone."
);
expect(mockOnDelete).toHaveBeenCalledWith("prompt_123");
window.confirm = originalConfirm;
});
test("delete does not execute when confirmation is cancelled", () => {
const originalConfirm = window.confirm;
window.confirm = jest.fn(() => false);
const { container } = render(
<PromptList {...defaultProps} prompts={mockPrompts} />
);
const actionCells = container.querySelectorAll("td:last-child");
const firstActionCell = actionCells[0];
const buttons = firstActionCell?.querySelectorAll("button");
const deleteButton = buttons?.[1]; // Second button should be delete
expect(deleteButton).toBeInTheDocument();
fireEvent.click(deleteButton!);
expect(mockOnDelete).not.toHaveBeenCalled();
window.confirm = originalConfirm;
});
});
describe("Search Functionality", () => {
const mockPrompts: Prompt[] = [
{
prompt_id: "user_greeting",
prompt: "Hello {{name}}, welcome!",
version: 1,
variables: ["name"],
is_default: true,
},
{
prompt_id: "system_summary",
prompt: "Summarize the following text",
version: 1,
variables: [],
is_default: false,
},
];
test("filters prompts by prompt ID", () => {
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
const searchInput = screen.getByPlaceholderText("Search prompts...");
fireEvent.change(searchInput, { target: { value: "user" } });
expect(screen.getByText("user_greeting")).toBeInTheDocument();
expect(screen.queryByText("system_summary")).not.toBeInTheDocument();
});
test("filters prompts by content", () => {
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
const searchInput = screen.getByPlaceholderText("Search prompts...");
fireEvent.change(searchInput, { target: { value: "welcome" } });
expect(screen.getByText("user_greeting")).toBeInTheDocument();
expect(screen.queryByText("system_summary")).not.toBeInTheDocument();
});
test("search is case insensitive", () => {
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
const searchInput = screen.getByPlaceholderText("Search prompts...");
fireEvent.change(searchInput, { target: { value: "HELLO" } });
expect(screen.getByText("user_greeting")).toBeInTheDocument();
expect(screen.queryByText("system_summary")).not.toBeInTheDocument();
});
test("clearing search shows all prompts", () => {
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
const searchInput = screen.getByPlaceholderText("Search prompts...");
// Filter first
fireEvent.change(searchInput, { target: { value: "user" } });
expect(screen.queryByText("system_summary")).not.toBeInTheDocument();
// Clear search
fireEvent.change(searchInput, { target: { value: "" } });
expect(screen.getByText("user_greeting")).toBeInTheDocument();
expect(screen.getByText("system_summary")).toBeInTheDocument();
});
});
});

View file

@ -0,0 +1,164 @@
"use client";
import { useState } from "react";
import { Badge } from "@/components/ui/badge";
import { Button } from "@/components/ui/button";
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/ui/table";
import { Input } from "@/components/ui/input";
import { Edit, Search, Trash2 } from "lucide-react";
import { Prompt, PromptFilters } from "./types";
interface PromptListProps {
prompts: Prompt[];
onEdit: (prompt: Prompt) => void;
onDelete: (promptId: string) => void;
}
export function PromptList({ prompts, onEdit, onDelete }: PromptListProps) {
const [filters, setFilters] = useState<PromptFilters>({});
const filteredPrompts = prompts.filter(prompt => {
if (
filters.searchTerm &&
!(
prompt.prompt
?.toLowerCase()
.includes(filters.searchTerm.toLowerCase()) ||
prompt.prompt_id
.toLowerCase()
.includes(filters.searchTerm.toLowerCase())
)
) {
return false;
}
return true;
});
return (
<div className="space-y-4">
{/* Filters */}
<div className="flex flex-col sm:flex-row gap-4">
<div className="relative flex-1">
<Search className="absolute left-3 top-1/2 transform -translate-y-1/2 text-muted-foreground h-4 w-4" />
<Input
placeholder="Search prompts..."
value={filters.searchTerm || ""}
onChange={e =>
setFilters(prev => ({ ...prev, searchTerm: e.target.value }))
}
className="pl-10"
/>
</div>
</div>
{/* Prompts Table */}
<div className="overflow-auto">
<Table>
<TableHeader>
<TableRow>
<TableHead>ID</TableHead>
<TableHead>Content</TableHead>
<TableHead>Variables</TableHead>
<TableHead>Version</TableHead>
<TableHead>Actions</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{filteredPrompts.map(prompt => (
<TableRow key={prompt.prompt_id}>
<TableCell className="max-w-48">
<Button
variant="link"
className="p-0 h-auto font-mono text-blue-600 hover:text-blue-800 dark:text-blue-400 dark:hover:text-blue-300 max-w-full justify-start"
onClick={() => onEdit(prompt)}
title={prompt.prompt_id}
>
<div className="truncate">{prompt.prompt_id}</div>
</Button>
</TableCell>
<TableCell className="max-w-64">
<div
className="font-mono text-xs text-muted-foreground truncate"
title={prompt.prompt || "No content"}
>
{prompt.prompt || "No content"}
</div>
</TableCell>
<TableCell>
{prompt.variables.length > 0 ? (
<div className="flex flex-wrap gap-1">
{prompt.variables.map(variable => (
<Badge
key={variable}
variant="outline"
className="text-xs"
>
{variable}
</Badge>
))}
</div>
) : (
<span className="text-muted-foreground text-sm">None</span>
)}
</TableCell>
<TableCell className="text-sm">
{prompt.version}
{prompt.is_default && (
<Badge variant="secondary" className="text-xs ml-2">
Default
</Badge>
)}
</TableCell>
<TableCell>
<div className="flex gap-1">
<Button
size="sm"
variant="outline"
onClick={() => onEdit(prompt)}
className="h-8 w-8 p-0"
>
<Edit className="h-3 w-3" />
</Button>
<Button
size="sm"
variant="outline"
onClick={() => {
if (
confirm(
`Are you sure you want to delete this prompt? This action cannot be undone.`
)
) {
onDelete(prompt.prompt_id);
}
}}
className="h-8 w-8 p-0 text-destructive hover:text-destructive"
>
<Trash2 className="h-3 w-3" />
</Button>
</div>
</TableCell>
</TableRow>
))}
</TableBody>
</Table>
</div>
{filteredPrompts.length === 0 && (
<div className="text-center py-12">
<div className="text-muted-foreground">
{prompts.length === 0
? "No prompts yet"
: "No prompts match your filters"}
</div>
</div>
)}
</div>
);
}

View file

@ -0,0 +1,304 @@
import React from "react";
import { render, screen, fireEvent, waitFor } from "@testing-library/react";
import "@testing-library/jest-dom";
import { PromptManagement } from "./prompt-management";
import type { Prompt } from "./types";
// Mock the auth client
const mockPromptsClient = {
list: jest.fn(),
create: jest.fn(),
update: jest.fn(),
delete: jest.fn(),
};
jest.mock("@/hooks/use-auth-client", () => ({
useAuthClient: () => ({
prompts: mockPromptsClient,
}),
}));
describe("PromptManagement", () => {
beforeEach(() => {
jest.clearAllMocks();
});
describe("Loading State", () => {
test("renders loading state initially", () => {
mockPromptsClient.list.mockReturnValue(new Promise(() => {})); // Never resolves
render(<PromptManagement />);
expect(screen.getByText("Loading prompts...")).toBeInTheDocument();
expect(screen.getByText("Prompts")).toBeInTheDocument();
});
});
describe("Empty State", () => {
test("renders empty state when no prompts", async () => {
mockPromptsClient.list.mockResolvedValue([]);
render(<PromptManagement />);
await waitFor(() => {
expect(screen.getByText("No prompts found.")).toBeInTheDocument();
});
expect(screen.getByText("Create Your First Prompt")).toBeInTheDocument();
});
test("opens modal when clicking 'Create Your First Prompt'", async () => {
mockPromptsClient.list.mockResolvedValue([]);
render(<PromptManagement />);
await waitFor(() => {
expect(
screen.getByText("Create Your First Prompt")
).toBeInTheDocument();
});
fireEvent.click(screen.getByText("Create Your First Prompt"));
expect(screen.getByText("Create New Prompt")).toBeInTheDocument();
});
});
describe("Error State", () => {
test("renders error state when API fails", async () => {
const error = new Error("API not found");
mockPromptsClient.list.mockRejectedValue(error);
render(<PromptManagement />);
await waitFor(() => {
expect(screen.getByText(/Error:/)).toBeInTheDocument();
});
});
test("renders specific error for 404", async () => {
const error = new Error("404 Not found");
mockPromptsClient.list.mockRejectedValue(error);
render(<PromptManagement />);
await waitFor(() => {
expect(
screen.getByText(/Prompts API endpoint not found/)
).toBeInTheDocument();
});
});
});
describe("Prompts List", () => {
const mockPrompts: Prompt[] = [
{
prompt_id: "prompt_123",
prompt: "Hello {{name}}, how are you?",
version: 1,
variables: ["name"],
is_default: true,
},
{
prompt_id: "prompt_456",
prompt: "Summarize this {{text}}",
version: 2,
variables: ["text"],
is_default: false,
},
];
test("renders prompts list correctly", async () => {
mockPromptsClient.list.mockResolvedValue(mockPrompts);
render(<PromptManagement />);
await waitFor(() => {
expect(screen.getByText("prompt_123")).toBeInTheDocument();
});
expect(screen.getByText("prompt_456")).toBeInTheDocument();
expect(
screen.getByText("Hello {{name}}, how are you?")
).toBeInTheDocument();
expect(screen.getByText("Summarize this {{text}}")).toBeInTheDocument();
});
test("opens modal when clicking 'New Prompt' button", async () => {
mockPromptsClient.list.mockResolvedValue(mockPrompts);
render(<PromptManagement />);
await waitFor(() => {
expect(screen.getByText("prompt_123")).toBeInTheDocument();
});
fireEvent.click(screen.getByText("New Prompt"));
expect(screen.getByText("Create New Prompt")).toBeInTheDocument();
});
});
describe("Modal Operations", () => {
const mockPrompts: Prompt[] = [
{
prompt_id: "prompt_123",
prompt: "Hello {{name}}",
version: 1,
variables: ["name"],
is_default: true,
},
];
test("closes modal when clicking cancel", async () => {
mockPromptsClient.list.mockResolvedValue(mockPrompts);
render(<PromptManagement />);
await waitFor(() => {
expect(screen.getByText("prompt_123")).toBeInTheDocument();
});
// Open modal
fireEvent.click(screen.getByText("New Prompt"));
expect(screen.getByText("Create New Prompt")).toBeInTheDocument();
// Close modal
fireEvent.click(screen.getByText("Cancel"));
expect(screen.queryByText("Create New Prompt")).not.toBeInTheDocument();
});
test("creates new prompt successfully", async () => {
const newPrompt: Prompt = {
prompt_id: "prompt_new",
prompt: "New prompt content",
version: 1,
variables: [],
is_default: false,
};
mockPromptsClient.list.mockResolvedValue(mockPrompts);
mockPromptsClient.create.mockResolvedValue(newPrompt);
render(<PromptManagement />);
await waitFor(() => {
expect(screen.getByText("prompt_123")).toBeInTheDocument();
});
// Open modal
fireEvent.click(screen.getByText("New Prompt"));
// Fill form
const promptInput = screen.getByLabelText("Prompt Content *");
fireEvent.change(promptInput, {
target: { value: "New prompt content" },
});
// Submit form
fireEvent.click(screen.getByText("Create Prompt"));
await waitFor(() => {
expect(mockPromptsClient.create).toHaveBeenCalledWith({
prompt: "New prompt content",
variables: [],
});
});
});
test("handles create error gracefully", async () => {
const error = {
detail: {
errors: [{ msg: "Prompt contains undeclared variables: ['test']" }],
},
};
mockPromptsClient.list.mockResolvedValue(mockPrompts);
mockPromptsClient.create.mockRejectedValue(error);
render(<PromptManagement />);
await waitFor(() => {
expect(screen.getByText("prompt_123")).toBeInTheDocument();
});
// Open modal
fireEvent.click(screen.getByText("New Prompt"));
// Fill form
const promptInput = screen.getByLabelText("Prompt Content *");
fireEvent.change(promptInput, { target: { value: "Hello {{test}}" } });
// Submit form
fireEvent.click(screen.getByText("Create Prompt"));
await waitFor(() => {
expect(
screen.getByText("Prompt contains undeclared variables: ['test']")
).toBeInTheDocument();
});
});
test("updates existing prompt successfully", async () => {
const updatedPrompt: Prompt = {
...mockPrompts[0],
prompt: "Updated content",
};
mockPromptsClient.list.mockResolvedValue(mockPrompts);
mockPromptsClient.update.mockResolvedValue(updatedPrompt);
const { container } = render(<PromptManagement />);
await waitFor(() => {
expect(screen.getByText("prompt_123")).toBeInTheDocument();
});
// Click edit button (first button in the action cell of the first row)
const actionCells = container.querySelectorAll("td:last-child");
const firstActionCell = actionCells[0];
const editButton = firstActionCell?.querySelector("button");
expect(editButton).toBeInTheDocument();
fireEvent.click(editButton!);
expect(screen.getByText("Edit Prompt")).toBeInTheDocument();
// Update content
const promptInput = screen.getByLabelText("Prompt Content *");
fireEvent.change(promptInput, { target: { value: "Updated content" } });
// Submit form
fireEvent.click(screen.getByText("Update Prompt"));
await waitFor(() => {
expect(mockPromptsClient.update).toHaveBeenCalledWith("prompt_123", {
prompt: "Updated content",
variables: ["name"],
version: 1,
set_as_default: true,
});
});
});
test("deletes prompt successfully", async () => {
mockPromptsClient.list.mockResolvedValue(mockPrompts);
mockPromptsClient.delete.mockResolvedValue(undefined);
// Mock window.confirm
const originalConfirm = window.confirm;
window.confirm = jest.fn(() => true);
const { container } = render(<PromptManagement />);
await waitFor(() => {
expect(screen.getByText("prompt_123")).toBeInTheDocument();
});
// Click delete button (second button in the action cell of the first row)
const actionCells = container.querySelectorAll("td:last-child");
const firstActionCell = actionCells[0];
const buttons = firstActionCell?.querySelectorAll("button");
const deleteButton = buttons?.[1]; // Second button should be delete
expect(deleteButton).toBeInTheDocument();
fireEvent.click(deleteButton!);
await waitFor(() => {
expect(mockPromptsClient.delete).toHaveBeenCalledWith("prompt_123");
});
// Restore window.confirm
window.confirm = originalConfirm;
});
});
});

View file

@ -0,0 +1,233 @@
"use client";
import { useState, useEffect } from "react";
import { Button } from "@/components/ui/button";
import { Plus } from "lucide-react";
import { PromptList } from "./prompt-list";
import { PromptEditor } from "./prompt-editor";
import { Prompt, PromptFormData } from "./types";
import { useAuthClient } from "@/hooks/use-auth-client";
export function PromptManagement() {
const [prompts, setPrompts] = useState<Prompt[]>([]);
const [showPromptModal, setShowPromptModal] = useState(false);
const [editingPrompt, setEditingPrompt] = useState<Prompt | undefined>();
const [loading, setLoading] = useState(true);
const [error, setError] = useState<string | null>(null); // For main page errors (loading, etc.)
const [modalError, setModalError] = useState<string | null>(null); // For form submission errors
const client = useAuthClient();
// Load prompts from API on component mount
useEffect(() => {
const fetchPrompts = async () => {
try {
setLoading(true);
setError(null);
const response = await client.prompts.list();
setPrompts(response || []);
} catch (err: unknown) {
console.error("Failed to load prompts:", err);
// Handle different types of errors
const error = err as Error & { status?: number };
if (error?.message?.includes("404") || error?.status === 404) {
setError(
"Prompts API endpoint not found. Please ensure your Llama Stack server supports the prompts API."
);
} else if (
error?.message?.includes("not implemented") ||
error?.message?.includes("not supported")
) {
setError(
"Prompts API is not yet implemented on this Llama Stack server."
);
} else {
setError(
`Failed to load prompts: ${error?.message || "Unknown error"}`
);
}
} finally {
setLoading(false);
}
};
fetchPrompts();
}, [client]);
const handleSavePrompt = async (formData: PromptFormData) => {
try {
setModalError(null);
if (editingPrompt) {
// Update existing prompt
const response = await client.prompts.update(editingPrompt.prompt_id, {
prompt: formData.prompt,
variables: formData.variables,
version: editingPrompt.version,
set_as_default: true,
});
// Update local state
setPrompts(prev =>
prev.map(p =>
p.prompt_id === editingPrompt.prompt_id ? response : p
)
);
} else {
// Create new prompt
const response = await client.prompts.create({
prompt: formData.prompt,
variables: formData.variables,
});
// Add to local state
setPrompts(prev => [response, ...prev]);
}
setShowPromptModal(false);
setEditingPrompt(undefined);
} catch (err) {
console.error("Failed to save prompt:", err);
// Extract specific error message from API response
const error = err as Error & {
message?: string;
detail?: { errors?: Array<{ msg?: string }> };
};
// Try to parse JSON from error message if it's a string
let parsedError = error;
if (typeof error?.message === "string" && error.message.includes("{")) {
try {
const jsonMatch = error.message.match(/\d+\s+(.+)/);
if (jsonMatch) {
parsedError = JSON.parse(jsonMatch[1]);
}
} catch {
// If parsing fails, use original error
}
}
// Try to get the specific validation error message
const validationError = parsedError?.detail?.errors?.[0]?.msg;
if (validationError) {
// Clean up validation error messages (remove "Value error, " prefix if present)
const cleanMessage = validationError.replace(/^Value error,\s*/i, "");
setModalError(cleanMessage);
} else {
// For other errors, format them nicely with line breaks
const statusMatch = error?.message?.match(/(\d+)\s+(.+)/);
if (statusMatch) {
const statusCode = statusMatch[1];
const response = statusMatch[2];
setModalError(
`Failed to save prompt: Status Code ${statusCode}\n\nResponse: ${response}`
);
} else {
const message = error?.message || error?.detail || "Unknown error";
setModalError(`Failed to save prompt: ${message}`);
}
}
}
};
const handleEditPrompt = (prompt: Prompt) => {
setEditingPrompt(prompt);
setShowPromptModal(true);
setModalError(null); // Clear any previous modal errors
};
const handleDeletePrompt = async (promptId: string) => {
try {
setError(null);
await client.prompts.delete(promptId);
setPrompts(prev => prev.filter(p => p.prompt_id !== promptId));
// If we're deleting the currently editing prompt, close the modal
if (editingPrompt && editingPrompt.prompt_id === promptId) {
setShowPromptModal(false);
setEditingPrompt(undefined);
}
} catch (err) {
console.error("Failed to delete prompt:", err);
setError("Failed to delete prompt");
}
};
const handleCreateNew = () => {
setEditingPrompt(undefined);
setShowPromptModal(true);
setModalError(null); // Clear any previous modal errors
};
const handleCancel = () => {
setShowPromptModal(false);
setEditingPrompt(undefined);
};
const renderContent = () => {
if (loading) {
return <div className="text-muted-foreground">Loading prompts...</div>;
}
if (error) {
return <div className="text-destructive">Error: {error}</div>;
}
if (!prompts || prompts.length === 0) {
return (
<div className="text-center py-12">
<p className="text-muted-foreground mb-4">No prompts found.</p>
<Button onClick={handleCreateNew}>
<Plus className="h-4 w-4 mr-2" />
Create Your First Prompt
</Button>
</div>
);
}
return (
<PromptList
prompts={prompts}
onEdit={handleEditPrompt}
onDelete={handleDeletePrompt}
/>
);
};
return (
<div className="space-y-4">
<div className="flex items-center justify-between">
<h1 className="text-2xl font-semibold">Prompts</h1>
<Button onClick={handleCreateNew} disabled={loading}>
<Plus className="h-4 w-4 mr-2" />
New Prompt
</Button>
</div>
{renderContent()}
{/* Create/Edit Prompt Modal */}
{showPromptModal && (
<div className="fixed inset-0 bg-black/50 flex items-center justify-center z-50">
<div className="bg-background border rounded-lg shadow-lg max-w-4xl w-full mx-4 max-h-[90vh] overflow-hidden">
<div className="p-6 border-b">
<h2 className="text-2xl font-bold">
{editingPrompt ? "Edit Prompt" : "Create New Prompt"}
</h2>
</div>
<div className="p-6 overflow-y-auto max-h-[calc(90vh-120px)]">
<PromptEditor
prompt={editingPrompt}
onSave={handleSavePrompt}
onCancel={handleCancel}
onDelete={handleDeletePrompt}
error={modalError}
/>
</div>
</div>
</div>
)}
</div>
);
}

View file

@ -0,0 +1,16 @@
export interface Prompt {
prompt_id: string;
prompt: string | null;
version: number;
variables: string[];
is_default: boolean;
}
export interface PromptFormData {
prompt: string;
variables: string[];
}
export interface PromptFilters {
searchTerm?: string;
}

View file

@ -0,0 +1,36 @@
import * as React from "react";
import { cva, type VariantProps } from "class-variance-authority";
import { cn } from "@/lib/utils";
const badgeVariants = cva(
"inline-flex items-center rounded-full border px-2.5 py-0.5 text-xs font-semibold transition-colors focus:outline-none focus:ring-2 focus:ring-ring focus:ring-offset-2",
{
variants: {
variant: {
default:
"border-transparent bg-primary text-primary-foreground hover:bg-primary/80",
secondary:
"border-transparent bg-secondary text-secondary-foreground hover:bg-secondary/80",
destructive:
"border-transparent bg-destructive text-destructive-foreground hover:bg-destructive/80",
outline: "text-foreground",
},
},
defaultVariants: {
variant: "default",
},
}
);
export interface BadgeProps
extends React.HTMLAttributes<HTMLDivElement>,
VariantProps<typeof badgeVariants> {}
function Badge({ className, variant, ...props }: BadgeProps) {
return (
<div className={cn(badgeVariants({ variant }), className)} {...props} />
);
}
export { Badge, badgeVariants };

View file

@ -0,0 +1,24 @@
import * as React from "react";
import * as LabelPrimitive from "@radix-ui/react-label";
import { cva, type VariantProps } from "class-variance-authority";
import { cn } from "@/lib/utils";
const labelVariants = cva(
"text-sm font-medium leading-none peer-disabled:cursor-not-allowed peer-disabled:opacity-70"
);
const Label = React.forwardRef<
React.ElementRef<typeof LabelPrimitive.Root>,
React.ComponentPropsWithoutRef<typeof LabelPrimitive.Root> &
VariantProps<typeof labelVariants>
>(({ className, ...props }, ref) => (
<LabelPrimitive.Root
ref={ref}
className={cn(labelVariants(), className)}
{...props}
/>
));
Label.displayName = LabelPrimitive.Root.displayName;
export { Label };

View file

@ -0,0 +1,53 @@
import * as React from "react";
import * as TabsPrimitive from "@radix-ui/react-tabs";
import { cn } from "@/lib/utils";
const Tabs = TabsPrimitive.Root;
const TabsList = React.forwardRef<
React.ElementRef<typeof TabsPrimitive.List>,
React.ComponentPropsWithoutRef<typeof TabsPrimitive.List>
>(({ className, ...props }, ref) => (
<TabsPrimitive.List
ref={ref}
className={cn(
"inline-flex h-10 items-center justify-center rounded-md bg-muted p-1 text-muted-foreground",
className
)}
{...props}
/>
));
TabsList.displayName = TabsPrimitive.List.displayName;
const TabsTrigger = React.forwardRef<
React.ElementRef<typeof TabsPrimitive.Trigger>,
React.ComponentPropsWithoutRef<typeof TabsPrimitive.Trigger>
>(({ className, ...props }, ref) => (
<TabsPrimitive.Trigger
ref={ref}
className={cn(
"inline-flex items-center justify-center whitespace-nowrap rounded-sm px-3 py-1.5 text-sm font-medium ring-offset-background transition-all focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:pointer-events-none disabled:opacity-50 data-[state=active]:bg-background data-[state=active]:text-foreground data-[state=active]:shadow-sm",
className
)}
{...props}
/>
));
TabsTrigger.displayName = TabsPrimitive.Trigger.displayName;
const TabsContent = React.forwardRef<
React.ElementRef<typeof TabsPrimitive.Content>,
React.ComponentPropsWithoutRef<typeof TabsPrimitive.Content>
>(({ className, ...props }, ref) => (
<TabsPrimitive.Content
ref={ref}
className={cn(
"mt-2 ring-offset-background focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2",
className
)}
{...props}
/>
));
TabsContent.displayName = TabsPrimitive.Content.displayName;
export { Tabs, TabsList, TabsTrigger, TabsContent };

View file

@ -0,0 +1,23 @@
import * as React from "react";
import { cn } from "@/lib/utils";
export type TextareaProps = React.TextareaHTMLAttributes<HTMLTextAreaElement>;
const Textarea = React.forwardRef<HTMLTextAreaElement, TextareaProps>(
({ className, ...props }, ref) => {
return (
<textarea
className={cn(
"flex min-h-[80px] w-full rounded-md border border-input bg-background px-3 py-2 text-sm ring-offset-background placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-50",
className
)}
ref={ref}
{...props}
/>
);
}
);
Textarea.displayName = "Textarea";
export { Textarea };

View file

@ -11,14 +11,16 @@
"@radix-ui/react-collapsible": "^1.1.12",
"@radix-ui/react-dialog": "^1.1.15",
"@radix-ui/react-dropdown-menu": "^2.1.16",
"@radix-ui/react-label": "^2.1.7",
"@radix-ui/react-select": "^2.2.6",
"@radix-ui/react-separator": "^1.1.7",
"@radix-ui/react-slot": "^1.2.3",
"@radix-ui/react-tabs": "^1.1.13",
"@radix-ui/react-tooltip": "^1.2.8",
"class-variance-authority": "^0.7.1",
"clsx": "^2.1.1",
"framer-motion": "^12.23.24",
"llama-stack-client": "^0.3.0",
"llama-stack-client": "github:llamastack/llama-stack-client-typescript",
"lucide-react": "^0.545.0",
"next": "15.5.4",
"next-auth": "^4.24.11",
@ -2597,6 +2599,29 @@
}
}
},
"node_modules/@radix-ui/react-label": {
"version": "2.1.7",
"resolved": "https://registry.npmjs.org/@radix-ui/react-label/-/react-label-2.1.7.tgz",
"integrity": "sha512-YT1GqPSL8kJn20djelMX7/cTRp/Y9w5IZHvfxQTVHrOqa2yMl7i/UfMqKRU5V7mEyKTrUVgJXhNQPVCG8PBLoQ==",
"license": "MIT",
"dependencies": {
"@radix-ui/react-primitive": "2.1.3"
},
"peerDependencies": {
"@types/react": "*",
"@types/react-dom": "*",
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
},
"peerDependenciesMeta": {
"@types/react": {
"optional": true
},
"@types/react-dom": {
"optional": true
}
}
},
"node_modules/@radix-ui/react-menu": {
"version": "2.1.16",
"resolved": "https://registry.npmjs.org/@radix-ui/react-menu/-/react-menu-2.1.16.tgz",
@ -2855,6 +2880,36 @@
}
}
},
"node_modules/@radix-ui/react-tabs": {
"version": "1.1.13",
"resolved": "https://registry.npmjs.org/@radix-ui/react-tabs/-/react-tabs-1.1.13.tgz",
"integrity": "sha512-7xdcatg7/U+7+Udyoj2zodtI9H/IIopqo+YOIcZOq1nJwXWBZ9p8xiu5llXlekDbZkca79a/fozEYQXIA4sW6A==",
"license": "MIT",
"dependencies": {
"@radix-ui/primitive": "1.1.3",
"@radix-ui/react-context": "1.1.2",
"@radix-ui/react-direction": "1.1.1",
"@radix-ui/react-id": "1.1.1",
"@radix-ui/react-presence": "1.1.5",
"@radix-ui/react-primitive": "2.1.3",
"@radix-ui/react-roving-focus": "1.1.11",
"@radix-ui/react-use-controllable-state": "1.2.2"
},
"peerDependencies": {
"@types/react": "*",
"@types/react-dom": "*",
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
},
"peerDependenciesMeta": {
"@types/react": {
"optional": true
},
"@types/react-dom": {
"optional": true
}
}
},
"node_modules/@radix-ui/react-tooltip": {
"version": "1.2.8",
"resolved": "https://registry.npmjs.org/@radix-ui/react-tooltip/-/react-tooltip-1.2.8.tgz",
@ -9629,9 +9684,8 @@
"license": "MIT"
},
"node_modules/llama-stack-client": {
"version": "0.3.0",
"resolved": "https://registry.npmjs.org/llama-stack-client/-/llama-stack-client-0.3.0.tgz",
"integrity": "sha512-76K/t1doaGmlBbDxCADaral9Vccvys9P8pqAMIhwBhMAqWudCEORrMMhUSg+pjhamWmEKj3wa++d4zeOGbfN/w==",
"version": "0.4.0-alpha.1",
"resolved": "git+ssh://git@github.com/llamastack/llama-stack-client-typescript.git#78de4862c4b7d77939ac210fa9f9bde77a2c5c5f",
"license": "MIT",
"dependencies": {
"@types/node": "^18.11.18",

View file

@ -16,14 +16,16 @@
"@radix-ui/react-collapsible": "^1.1.12",
"@radix-ui/react-dialog": "^1.1.15",
"@radix-ui/react-dropdown-menu": "^2.1.16",
"@radix-ui/react-label": "^2.1.7",
"@radix-ui/react-select": "^2.2.6",
"@radix-ui/react-separator": "^1.1.7",
"@radix-ui/react-slot": "^1.2.3",
"@radix-ui/react-tabs": "^1.1.13",
"@radix-ui/react-tooltip": "^1.2.8",
"class-variance-authority": "^0.7.1",
"clsx": "^2.1.1",
"framer-motion": "^12.23.24",
"llama-stack-client": "^0.3.0",
"llama-stack-client": "github:llamastack/llama-stack-client-typescript",
"lucide-react": "^0.545.0",
"next": "15.5.4",
"next-auth": "^4.24.11",