Merge remote-tracking branch 'origin/main' into stores

This commit is contained in:
Ashwin Bharambe 2025-10-20 10:49:06 -07:00
commit 490b212576
89 changed files with 19353 additions and 8323 deletions

View file

@ -173,7 +173,9 @@ class ConversationItemDeletedResource(BaseModel):
@runtime_checkable
@trace_protocol
class Conversations(Protocol):
"""Protocol for conversation management operations."""
"""Conversations
Protocol for conversation management operations."""
@webmethod(route="/conversations", method="POST", level=LLAMA_STACK_API_V1)
async def create_conversation(
@ -181,6 +183,8 @@ class Conversations(Protocol):
) -> Conversation:
"""Create a conversation.
Create a conversation.
:param items: Initial items to include in the conversation context.
:param metadata: Set of key-value pairs that can be attached to an object.
:returns: The created conversation object.
@ -189,7 +193,9 @@ class Conversations(Protocol):
@webmethod(route="/conversations/{conversation_id}", method="GET", level=LLAMA_STACK_API_V1)
async def get_conversation(self, conversation_id: str) -> Conversation:
"""Get a conversation with the given ID.
"""Retrieve a conversation.
Get a conversation with the given ID.
:param conversation_id: The conversation identifier.
:returns: The conversation object.
@ -198,7 +204,9 @@ class Conversations(Protocol):
@webmethod(route="/conversations/{conversation_id}", method="POST", level=LLAMA_STACK_API_V1)
async def update_conversation(self, conversation_id: str, metadata: Metadata) -> Conversation:
"""Update a conversation's metadata with the given ID.
"""Update a conversation.
Update a conversation's metadata with the given ID.
:param conversation_id: The conversation identifier.
:param metadata: Set of key-value pairs that can be attached to an object.
@ -208,7 +216,9 @@ class Conversations(Protocol):
@webmethod(route="/conversations/{conversation_id}", method="DELETE", level=LLAMA_STACK_API_V1)
async def openai_delete_conversation(self, conversation_id: str) -> ConversationDeletedResource:
"""Delete a conversation with the given ID.
"""Delete a conversation.
Delete a conversation with the given ID.
:param conversation_id: The conversation identifier.
:returns: The deleted conversation resource.
@ -217,7 +227,9 @@ class Conversations(Protocol):
@webmethod(route="/conversations/{conversation_id}/items", method="POST", level=LLAMA_STACK_API_V1)
async def add_items(self, conversation_id: str, items: list[ConversationItem]) -> ConversationItemList:
"""Create items in the conversation.
"""Create items.
Create items in the conversation.
:param conversation_id: The conversation identifier.
:param items: Items to include in the conversation context.
@ -227,7 +239,9 @@ class Conversations(Protocol):
@webmethod(route="/conversations/{conversation_id}/items/{item_id}", method="GET", level=LLAMA_STACK_API_V1)
async def retrieve(self, conversation_id: str, item_id: str) -> ConversationItem:
"""Retrieve a conversation item.
"""Retrieve an item.
Retrieve a conversation item.
:param conversation_id: The conversation identifier.
:param item_id: The item identifier.
@ -244,7 +258,9 @@ class Conversations(Protocol):
limit: int | NotGiven = NOT_GIVEN,
order: Literal["asc", "desc"] | NotGiven = NOT_GIVEN,
) -> ConversationItemList:
"""List items in the conversation.
"""List items.
List items in the conversation.
:param conversation_id: The conversation identifier.
:param after: An item ID to list items after, used in pagination.
@ -259,7 +275,9 @@ class Conversations(Protocol):
async def openai_delete_conversation_item(
self, conversation_id: str, item_id: str
) -> ConversationItemDeletedResource:
"""Delete a conversation item.
"""Delete an item.
Delete a conversation item.
:param conversation_id: The conversation identifier.
:param item_id: The item identifier.

View file

@ -82,7 +82,9 @@ class EvaluateResponse(BaseModel):
class Eval(Protocol):
"""Llama Stack Evaluation API for running evaluations on model and agent candidates."""
"""Evaluations
Llama Stack Evaluation API for running evaluations on model and agent candidates."""
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST", level=LLAMA_STACK_API_V1ALPHA)

View file

@ -0,0 +1,182 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import sys
from pathlib import Path
import yaml
from termcolor import cprint
from llama_stack.cli.stack.utils import ImageType
from llama_stack.core.build import get_provider_dependencies
from llama_stack.core.datatypes import (
BuildConfig,
BuildProvider,
DistributionSpec,
)
from llama_stack.core.distribution import get_provider_registry
from llama_stack.core.stack import replace_env_vars
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"
logger = get_logger(name=__name__, category="cli")
# These are the dependencies needed by the distribution server.
# `llama-stack` is automatically installed by the installation script.
SERVER_DEPENDENCIES = [
"aiosqlite",
"fastapi",
"fire",
"httpx",
"uvicorn",
"opentelemetry-sdk",
"opentelemetry-exporter-otlp-proto-http",
]
def format_output_deps_only(
normal_deps: list[str],
special_deps: list[str],
external_deps: list[str],
uv: bool = False,
) -> str:
"""Format dependencies as a list."""
lines = []
uv_str = ""
if uv:
uv_str = "uv pip install "
# Quote deps with commas
quoted_normal_deps = [quote_if_needed(dep) for dep in normal_deps]
lines.append(f"{uv_str}{' '.join(quoted_normal_deps)}")
for special_dep in special_deps:
lines.append(f"{uv_str}{quote_special_dep(special_dep)}")
for external_dep in external_deps:
lines.append(f"{uv_str}{quote_special_dep(external_dep)}")
return "\n".join(lines)
def run_stack_list_deps_command(args: argparse.Namespace) -> None:
if args.config:
try:
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
config_file = resolve_config_or_distro(args.config, Mode.BUILD)
except ValueError as e:
cprint(
f"Could not parse config file {args.config}: {e}",
color="red",
file=sys.stderr,
)
sys.exit(1)
if config_file:
with open(config_file) as f:
try:
contents = yaml.safe_load(f)
contents = replace_env_vars(contents)
build_config = BuildConfig(**contents)
build_config.image_type = "venv"
except Exception as e:
cprint(
f"Could not parse config file {config_file}: {e}",
color="red",
file=sys.stderr,
)
sys.exit(1)
elif args.providers:
provider_list: dict[str, list[BuildProvider]] = dict()
for api_provider in args.providers.split(","):
if "=" not in api_provider:
cprint(
"Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2",
color="red",
file=sys.stderr,
)
sys.exit(1)
api, provider_type = api_provider.split("=")
providers_for_api = get_provider_registry().get(Api(api), None)
if providers_for_api is None:
cprint(
f"{api} is not a valid API.",
color="red",
file=sys.stderr,
)
sys.exit(1)
if provider_type in providers_for_api:
provider = BuildProvider(
provider_type=provider_type,
module=None,
)
provider_list.setdefault(api, []).append(provider)
else:
cprint(
f"{provider_type} is not a valid provider for the {api} API.",
color="red",
file=sys.stderr,
)
sys.exit(1)
distribution_spec = DistributionSpec(
providers=provider_list,
description=",".join(args.providers),
)
build_config = BuildConfig(image_type=ImageType.VENV.value, distribution_spec=distribution_spec)
normal_deps, special_deps, external_provider_dependencies = get_provider_dependencies(build_config)
normal_deps += SERVER_DEPENDENCIES
# Add external API dependencies
if build_config.external_apis_dir:
from llama_stack.core.external import load_external_apis
external_apis = load_external_apis(build_config)
if external_apis:
for _, api_spec in external_apis.items():
normal_deps.extend(api_spec.pip_packages)
# Format and output based on requested format
output = format_output_deps_only(
normal_deps=normal_deps,
special_deps=special_deps,
external_deps=external_provider_dependencies,
uv=args.format == "uv",
)
print(output)
def quote_if_needed(dep):
# Add quotes if the dependency contains special characters that need escaping in shell
# This includes: commas, comparison operators (<, >, <=, >=, ==, !=)
needs_quoting = any(char in dep for char in [",", "<", ">", "="])
return f"'{dep}'" if needs_quoting else dep
def quote_special_dep(dep_string):
"""
Quote individual packages in a special dependency string.
Special deps may contain multiple packages and flags like --extra-index-url.
We need to quote only the package specs that contain special characters.
"""
parts = dep_string.split()
quoted_parts = []
for part in parts:
# Don't quote flags (they start with -)
if part.startswith("-"):
quoted_parts.append(part)
else:
# Quote package specs that need it
quoted_parts.append(quote_if_needed(part))
return " ".join(quoted_parts)

View file

@ -8,6 +8,9 @@ import textwrap
from llama_stack.cli.stack.utils import ImageType
from llama_stack.cli.subcommand import Subcommand
from llama_stack.log import get_logger
logger = get_logger(__name__, category="cli")
class StackBuild(Subcommand):
@ -16,7 +19,7 @@ class StackBuild(Subcommand):
self.parser = subparsers.add_parser(
"build",
prog="llama stack build",
description="Build a Llama stack container",
description="[DEPRECATED] Build a Llama stack container. This command is deprecated and will be removed in a future release. Use `llama stack list-deps <distro>' instead.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
self._add_arguments()
@ -93,6 +96,9 @@ the build. If not specified, currently active environment will be used if found.
)
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
logger.warning(
"The 'llama stack build' command is deprecated and will be removed in a future release. Please use 'llama stack list-deps'"
)
# always keep implementation completely silo-ed away from CLI so CLI
# can be fast to load and reduces dependencies
from ._build import run_stack_build_command

View file

@ -0,0 +1,51 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from llama_stack.cli.subcommand import Subcommand
class StackListDeps(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"list-deps",
prog="llama stack list-deps",
description="list the dependencies for a llama stack distribution",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_stack_list_deps_command)
def _add_arguments(self):
self.parser.add_argument(
"config",
type=str,
nargs="?", # Make it optional
metavar="config | distro",
help="Path to config file to use or name of known distro (llama stack list for a list).",
)
self.parser.add_argument(
"--providers",
type=str,
default=None,
help="sync dependencies for a list of providers and only those providers. This list is formatted like: api1=provider1,api2=provider2. Where there can be multiple providers per API.",
)
self.parser.add_argument(
"--format",
type=str,
choices=["uv", "deps-only"],
default="deps-only",
help="Output format: 'uv' shows shell commands, 'deps-only' shows just the list of dependencies without `uv` (default)",
)
def _run_stack_list_deps_command(self, args: argparse.Namespace) -> None:
# always keep implementation completely silo-ed away from CLI so CLI
# can be fast to load and reduces dependencies
from ._list_deps import run_stack_list_deps_command
return run_stack_list_deps_command(args)

View file

@ -13,6 +13,7 @@ from llama_stack.cli.subcommand import Subcommand
from .build import StackBuild
from .list_apis import StackListApis
from .list_deps import StackListDeps
from .list_providers import StackListProviders
from .remove import StackRemove
from .run import StackRun
@ -39,6 +40,7 @@ class StackParser(Subcommand):
subparsers = self.parser.add_subparsers(title="stack_subcommands")
# Add sub-commands
StackListDeps.create(subparsers)
StackBuild.create(subparsers)
StackListApis.create(subparsers)
StackListProviders.create(subparsers)

View file

@ -4,7 +4,28 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import sys
from enum import Enum
from functools import lru_cache
from pathlib import Path
import yaml
from termcolor import cprint
from llama_stack.core.datatypes import (
BuildConfig,
Provider,
StackRunConfig,
)
from llama_stack.core.distribution import get_provider_registry
from llama_stack.core.resolver import InvalidProviderError
from llama_stack.core.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.core.utils.image_types import LlamaStackImageType
from llama_stack.providers.datatypes import Api
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "distributions"
class ImageType(Enum):
@ -19,3 +40,91 @@ def print_subcommand_description(parser, subparsers):
description = subcommand.description
description_text += f" {name:<21} {description}\n"
parser.epilog = description_text
def generate_run_config(
build_config: BuildConfig,
build_dir: Path,
image_name: str,
) -> Path:
"""
Generate a run.yaml template file for user to edit from a build.yaml file
"""
apis = list(build_config.distribution_spec.providers.keys())
run_config = StackRunConfig(
container_image=(image_name if build_config.image_type == LlamaStackImageType.CONTAINER.value else None),
image_name=image_name,
apis=apis,
providers={},
external_providers_dir=build_config.external_providers_dir
if build_config.external_providers_dir
else EXTERNAL_PROVIDERS_DIR,
)
# build providers dict
provider_registry = get_provider_registry(build_config)
for api in apis:
run_config.providers[api] = []
providers = build_config.distribution_spec.providers[api]
for provider in providers:
pid = provider.provider_type.split("::")[-1]
p = provider_registry[Api(api)][provider.provider_type]
if p.deprecation_error:
raise InvalidProviderError(p.deprecation_error)
try:
config_type = instantiate_class_type(provider_registry[Api(api)][provider.provider_type].config_class)
except (ModuleNotFoundError, ValueError) as exc:
# HACK ALERT:
# This code executes after building is done, the import cannot work since the
# package is either available in the venv or container - not available on the host.
# TODO: use a "is_external" flag in ProviderSpec to check if the provider is
# external
cprint(
f"Failed to import provider {provider.provider_type} for API {api} - assuming it's external, skipping: {exc}",
color="yellow",
file=sys.stderr,
)
# Set config_type to None to avoid UnboundLocalError
config_type = None
if config_type is not None and hasattr(config_type, "sample_run_config"):
config = config_type.sample_run_config(__distro_dir__=f"~/.llama/distributions/{image_name}")
else:
config = {}
p_spec = Provider(
provider_id=pid,
provider_type=provider.provider_type,
config=config,
module=provider.module,
)
run_config.providers[api].append(p_spec)
run_config_file = build_dir / f"{image_name}-run.yaml"
with open(run_config_file, "w") as f:
to_write = json.loads(run_config.model_dump_json())
f.write(yaml.dump(to_write, sort_keys=False))
# Only print this message for non-container builds since it will be displayed before the
# container is built
# For non-container builds, the run.yaml is generated at the very end of the build process so it
# makes sense to display this message
if build_config.image_type != LlamaStackImageType.CONTAINER.value:
cprint(f"You can now run your stack with `llama stack run {run_config_file}`", color="green", file=sys.stderr)
return run_config_file
@lru_cache
def available_templates_specs() -> dict[str, BuildConfig]:
import yaml
template_specs = {}
for p in TEMPLATES_PATH.rglob("*build.yaml"):
template_name = p.parent.name
with open(p) as f:
build_config = BuildConfig(**yaml.safe_load(f))
template_specs[template_name] = build_config
return template_specs

View file

@ -338,7 +338,7 @@ fi
# Add other require item commands genearic to all containers
add_to_container << EOF
RUN mkdir -p /.llama /.cache && chmod -R g+rw /app /.llama /.cache
RUN mkdir -p /.llama /.cache && chmod -R g+rw /.llama /.cache && (chmod -R g+rw /app 2>/dev/null || true)
EOF
printf "Containerfile created successfully in %s/Containerfile\n\n" "$TEMP_DIR"

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import importlib
import importlib.metadata
import inspect
from typing import Any

View file

@ -42,3 +42,8 @@ def sync_test_context_from_provider_data():
return TEST_CONTEXT.set(provider_data["__test_id"])
return None
def is_debug_mode() -> bool:
"""Check if test recording debug mode is enabled via LLAMA_STACK_TEST_DEBUG env var."""
return os.environ.get("LLAMA_STACK_TEST_DEBUG", "").lower() in ("1", "true", "yes")

View file

@ -42,25 +42,25 @@ def resolve_config_or_distro(
# Strategy 1: Try as file path first
config_path = Path(config_or_distro)
if config_path.exists() and config_path.is_file():
logger.info(f"Using file path: {config_path}")
logger.debug(f"Using file path: {config_path}")
return config_path.resolve()
# Strategy 2: Try as distribution name (if no .yaml extension)
if not config_or_distro.endswith(".yaml"):
distro_config = _get_distro_config_path(config_or_distro, mode)
if distro_config.exists():
logger.info(f"Using distribution: {distro_config}")
logger.debug(f"Using distribution: {distro_config}")
return distro_config
# Strategy 3: Try as built distribution name
distrib_config = DISTRIBS_BASE_DIR / f"llamastack-{config_or_distro}" / f"{config_or_distro}-{mode}.yaml"
if distrib_config.exists():
logger.info(f"Using built distribution: {distrib_config}")
logger.debug(f"Using built distribution: {distrib_config}")
return distrib_config
distrib_config = DISTRIBS_BASE_DIR / f"{config_or_distro}" / f"{config_or_distro}-{mode}.yaml"
if distrib_config.exists():
logger.info(f"Using built distribution: {distrib_config}")
logger.debug(f"Using built distribution: {distrib_config}")
return distrib_config
# Strategy 4: Failed - provide helpful error

View file

@ -70,10 +70,10 @@ docker run \
### Via venv
Make sure you have done `uv pip install llama-stack` and have the Llama Stack CLI available.
Make sure you have the Llama Stack CLI available.
```bash
llama stack build --distro {{ name }} --image-type venv
llama stack list-deps meta-reference-gpu | xargs -L1 uv pip install
INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
llama stack run distributions/{{ name }}/run.yaml \
--port 8321

View file

@ -126,11 +126,11 @@ docker run \
### Via venv
If you've set up your local development environment, you can also build the image using your local virtual environment.
If you've set up your local development environment, you can also install the distribution dependencies using your local virtual environment.
```bash
INFERENCE_MODEL=meta-llama/Llama-3.1-8B-Instruct
llama stack build --distro nvidia --image-type venv
llama stack list-deps nvidia | xargs -L1 uv pip install
NVIDIA_API_KEY=$NVIDIA_API_KEY \
INFERENCE_MODEL=$INFERENCE_MODEL \
llama stack run ./run.yaml \

View file

@ -79,7 +79,6 @@ class TelemetryAdapter(Telemetry):
metrics.set_meter_provider(metric_provider)
self.meter = metrics.get_meter(__name__)
self._lock = _global_lock
async def initialize(self) -> None:

View file

@ -45,7 +45,7 @@ The following example shows how to create a chat completion for an NVIDIA NIM.
```python
response = client.chat.completions.create(
model="meta-llama/Llama-3.1-8B-Instruct",
model="nvidia/meta/llama-3.1-8b-instruct",
messages=[
{
"role": "system",
@ -67,37 +67,40 @@ print(f"Response: {response.choices[0].message.content}")
The following example shows how to do tool calling for an NVIDIA NIM.
```python
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
tool_definition = ToolDefinition(
tool_name="get_weather",
description="Get current weather information for a location",
parameters={
"location": ToolParamDefinition(
param_type="string",
description="The city and state, e.g. San Francisco, CA",
required=True,
),
"unit": ToolParamDefinition(
param_type="string",
description="Temperature unit (celsius or fahrenheit)",
required=False,
default="celsius",
),
tool_definition = {
"type": "function",
"function": {
"name": "get_weather",
"description": "Get current weather information for a location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {
"type": "string",
"description": "Temperature unit (celsius or fahrenheit)",
"default": "celsius",
},
},
"required": ["location"],
},
},
)
}
tool_response = client.chat.completions.create(
model="meta-llama/Llama-3.1-8B-Instruct",
model="nvidia/meta/llama-3.1-8b-instruct",
messages=[{"role": "user", "content": "What's the weather like in San Francisco?"}],
tools=[tool_definition],
)
print(f"Tool Response: {tool_response.choices[0].message.content}")
print(f"Response content: {tool_response.choices[0].message.content}")
if tool_response.choices[0].message.tool_calls:
for tool_call in tool_response.choices[0].message.tool_calls:
print(f"Tool Called: {tool_call.tool_name}")
print(f"Arguments: {tool_call.arguments}")
print(f"Tool Called: {tool_call.function.name}")
print(f"Arguments: {tool_call.function.arguments}")
```
### Structured Output Example
@ -105,33 +108,26 @@ if tool_response.choices[0].message.tool_calls:
The following example shows how to do structured output for an NVIDIA NIM.
```python
from llama_stack.apis.inference import JsonSchemaResponseFormat, ResponseFormatType
person_schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"},
"age": {"type": "number"},
"occupation": {"type": "string"},
},
"required": ["name", "age", "occupation"],
}
response_format = JsonSchemaResponseFormat(
type=ResponseFormatType.json_schema, json_schema=person_schema
)
structured_response = client.chat.completions.create(
model="meta-llama/Llama-3.1-8B-Instruct",
model="nvidia/meta/llama-3.1-8b-instruct",
messages=[
{
"role": "user",
"content": "Create a profile for a fictional person named Alice who is 30 years old and is a software engineer. ",
}
],
response_format=response_format,
extra_body={"nvext": {"guided_json": person_schema}},
)
print(f"Structured Response: {structured_response.choices[0].message.content}")
```
@ -141,7 +137,7 @@ The following example shows how to create embeddings for an NVIDIA NIM.
```python
response = client.embeddings.create(
model="nvidia/llama-3.2-nv-embedqa-1b-v2",
model="nvidia/nvidia/llama-3.2-nv-embedqa-1b-v2",
input=["What is the capital of France?"],
extra_body={"input_type": "query"},
)
@ -163,15 +159,15 @@ image_path = {path_to_the_image}
demo_image_b64 = load_image_as_base64(image_path)
vlm_response = client.chat.completions.create(
model="nvidia/vila",
model="nvidia/meta/llama-3.2-11b-vision-instruct",
messages=[
{
"role": "user",
"content": [
{
"type": "image",
"image": {
"data": demo_image_b64,
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{demo_image_b64}",
},
},
{

View file

@ -19,15 +19,6 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
"""
NVIDIA Inference Adapter for Llama Stack.
Note: The inheritance order is important here. OpenAIMixin must come before
ModelRegistryHelper to ensure that OpenAIMixin.check_model_availability()
is used instead of ModelRegistryHelper.check_model_availability(). It also
must come before Inference to ensure that OpenAIMixin methods are available
in the Inference interface.
- OpenAIMixin.check_model_availability() queries the NVIDIA API to check if a model exists
- ModelRegistryHelper.check_model_availability() just returns False and shows a warning
"""
# source: https://docs.nvidia.com/nim/nemo-retriever/text-embedding/latest/support-matrix.html

View file

@ -70,7 +70,7 @@ def trace_protocol[T](cls: type[T]) -> type[T]:
"__class__": class_name,
"__method__": method_name,
"__type__": span_type,
"__args__": str(combined_args),
"__args__": json.dumps(combined_args),
}
return class_name, method_name, span_attributes
@ -82,8 +82,8 @@ def trace_protocol[T](cls: type[T]) -> type[T]:
class_name, method_name, span_attributes = create_span_context(self, *args, **kwargs)
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
count = 0
try:
count = 0
async for item in method(self, *args, **kwargs):
yield item
count += 1

View file

@ -37,7 +37,7 @@ _id_counters: dict[str, dict[str, int]] = {}
# Test context uses ContextVar since it changes per-test and needs async isolation
from openai.types.completion_choice import CompletionChoice
from llama_stack.core.testing_context import get_test_context
from llama_stack.core.testing_context import get_test_context, is_debug_mode
# update the "finish_reason" field, since its type definition is wrong (no None is accepted)
CompletionChoice.model_fields["finish_reason"].annotation = Literal["stop", "length", "content_filter"] | None
@ -146,6 +146,7 @@ def normalize_inference_request(method: str, url: str, headers: dict[str, Any],
body_for_hash = _normalize_body_for_hash(body)
test_id = get_test_context()
normalized: dict[str, Any] = {
"method": method.upper(),
"endpoint": parsed.path,
@ -154,10 +155,20 @@ def normalize_inference_request(method: str, url: str, headers: dict[str, Any],
# Include test_id for isolation, except for shared infrastructure endpoints
if parsed.path not in ("/api/tags", "/v1/models"):
normalized["test_id"] = get_test_context()
normalized["test_id"] = test_id
normalized_json = json.dumps(normalized, sort_keys=True)
return hashlib.sha256(normalized_json.encode()).hexdigest()
request_hash = hashlib.sha256(normalized_json.encode()).hexdigest()
if is_debug_mode():
logger.info("[RECORDING DEBUG] Hash computation:")
logger.info(f" Test ID: {test_id}")
logger.info(f" Method: {method.upper()}")
logger.info(f" Endpoint: {parsed.path}")
logger.info(f" Model: {body.get('model', 'N/A')}")
logger.info(f" Computed hash: {request_hash}")
return request_hash
def normalize_tool_request(provider_name: str, tool_name: str, kwargs: dict[str, Any]) -> str:
@ -212,6 +223,11 @@ def patch_httpx_for_test_id():
provider_data["__test_id"] = test_id
request.headers["X-LlamaStack-Provider-Data"] = json.dumps(provider_data)
if is_debug_mode():
logger.info("[RECORDING DEBUG] Injected test ID into request header:")
logger.info(f" Test ID: {test_id}")
logger.info(f" URL: {request.url}")
return None
LlamaStackClient._prepare_request = patched_prepare_request
@ -355,12 +371,35 @@ class ResponseStorage:
test_file = test_id.split("::")[0] # Remove test function part
test_dir = Path(test_file).parent # Get parent directory
# Put recordings in a "recordings" subdirectory of the test's parent dir
# e.g., "tests/integration/inference" -> "tests/integration/inference/recordings"
return test_dir / "recordings"
if self.base_dir.is_absolute():
repo_root = self.base_dir.parent.parent.parent
result = repo_root / test_dir / "recordings"
if is_debug_mode():
logger.info("[RECORDING DEBUG] Path resolution (absolute base_dir):")
logger.info(f" Test ID: {test_id}")
logger.info(f" Base dir: {self.base_dir}")
logger.info(f" Repo root: {repo_root}")
logger.info(f" Test file: {test_file}")
logger.info(f" Test dir: {test_dir}")
logger.info(f" Recordings dir: {result}")
return result
else:
result = test_dir / "recordings"
if is_debug_mode():
logger.info("[RECORDING DEBUG] Path resolution (relative base_dir):")
logger.info(f" Test ID: {test_id}")
logger.info(f" Base dir: {self.base_dir}")
logger.info(f" Test dir: {test_dir}")
logger.info(f" Recordings dir: {result}")
return result
else:
# Fallback for non-test contexts
return self.base_dir / "recordings"
result = self.base_dir / "recordings"
if is_debug_mode():
logger.info("[RECORDING DEBUG] Path resolution (no test context):")
logger.info(f" Base dir: {self.base_dir}")
logger.info(f" Recordings dir: {result}")
return result
def _ensure_directory(self):
"""Ensure test-specific directories exist."""
@ -395,6 +434,13 @@ class ResponseStorage:
response_path = responses_dir / response_file
if is_debug_mode():
logger.info("[RECORDING DEBUG] Storing recording:")
logger.info(f" Request hash: {request_hash}")
logger.info(f" File: {response_path}")
logger.info(f" Test ID: {get_test_context()}")
logger.info(f" Endpoint: {endpoint}")
# Save response to JSON file with metadata
with open(response_path, "w") as f:
json.dump(
@ -423,16 +469,33 @@ class ResponseStorage:
test_dir = self._get_test_dir()
response_path = test_dir / response_file
if is_debug_mode():
logger.info("[RECORDING DEBUG] Looking up recording:")
logger.info(f" Request hash: {request_hash}")
logger.info(f" Primary path: {response_path}")
logger.info(f" Primary exists: {response_path.exists()}")
if response_path.exists():
if is_debug_mode():
logger.info(" Found in primary location")
return _recording_from_file(response_path)
# Fallback to base recordings directory (for session-level recordings)
fallback_dir = self.base_dir / "recordings"
fallback_path = fallback_dir / response_file
if is_debug_mode():
logger.info(f" Fallback path: {fallback_path}")
logger.info(f" Fallback exists: {fallback_path.exists()}")
if fallback_path.exists():
if is_debug_mode():
logger.info(" Found in fallback location")
return _recording_from_file(fallback_path)
if is_debug_mode():
logger.info(" Recording not found in either location")
return None
def _model_list_responses(self, request_hash: str) -> list[dict[str, Any]]:
@ -588,6 +651,13 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
mode = _current_mode
storage = _current_storage
if is_debug_mode():
logger.info("[RECORDING DEBUG] Entering inference method:")
logger.info(f" Mode: {mode}")
logger.info(f" Client type: {client_type}")
logger.info(f" Endpoint: {endpoint}")
logger.info(f" Test context: {get_test_context()}")
if mode == APIRecordingMode.LIVE or storage is None:
if endpoint == "/v1/models":
return original_method(self, *args, **kwargs)
@ -643,6 +713,18 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
return response_body
elif mode == APIRecordingMode.REPLAY:
# REPLAY mode requires recording to exist
if is_debug_mode():
logger.error("[RECORDING DEBUG] Recording not found!")
logger.error(f" Mode: {mode}")
logger.error(f" Request hash: {request_hash}")
logger.error(f" Method: {method}")
logger.error(f" URL: {url}")
logger.error(f" Endpoint: {endpoint}")
logger.error(f" Model: {body.get('model', 'unknown')}")
logger.error(f" Test context: {get_test_context()}")
logger.error(
f" Stack config type: {os.environ.get('LLAMA_STACK_TEST_STACK_CONFIG_TYPE', 'library_client')}"
)
raise RuntimeError(
f"Recording not found for request hash: {request_hash}\n"
f"Model: {body.get('model', 'unknown')} | Request: {method} {url}\n"

File diff suppressed because it is too large Load diff

View file

@ -43,16 +43,16 @@
"@testing-library/dom": "^10.4.1",
"@testing-library/jest-dom": "^6.8.0",
"@testing-library/react": "^16.3.0",
"@types/jest": "^29.5.14",
"@types/jest": "^30.0.0",
"@types/node": "^24",
"@types/react": "^19",
"@types/react-dom": "^19",
"eslint": "^9",
"eslint-config-next": "15.5.2",
"eslint-config-next": "15.5.6",
"eslint-config-prettier": "^10.1.8",
"eslint-plugin-prettier": "^5.5.4",
"jest": "^29.7.0",
"jest-environment-jsdom": "^30.1.2",
"jest": "^30.2.0",
"jest-environment-jsdom": "^30.2.0",
"prettier": "3.6.2",
"tailwindcss": "^4",
"ts-node": "^10.9.2",