Merge branch 'main' into fix-chroma

This commit is contained in:
Francisco Arceo 2025-07-24 21:48:51 -04:00 committed by GitHub
commit 062c6a419a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
76 changed files with 2468 additions and 913 deletions

View file

@ -4,15 +4,83 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from enum import Enum, EnumMeta
from pydantic import BaseModel
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
class DynamicApiMeta(EnumMeta):
def __new__(cls, name, bases, namespace):
# Store the original enum values
original_values = {k: v for k, v in namespace.items() if not k.startswith("_")}
# Create the enum class
cls = super().__new__(cls, name, bases, namespace)
# Store the original values for reference
cls._original_values = original_values
# Initialize _dynamic_values
cls._dynamic_values = {}
return cls
def __call__(cls, value):
try:
return super().__call__(value)
except ValueError as e:
# If this value was already dynamically added, return it
if value in cls._dynamic_values:
return cls._dynamic_values[value]
# If the value doesn't exist, create a new enum member
# Create a new member name from the value
member_name = value.lower().replace("-", "_")
# If this member name already exists in the enum, return the existing member
if member_name in cls._member_map_:
return cls._member_map_[member_name]
# Instead of creating a new member, raise ValueError to force users to use Api.add() to
# register new APIs explicitly
raise ValueError(f"API '{value}' does not exist. Use Api.add() to register new APIs.") from e
def __iter__(cls):
# Allow iteration over both static and dynamic members
yield from super().__iter__()
if hasattr(cls, "_dynamic_values"):
yield from cls._dynamic_values.values()
def add(cls, value):
"""
Add a new API to the enum.
Used to register external APIs.
"""
member_name = value.lower().replace("-", "_")
# If this member name already exists in the enum, return it
if member_name in cls._member_map_:
return cls._member_map_[member_name]
# Create a new enum member
member = object.__new__(cls)
member._name_ = member_name
member._value_ = value
# Add it to the enum class
cls._member_map_[member_name] = member
cls._member_names_.append(member_name)
cls._member_type_ = str
# Store it in our dynamic values
cls._dynamic_values[value] = member
return member
@json_schema_type
class Api(Enum):
class Api(Enum, metaclass=DynamicApiMeta):
providers = "providers"
inference = "inference"
safety = "safety"
@ -54,3 +122,12 @@ class Error(BaseModel):
title: str
detail: str
instance: str | None = None
class ExternalApiSpec(BaseModel):
"""Specification for an external API implementation."""
module: str = Field(..., description="Python module containing the API implementation")
name: str = Field(..., description="Name of the API")
pip_packages: list[str] = Field(default=[], description="List of pip packages to install the API")
protocol: str = Field(..., description="Name of the protocol class for the API")

View file

@ -819,12 +819,6 @@ class OpenAIEmbeddingsResponse(BaseModel):
class ModelStore(Protocol):
async def get_model(self, identifier: str) -> Model: ...
async def update_registered_llm_models(
self,
provider_id: str,
models: list[Model],
) -> None: ...
class TextTruncation(Enum):
"""Config for how to truncate text for embedding when text is longer than the model's max sequence length. Start and End semantics depend on whether the language is left-to-right or right-to-left.

View file

@ -22,6 +22,8 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho
# Add this constant near the top of the file, after the imports
DEFAULT_TTL_DAYS = 7
REQUIRED_SCOPE = "telemetry.read"
@json_schema_type
class SpanStatus(Enum):
@ -259,7 +261,7 @@ class Telemetry(Protocol):
"""
...
@webmethod(route="/telemetry/traces", method="POST")
@webmethod(route="/telemetry/traces", method="POST", required_scope=REQUIRED_SCOPE)
async def query_traces(
self,
attribute_filters: list[QueryCondition] | None = None,
@ -277,7 +279,7 @@ class Telemetry(Protocol):
"""
...
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET")
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET", required_scope=REQUIRED_SCOPE)
async def get_trace(self, trace_id: str) -> Trace:
"""Get a trace by its ID.
@ -286,7 +288,9 @@ class Telemetry(Protocol):
"""
...
@webmethod(route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET")
@webmethod(
route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET", required_scope=REQUIRED_SCOPE
)
async def get_span(self, trace_id: str, span_id: str) -> Span:
"""Get a span by its ID.
@ -296,7 +300,7 @@ class Telemetry(Protocol):
"""
...
@webmethod(route="/telemetry/spans/{span_id:path}/tree", method="POST")
@webmethod(route="/telemetry/spans/{span_id:path}/tree", method="POST", required_scope=REQUIRED_SCOPE)
async def get_span_tree(
self,
span_id: str,
@ -312,7 +316,7 @@ class Telemetry(Protocol):
"""
...
@webmethod(route="/telemetry/spans", method="POST")
@webmethod(route="/telemetry/spans", method="POST", required_scope=REQUIRED_SCOPE)
async def query_spans(
self,
attribute_filters: list[QueryCondition],
@ -345,7 +349,7 @@ class Telemetry(Protocol):
"""
...
@webmethod(route="/telemetry/metrics/{metric_name}", method="POST")
@webmethod(route="/telemetry/metrics/{metric_name}", method="POST", required_scope=REQUIRED_SCOPE)
async def query_metrics(
self,
metric_name: str,

View file

@ -36,6 +36,7 @@ from llama_stack.distribution.datatypes import (
StackRunConfig,
)
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.external import load_external_apis
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.stack import replace_env_vars
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR, EXTERNAL_PROVIDERS_DIR
@ -404,6 +405,29 @@ def _run_stack_build_command_from_build_config(
to_write = json.loads(build_config.model_dump_json())
f.write(yaml.dump(to_write, sort_keys=False))
# We first install the external APIs so that the build process can use them and discover the
# providers dependencies
if build_config.external_apis_dir:
cprint("Installing external APIs", color="yellow", file=sys.stderr)
external_apis = load_external_apis(build_config)
if external_apis:
# install the external APIs
packages = []
for _, api_spec in external_apis.items():
if api_spec.pip_packages:
packages.extend(api_spec.pip_packages)
cprint(
f"Installing {api_spec.name} with pip packages {api_spec.pip_packages}",
color="yellow",
file=sys.stderr,
)
return_code = run_command(["uv", "pip", "install", *packages])
if return_code != 0:
packages_str = ", ".join(packages)
raise RuntimeError(
f"Failed to install external APIs packages: {packages_str} (return code: {return_code})"
)
return_code = build_image(
build_config,
build_file_path,

View file

@ -14,6 +14,7 @@ from termcolor import cprint
from llama_stack.distribution.datatypes import BuildConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.external import load_external_apis
from llama_stack.distribution.utils.exec import run_command
from llama_stack.distribution.utils.image_types import LlamaStackImageType
from llama_stack.providers.datatypes import Api
@ -105,6 +106,11 @@ def build_image(
normal_deps, special_deps = get_provider_dependencies(build_config)
normal_deps += SERVER_DEPENDENCIES
if build_config.external_apis_dir:
external_apis = load_external_apis(build_config)
if external_apis:
for _, api_spec in external_apis.items():
normal_deps.extend(api_spec.pip_packages)
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
script = str(importlib.resources.files("llama_stack") / "distribution/build_container.sh")

View file

@ -36,6 +36,11 @@ LLAMA_STACK_RUN_CONFIG_VERSION = 2
RoutingKey = str | list[str]
class RegistryEntrySource(StrEnum):
via_register_api = "via_register_api"
listed_from_provider = "listed_from_provider"
class User(BaseModel):
principal: str
# further attributes that may be used for access control decisions
@ -50,6 +55,7 @@ class ResourceWithOwner(Resource):
resource. This can be used to constrain access to the resource."""
owner: User | None = None
source: RegistryEntrySource = RegistryEntrySource.via_register_api
# Use the extended Resource for all routable objects
@ -381,6 +387,11 @@ a default SQLite store will be used.""",
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
)
external_apis_dir: Path | None = Field(
default=None,
description="Path to directory containing external API implementations. The APIs code and dependencies must be installed on the system.",
)
@field_validator("external_providers_dir")
@classmethod
def validate_external_providers_dir(cls, v):
@ -412,6 +423,10 @@ class BuildConfig(BaseModel):
default_factory=list,
description="Additional pip packages to install in the distribution. These packages will be installed in the distribution environment.",
)
external_apis_dir: Path | None = Field(
default=None,
description="Path to directory containing external API implementations. The APIs code and dependencies must be installed on the system.",
)
@field_validator("external_providers_dir")
@classmethod

View file

@ -12,6 +12,7 @@ from typing import Any
import yaml
from pydantic import BaseModel
from llama_stack.distribution.external import load_external_apis
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import (
AdapterSpec,
@ -133,16 +134,34 @@ def get_provider_registry(
ValueError: If any provider spec is invalid
"""
ret: dict[Api, dict[str, ProviderSpec]] = {}
registry: dict[Api, dict[str, ProviderSpec]] = {}
for api in providable_apis():
name = api.name.lower()
logger.debug(f"Importing module {name}")
try:
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
ret[api] = {a.provider_type: a for a in module.available_providers()}
registry[api] = {a.provider_type: a for a in module.available_providers()}
except ImportError as e:
logger.warning(f"Failed to import module {name}: {e}")
# Refresh providable APIs with external APIs if any
external_apis = load_external_apis(config)
for api, api_spec in external_apis.items():
name = api_spec.name.lower()
logger.info(f"Importing external API {name} module {api_spec.module}")
try:
module = importlib.import_module(api_spec.module)
registry[api] = {a.provider_type: a for a in module.available_providers()}
except (ImportError, AttributeError) as e:
# Populate the registry with an empty dict to avoid breaking the provider registry
# This assume that the in-tree provider(s) are not available for this API which means
# that users will need to use external providers for this API.
registry[api] = {}
logger.error(
f"Failed to import external API {name}: {e}. Could not populate the in-tree provider(s) registry for {api.name}. \n"
"Install the API package to load any in-tree providers for this API."
)
# Check if config has the external_providers_dir attribute
if config and hasattr(config, "external_providers_dir") and config.external_providers_dir:
external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir))
@ -175,11 +194,9 @@ def get_provider_registry(
else:
spec = _load_inline_provider_spec(spec_data, api, provider_name)
provider_type_key = f"inline::{provider_name}"
logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}")
if provider_type_key in ret[api]:
if provider_type_key in registry[api]:
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
ret[api][provider_type_key] = spec
registry[api][provider_type_key] = spec
logger.info(f"Successfully loaded external provider {provider_type_key}")
except yaml.YAMLError as yaml_err:
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
@ -187,4 +204,4 @@ def get_provider_registry(
except Exception as e:
logger.error(f"Failed to load provider spec from {spec_path}: {e}")
raise e
return ret
return registry

View file

@ -0,0 +1,54 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import yaml
from llama_stack.apis.datatypes import Api, ExternalApiSpec
from llama_stack.distribution.datatypes import BuildConfig, StackRunConfig
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="core")
def load_external_apis(config: StackRunConfig | BuildConfig | None) -> dict[Api, ExternalApiSpec]:
"""Load external API specifications from the configured directory.
Args:
config: StackRunConfig or BuildConfig containing the external APIs directory path
Returns:
A dictionary mapping API names to their specifications
"""
if not config or not config.external_apis_dir:
return {}
external_apis_dir = config.external_apis_dir.expanduser().resolve()
if not external_apis_dir.is_dir():
logger.error(f"External APIs directory is not a directory: {external_apis_dir}")
return {}
logger.info(f"Loading external APIs from {external_apis_dir}")
external_apis: dict[Api, ExternalApiSpec] = {}
# Look for YAML files in the external APIs directory
for yaml_path in external_apis_dir.glob("*.yaml"):
try:
with open(yaml_path) as f:
spec_data = yaml.safe_load(f)
spec = ExternalApiSpec(**spec_data)
api = Api.add(spec.name)
logger.info(f"Loaded external API spec for {spec.name} from {yaml_path}")
external_apis[api] = spec
except yaml.YAMLError as yaml_err:
logger.error(f"Failed to parse YAML file {yaml_path}: {yaml_err}")
raise
except Exception:
logger.exception(f"Failed to load external API spec from {yaml_path}")
raise
return external_apis

View file

@ -16,6 +16,7 @@ from llama_stack.apis.inspect import (
VersionInfo,
)
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.external import load_external_apis
from llama_stack.distribution.server.routes import get_all_api_routes
from llama_stack.providers.datatypes import HealthStatus
@ -42,7 +43,8 @@ class DistributionInspectImpl(Inspect):
run_config: StackRunConfig = self.config.run_config
ret = []
all_endpoints = get_all_api_routes()
external_apis = load_external_apis(run_config)
all_endpoints = get_all_api_routes(external_apis)
for api, endpoints in all_endpoints.items():
# Always include provider and inspect APIs, filter others based on run config
if api.value in ["providers", "inspect"]:
@ -53,7 +55,8 @@ class DistributionInspectImpl(Inspect):
method=next(iter([m for m in e.methods if m != "HEAD"])),
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
)
for e in endpoints
for e, _ in endpoints
if e.methods is not None
]
)
else:
@ -66,7 +69,8 @@ class DistributionInspectImpl(Inspect):
method=next(iter([m for m in e.methods if m != "HEAD"])),
provider_types=[p.provider_type for p in providers],
)
for e in endpoints
for e, _ in endpoints
if e.methods is not None
]
)

View file

@ -161,7 +161,13 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
if not self.skip_logger_removal:
self._remove_root_logger_handlers()
return self.loop.run_until_complete(self.async_client.initialize())
# use a new event loop to avoid interfering with the main event loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(self.async_client.initialize())
finally:
asyncio.set_event_loop(None)
def _remove_root_logger_handlers(self):
"""
@ -353,13 +359,15 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
body = options.params or {}
body |= options.json_data or {}
matched_func, path_params, route = find_matching_route(options.method, path, self.route_impls)
matched_func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
body |= path_params
body, field_names = self._handle_file_uploads(options, body)
body = self._convert_body(path, options.method, body, exclude_params=set(field_names))
await start_trace(route, {"__location__": "library_client"})
trace_path = webmethod.descriptive_name or route_path
await start_trace(trace_path, {"__location__": "library_client"})
try:
result = await matched_func(**body)
finally:
@ -409,12 +417,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
path = options.url
body = options.params or {}
body |= options.json_data or {}
func, path_params, route = find_matching_route(options.method, path, self.route_impls)
func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
body |= path_params
body = self._convert_body(path, options.method, body)
await start_trace(route, {"__location__": "library_client"})
trace_path = webmethod.descriptive_name or route_path
await start_trace(trace_path, {"__location__": "library_client"})
async def gen():
try:
@ -445,8 +454,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
# we use asynchronous impl always internally and channel all requests to AsyncLlamaStackClient
# however, the top-level caller may be a SyncAPIClient -- so its stream_cls might be a Stream (SyncStream)
# so we need to convert it to AsyncStream
# mypy can't track runtime variables inside the [...] of a generic, so ignore that check
args = get_args(stream_cls)
stream_cls = AsyncStream[args[0]]
stream_cls = AsyncStream[args[0]] # type: ignore[valid-type]
response = AsyncAPIResponse(
raw=mock_response,
client=self,
@ -468,7 +478,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
exclude_params = exclude_params or set()
func, _, _ = find_matching_route(method, path, self.route_impls)
func, _, _, _ = find_matching_route(method, path, self.route_impls)
sig = inspect.signature(func)
# Strip NOT_GIVENs to use the defaults in signature

View file

@ -101,3 +101,15 @@ def get_authenticated_user() -> User | None:
if not provider_data:
return None
return provider_data.get("__authenticated_user")
def user_from_scope(scope: dict) -> User | None:
"""Create a User object from ASGI scope data (set by authentication middleware)"""
user_attributes = scope.get("user_attributes", {})
principal = scope.get("principal", "")
# auth not enabled
if not principal and not user_attributes:
return None
return User(principal=principal, attributes=user_attributes)

View file

@ -11,6 +11,7 @@ from llama_stack.apis.agents import Agents
from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.datatypes import ExternalApiSpec
from llama_stack.apis.eval import Eval
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference, InferenceProvider
@ -35,6 +36,7 @@ from llama_stack.distribution.datatypes import (
StackRunConfig,
)
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.external import load_external_apis
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
@ -59,8 +61,16 @@ class InvalidProviderError(Exception):
pass
def api_protocol_map() -> dict[Api, Any]:
return {
def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) -> dict[Api, Any]:
"""Get a mapping of API types to their protocol classes.
Args:
external_apis: Optional dictionary of external API specifications
Returns:
Dictionary mapping API types to their protocol classes
"""
protocols = {
Api.providers: ProvidersAPI,
Api.agents: Agents,
Api.inference: Inference,
@ -83,10 +93,23 @@ def api_protocol_map() -> dict[Api, Any]:
Api.files: Files,
}
if external_apis:
for api, api_spec in external_apis.items():
try:
module = importlib.import_module(api_spec.module)
api_class = getattr(module, api_spec.protocol)
def api_protocol_map_for_compliance_check() -> dict[Api, Any]:
protocols[api] = api_class
except (ImportError, AttributeError):
logger.exception(f"Failed to load external API {api_spec.name}")
return protocols
def api_protocol_map_for_compliance_check(config: Any) -> dict[Api, Any]:
external_apis = load_external_apis(config)
return {
**api_protocol_map(),
**api_protocol_map(external_apis),
Api.inference: InferenceProvider,
}
@ -250,7 +273,7 @@ async def instantiate_providers(
dist_registry: DistributionRegistry,
run_config: StackRunConfig,
policy: list[AccessRule],
) -> dict:
) -> dict[Api, Any]:
"""Instantiates providers asynchronously while managing dependencies."""
impls: dict[Api, Any] = {}
inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
@ -360,7 +383,7 @@ async def instantiate_provider(
impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config
protocols = api_protocol_map_for_compliance_check()
protocols = api_protocol_map_for_compliance_check(run_config)
additional_protocols = additional_protocols_map()
# TODO: check compliance for special tool groups
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol

View file

@ -117,6 +117,9 @@ class CommonRoutingTableImpl(RoutingTable):
for p in self.impls_by_provider_id.values():
await p.shutdown()
async def refresh(self) -> None:
pass
async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
from .benchmarks import BenchmarksRoutingTable
from .datasets import DatasetsRoutingTable
@ -206,7 +209,6 @@ class CommonRoutingTableImpl(RoutingTable):
if obj.type == ResourceType.model.value:
await self.dist_registry.register(registered_obj)
return registered_obj
else:
await self.dist_registry.register(obj)
return obj

View file

@ -10,6 +10,7 @@ from typing import Any
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
from llama_stack.distribution.datatypes import (
ModelWithOwner,
RegistryEntrySource,
)
from llama_stack.log import get_logger
@ -19,6 +20,26 @@ logger = get_logger(name=__name__, category="core")
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
listed_providers: set[str] = set()
async def refresh(self) -> None:
for provider_id, provider in self.impls_by_provider_id.items():
refresh = await provider.should_refresh_models()
if not (refresh or provider_id in self.listed_providers):
continue
try:
models = await provider.list_models()
except Exception as e:
logger.exception(f"Model refresh failed for provider {provider_id}: {e}")
continue
self.listed_providers.add(provider_id)
if models is None:
continue
await self.update_registered_models(provider_id, models)
async def list_models(self) -> ListModelsResponse:
return ListModelsResponse(data=await self.get_all_with_type("model"))
@ -81,6 +102,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
provider_id=provider_id,
metadata=metadata,
model_type=model_type,
source=RegistryEntrySource.via_register_api,
)
registered_model = await self.register_object(model)
return registered_model
@ -91,7 +113,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
raise ValueError(f"Model {model_id} not found")
await self.unregister_object(existing_model)
async def update_registered_llm_models(
async def update_registered_models(
self,
provider_id: str,
models: list[Model],
@ -102,18 +124,19 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
# from run.yaml) that we need to keep track of
model_ids = {}
for model in existing_models:
# we leave embeddings models alone because often we don't get metadata
# (embedding dimension, etc.) from the provider
if model.provider_id == provider_id and model.model_type == ModelType.llm:
if model.provider_id != provider_id:
continue
if model.source == RegistryEntrySource.via_register_api:
model_ids[model.provider_resource_id] = model.identifier
logger.debug(f"unregistering model {model.identifier}")
await self.unregister_object(model)
continue
logger.debug(f"unregistering model {model.identifier}")
await self.unregister_object(model)
for model in models:
if model.model_type != ModelType.llm:
continue
if model.provider_resource_id in model_ids:
model.identifier = model_ids[model.provider_resource_id]
# avoid overwriting a non-provider-registered model entry
continue
logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
await self.register_object(
@ -123,5 +146,6 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
provider_id=provider_id,
metadata=model.metadata,
model_type=model.model_type,
source=RegistryEntrySource.listed_from_provider,
)
)

View file

@ -7,9 +7,12 @@
import json
import httpx
from aiohttp import hdrs
from llama_stack.distribution.datatypes import AuthenticationConfig
from llama_stack.distribution.datatypes import AuthenticationConfig, User
from llama_stack.distribution.request_headers import user_from_scope
from llama_stack.distribution.server.auth_providers import create_auth_provider
from llama_stack.distribution.server.routes import find_matching_route, initialize_route_impls
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
@ -78,12 +81,14 @@ class AuthenticationMiddleware:
access resources that don't have access_attributes defined.
"""
def __init__(self, app, auth_config: AuthenticationConfig):
def __init__(self, app, auth_config: AuthenticationConfig, impls):
self.app = app
self.impls = impls
self.auth_provider = create_auth_provider(auth_config)
async def __call__(self, scope, receive, send):
if scope["type"] == "http":
# First, handle authentication
headers = dict(scope.get("headers", []))
auth_header = headers.get(b"authorization", b"").decode()
@ -121,15 +126,50 @@ class AuthenticationMiddleware:
f"Authentication successful: {validation_result.principal} with {len(validation_result.attributes)} attributes"
)
# Scope-based API access control
path = scope.get("path", "")
method = scope.get("method", hdrs.METH_GET)
if not hasattr(self, "route_impls"):
self.route_impls = initialize_route_impls(self.impls)
try:
_, _, _, webmethod = find_matching_route(method, path, self.route_impls)
except ValueError:
# If no matching endpoint is found, pass through to FastAPI
return await self.app(scope, receive, send)
if webmethod.required_scope:
user = user_from_scope(scope)
if not _has_required_scope(webmethod.required_scope, user):
return await self._send_auth_error(
send,
f"Access denied: user does not have required scope: {webmethod.required_scope}",
status=403,
)
return await self.app(scope, receive, send)
async def _send_auth_error(self, send, message):
async def _send_auth_error(self, send, message, status=401):
await send(
{
"type": "http.response.start",
"status": 401,
"status": status,
"headers": [[b"content-type", b"application/json"]],
}
)
error_msg = json.dumps({"error": {"message": message}}).encode()
error_key = "message" if status == 401 else "detail"
error_msg = json.dumps({"error": {error_key: message}}).encode()
await send({"type": "http.response.body", "body": error_msg})
def _has_required_scope(required_scope: str, user: User | None) -> bool:
# if no user, assume auth is not enabled
if not user:
return True
if not user.attributes:
return False
user_scopes = user.attributes.get("scopes", [])
return required_scope in user_scopes

View file

@ -12,17 +12,18 @@ from typing import Any
from aiohttp import hdrs
from starlette.routing import Route
from llama_stack.apis.datatypes import Api, ExternalApiSpec
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
from llama_stack.distribution.resolver import api_protocol_map
from llama_stack.providers.datatypes import Api
from llama_stack.schema_utils import WebMethod
EndpointFunc = Callable[..., Any]
PathParams = dict[str, str]
RouteInfo = tuple[EndpointFunc, str]
RouteInfo = tuple[EndpointFunc, str, WebMethod]
PathImpl = dict[str, RouteInfo]
RouteImpls = dict[str, PathImpl]
RouteMatch = tuple[EndpointFunc, PathParams, str]
RouteMatch = tuple[EndpointFunc, PathParams, str, WebMethod]
def toolgroup_protocol_map():
@ -31,10 +32,12 @@ def toolgroup_protocol_map():
}
def get_all_api_routes() -> dict[Api, list[Route]]:
def get_all_api_routes(
external_apis: dict[Api, ExternalApiSpec] | None = None,
) -> dict[Api, list[tuple[Route, WebMethod]]]:
apis = {}
protocols = api_protocol_map()
protocols = api_protocol_map(external_apis)
toolgroup_protocols = toolgroup_protocol_map()
for api, protocol in protocols.items():
routes = []
@ -65,7 +68,7 @@ def get_all_api_routes() -> dict[Api, list[Route]]:
else:
http_method = hdrs.METH_POST
routes.append(
Route(path=path, methods=[http_method], name=name, endpoint=None)
(Route(path=path, methods=[http_method], name=name, endpoint=None), webmethod)
) # setting endpoint to None since don't use a Router object
apis[api] = routes
@ -73,8 +76,8 @@ def get_all_api_routes() -> dict[Api, list[Route]]:
return apis
def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
routes = get_all_api_routes()
def initialize_route_impls(impls, external_apis: dict[Api, ExternalApiSpec] | None = None) -> RouteImpls:
api_to_routes = get_all_api_routes(external_apis)
route_impls: RouteImpls = {}
def _convert_path_to_regex(path: str) -> str:
@ -88,10 +91,10 @@ def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
return f"^{pattern}$"
for api, api_routes in routes.items():
for api, api_routes in api_to_routes.items():
if api not in impls:
continue
for route in api_routes:
for route, webmethod in api_routes:
impl = impls[api]
func = getattr(impl, route.name)
# Get the first (and typically only) method from the set, filtering out HEAD
@ -104,6 +107,7 @@ def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
route_impls[method][_convert_path_to_regex(route.path)] = (
func,
route.path,
webmethod,
)
return route_impls
@ -118,7 +122,7 @@ def find_matching_route(method: str, path: str, route_impls: RouteImpls) -> Rout
route_impls: A dictionary of endpoint implementations
Returns:
A tuple of (endpoint_function, path_params, descriptive_name)
A tuple of (endpoint_function, path_params, route_path, webmethod_metadata)
Raises:
ValueError: If no matching endpoint is found
@ -127,11 +131,11 @@ def find_matching_route(method: str, path: str, route_impls: RouteImpls) -> Rout
if not impls:
raise ValueError(f"No endpoint found for {path}")
for regex, (func, descriptive_name) in impls.items():
for regex, (func, route_path, webmethod) in impls.items():
match = re.match(regex, path)
if match:
# Extract named groups from the regex match
path_params = match.groupdict()
return func, path_params, descriptive_name
return func, path_params, route_path, webmethod
raise ValueError(f"No endpoint found for {path}")

View file

@ -40,7 +40,12 @@ from llama_stack.distribution.datatypes import (
StackRunConfig,
)
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context
from llama_stack.distribution.external import ExternalApiSpec, load_external_apis
from llama_stack.distribution.request_headers import (
PROVIDER_DATA_VAR,
request_provider_data_context,
user_from_scope,
)
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.server.routes import (
find_matching_route,
@ -222,9 +227,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
@functools.wraps(func)
async def route_handler(request: Request, **kwargs):
# Get auth attributes from the request scope
user_attributes = request.scope.get("user_attributes", {})
principal = request.scope.get("principal", "")
user = User(principal=principal, attributes=user_attributes)
user = user_from_scope(request.scope)
await log_request_pre_validation(request)
@ -282,9 +285,10 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
class TracingMiddleware:
def __init__(self, app, impls):
def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]):
self.app = app
self.impls = impls
self.external_apis = external_apis
# FastAPI built-in paths that should bypass custom routing
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
@ -301,10 +305,12 @@ class TracingMiddleware:
return await self.app(scope, receive, send)
if not hasattr(self, "route_impls"):
self.route_impls = initialize_route_impls(self.impls)
self.route_impls = initialize_route_impls(self.impls, self.external_apis)
try:
_, _, trace_path = find_matching_route(scope.get("method", hdrs.METH_GET), path, self.route_impls)
_, _, route_path, webmethod = find_matching_route(
scope.get("method", hdrs.METH_GET), path, self.route_impls
)
except ValueError:
# If no matching endpoint is found, pass through to FastAPI
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
@ -321,6 +327,7 @@ class TracingMiddleware:
if tracestate:
trace_attributes["tracestate"] = tracestate
trace_path = webmethod.descriptive_name or route_path
trace_context = await start_trace(trace_path, trace_attributes)
async def send_with_trace_id(message):
@ -432,10 +439,21 @@ def main(args: argparse.Namespace | None = None):
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
app.add_middleware(ClientVersionMiddleware)
# Add authentication middleware if configured
try:
# Create and set the event loop that will be used for both construction and server runtime
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Construct the stack in the persistent event loop
impls = loop.run_until_complete(construct_stack(config))
except InvalidProviderError as e:
logger.error(f"Error: {str(e)}")
sys.exit(1)
if config.server.auth:
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth, impls=impls)
else:
if config.server.quota:
quota = config.server.quota
@ -466,24 +484,14 @@ def main(args: argparse.Namespace | None = None):
window_seconds=window_seconds,
)
try:
# Create and set the event loop that will be used for both construction and server runtime
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Construct the stack in the persistent event loop
impls = loop.run_until_complete(construct_stack(config))
except InvalidProviderError as e:
logger.error(f"Error: {str(e)}")
sys.exit(1)
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
else:
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
all_routes = get_all_api_routes()
# Load external APIs if configured
external_apis = load_external_apis(config)
all_routes = get_all_api_routes(external_apis)
if config.apis:
apis_to_serve = set(config.apis)
@ -502,9 +510,12 @@ def main(args: argparse.Namespace | None = None):
api = Api(api_str)
routes = all_routes[api]
impl = impls[api]
try:
impl = impls[api]
except KeyError as e:
raise ValueError(f"Could not find provider implementation for {api} API") from e
for route in routes:
for route, _ in routes:
if not hasattr(impl, route.name):
# ideally this should be a typing violation already
raise ValueError(f"Could not find method {route.name} on {impl}!")
@ -533,7 +544,7 @@ def main(args: argparse.Namespace | None = None):
app.exception_handler(Exception)(global_exception_handler)
app.__llama_stack_impls__ = impls
app.add_middleware(TracingMiddleware, impls=impls)
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
import uvicorn

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import importlib.resources
import os
import re
@ -38,6 +39,7 @@ from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.inspect import DistributionInspectConfig, DistributionInspectImpl
from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
from llama_stack.distribution.routing_tables.common import CommonRoutingTableImpl
from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
@ -90,6 +92,9 @@ RESOURCES = [
]
REGISTRY_REFRESH_INTERVAL_SECONDS = 300
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
for rsrc, api, register_method, list_method in RESOURCES:
objects = getattr(run_config, rsrc)
@ -324,9 +329,33 @@ async def construct_stack(
add_internal_implementations(impls, run_config)
await register_resources(run_config, impls)
task = asyncio.create_task(refresh_registry(impls))
def cb(task):
import traceback
if task.cancelled():
logger.error("Model refresh task cancelled")
elif task.exception():
logger.error(f"Model refresh task failed: {task.exception()}")
traceback.print_exception(task.exception())
else:
logger.debug("Model refresh task completed")
task.add_done_callback(cb)
return impls
async def refresh_registry(impls: dict[Api, Any]):
routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)]
while True:
for routing_table in routing_tables:
await routing_table.refresh()
await asyncio.sleep(REGISTRY_REFRESH_INTERVAL_SECONDS)
def get_stack_run_config_from_template(template: str) -> StackRunConfig:
template_path = importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml"

View file

@ -6,6 +6,7 @@
import logging
import os
import re
import sys
from logging.config import dictConfig
@ -30,6 +31,7 @@ CATEGORIES = [
"eval",
"tools",
"client",
"telemetry",
]
# Initialize category levels with default level
@ -113,6 +115,11 @@ def parse_environment_config(env_config: str) -> dict[str, int]:
return category_levels
def strip_rich_markup(text):
"""Remove Rich markup tags like [dim], [bold magenta], etc."""
return re.sub(r"\[/?[a-zA-Z0-9 _#=,]+\]", "", text)
class CustomRichHandler(RichHandler):
def __init__(self, *args, **kwargs):
kwargs["console"] = Console(width=150)
@ -131,6 +138,19 @@ class CustomRichHandler(RichHandler):
self.markup = original_markup
class CustomFileHandler(logging.FileHandler):
def __init__(self, filename, mode="a", encoding=None, delay=False):
super().__init__(filename, mode, encoding, delay)
# Default formatter to match console output
self.default_formatter = logging.Formatter("%(asctime)s %(name)s:%(lineno)d %(category)s: %(message)s")
self.setFormatter(self.default_formatter)
def emit(self, record):
if hasattr(record, "msg"):
record.msg = strip_rich_markup(str(record.msg))
super().emit(record)
def setup_logging(category_levels: dict[str, int], log_file: str | None) -> None:
"""
Configure logging based on the provided category log levels and an optional log file.
@ -167,8 +187,7 @@ def setup_logging(category_levels: dict[str, int], log_file: str | None) -> None
# Add a file handler if log_file is set
if log_file:
handlers["file"] = {
"class": "logging.FileHandler",
"formatter": "rich",
"()": CustomFileHandler,
"filename": log_file,
"mode": "a",
"encoding": "utf-8",

View file

@ -47,6 +47,17 @@ class ModelsProtocolPrivate(Protocol):
async def unregister_model(self, model_id: str) -> None: ...
# the Stack router will query each provider for their list of models
# if a `refresh_interval_seconds` is provided, this method will be called
# periodically to refresh the list of models
#
# NOTE: each model returned will be registered with the model registry. this means
# a callback to the `register_model()` method will be made. this is duplicative and
# may be removed in the future.
async def list_models(self) -> list[Model] | None: ...
async def should_refresh_models(self) -> bool: ...
class ShieldsProtocolPrivate(Protocol):
async def register_shield(self, shield: Shield) -> None: ...

View file

@ -102,6 +102,12 @@ class MetaReferenceInferenceImpl(
if self.config.create_distributed_process_group:
self.generator.stop()
async def should_refresh_models(self) -> bool:
return False
async def list_models(self) -> list[Model] | None:
return None
async def unregister_model(self, model_id: str) -> None:
pass

View file

@ -20,6 +20,7 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import ModelType
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
@ -41,6 +42,8 @@ class SentenceTransformersInferenceImpl(
InferenceProvider,
ModelsProtocolPrivate,
):
__provider_id__: str
def __init__(self, config: SentenceTransformersInferenceConfig) -> None:
self.config = config
@ -50,6 +53,22 @@ class SentenceTransformersInferenceImpl(
async def shutdown(self) -> None:
pass
async def should_refresh_models(self) -> bool:
return False
async def list_models(self) -> list[Model] | None:
return [
Model(
identifier="all-MiniLM-L6-v2",
provider_resource_id="all-MiniLM-L6-v2",
provider_id=self.__provider_id__,
metadata={
"embedding_dimension": 384,
},
model_type=ModelType.embedding,
),
]
async def register_model(self, model: Model) -> Model:
return model

View file

@ -11,19 +11,9 @@ from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.trace.export import SpanProcessor
from opentelemetry.trace.status import StatusCode
# Colors for console output
COLORS = {
"reset": "\033[0m",
"bold": "\033[1m",
"dim": "\033[2m",
"red": "\033[31m",
"green": "\033[32m",
"yellow": "\033[33m",
"blue": "\033[34m",
"magenta": "\033[35m",
"cyan": "\033[36m",
"white": "\033[37m",
}
from llama_stack.log import get_logger
logger = get_logger(name="console_span_processor", category="telemetry")
class ConsoleSpanProcessor(SpanProcessor):
@ -35,34 +25,21 @@ class ConsoleSpanProcessor(SpanProcessor):
return
timestamp = datetime.fromtimestamp(span.start_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
print(
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
f"{COLORS['magenta']}[START]{COLORS['reset']} "
f"{COLORS['dim']}{span.name}{COLORS['reset']}"
)
logger.info(f"[dim]{timestamp}[/dim] [bold magenta][START][/bold magenta] [dim]{span.name}[/dim]")
def on_end(self, span: ReadableSpan) -> None:
if span.attributes and span.attributes.get("__autotraced__"):
return
timestamp = datetime.fromtimestamp(span.end_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
span_context = (
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
f"{COLORS['magenta']}[END]{COLORS['reset']} "
f"{COLORS['dim']}{span.name}{COLORS['reset']}"
)
span_context = f"[dim]{timestamp}[/dim] [bold magenta][END][/bold magenta] [dim]{span.name}[/dim]"
if span.status.status_code == StatusCode.ERROR:
span_context += f"{COLORS['reset']} {COLORS['red']}[ERROR]{COLORS['reset']}"
span_context += " [bold red][ERROR][/bold red]"
elif span.status.status_code != StatusCode.UNSET:
span_context += f"{COLORS['reset']} [{span.status.status_code}]"
span_context += f" [{span.status.status_code}]"
duration_ms = (span.end_time - span.start_time) / 1e6
span_context += f"{COLORS['reset']} ({duration_ms:.2f}ms)"
print(span_context)
span_context += f" ({duration_ms:.2f}ms)"
logger.info(span_context)
if self.print_attributes and span.attributes:
for key, value in span.attributes.items():
@ -71,31 +48,26 @@ class ConsoleSpanProcessor(SpanProcessor):
str_value = str(value)
if len(str_value) > 1000:
str_value = str_value[:997] + "..."
print(f" {COLORS['dim']}{key}: {str_value}{COLORS['reset']}")
logger.info(f" [dim]{key}[/dim]: {str_value}")
for event in span.events:
event_time = datetime.fromtimestamp(event.timestamp / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
severity = event.attributes.get("severity", "info")
message = event.attributes.get("message", event.name)
if isinstance(message, dict | list):
if isinstance(message, dict) or isinstance(message, list):
message = json.dumps(message, indent=2)
severity_colors = {
"error": f"{COLORS['bold']}{COLORS['red']}",
"warn": f"{COLORS['bold']}{COLORS['yellow']}",
"info": COLORS["white"],
"debug": COLORS["dim"],
}
msg_color = severity_colors.get(severity, COLORS["white"])
print(f" {event_time} {msg_color}[{severity.upper()}] {message}{COLORS['reset']}")
severity_color = {
"error": "red",
"warn": "yellow",
"info": "white",
"debug": "dim",
}.get(severity, "white")
logger.info(f" {event_time} [bold {severity_color}][{severity.upper()}][/bold {severity_color}] {message}")
if event.attributes:
for key, value in event.attributes.items():
if key.startswith("__") or key in ["message", "severity"]:
continue
print(f" {COLORS['dim']}{key}: {value}{COLORS['reset']}")
logger.info(f"/r[dim]{key}[/dim]: {value}")
def shutdown(self) -> None:
"""Shutdown the processor."""

View file

@ -6,13 +6,14 @@
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from pydantic import Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class FireworksImplConfig(BaseModel):
class FireworksImplConfig(RemoteInferenceProviderConfig):
url: str = Field(
default="https://api.fireworks.ai/inference/v1",
description="The URL for the Fireworks server",

View file

@ -70,7 +70,7 @@ logger = get_logger(name=__name__, category="inference")
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
def __init__(self, config: FireworksImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
self.config = config
async def initialize(self) -> None:

View file

@ -13,8 +13,10 @@ DEFAULT_OLLAMA_URL = "http://localhost:11434"
class OllamaImplConfig(BaseModel):
url: str = DEFAULT_OLLAMA_URL
refresh_models: bool = Field(default=False, description="refresh and re-register models periodically")
refresh_models_interval: int = Field(default=300, description="interval in seconds to refresh models")
refresh_models: bool = Field(
default=False,
description="Whether to refresh models periodically",
)
@classmethod
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]:

View file

@ -98,14 +98,16 @@ class OllamaInferenceAdapter(
def __init__(self, config: OllamaImplConfig) -> None:
self.register_helper = ModelRegistryHelper(MODEL_ENTRIES)
self.config = config
self._client = None
self._clients: dict[asyncio.AbstractEventLoop, AsyncClient] = {}
self._openai_client = None
@property
def client(self) -> AsyncClient:
if self._client is None:
self._client = AsyncClient(host=self.config.url)
return self._client
# ollama client attaches itself to the current event loop (sadly?)
loop = asyncio.get_running_loop()
if loop not in self._clients:
self._clients[loop] = AsyncClient(host=self.config.url)
return self._clients[loop]
@property
def openai_client(self) -> AsyncOpenAI:
@ -121,59 +123,61 @@ class OllamaInferenceAdapter(
"Ollama Server is not running, make sure to start it using `ollama serve` in a separate terminal"
)
if self.config.refresh_models:
logger.debug("ollama starting background model refresh task")
self._refresh_task = asyncio.create_task(self._refresh_models())
def cb(task):
if task.cancelled():
import traceback
logger.error(f"ollama background refresh task canceled:\n{''.join(traceback.format_stack())}")
elif task.exception():
logger.error(f"ollama background refresh task died: {task.exception()}")
else:
logger.error("ollama background refresh task completed unexpectedly")
self._refresh_task.add_done_callback(cb)
async def _refresh_models(self) -> None:
# Wait for model store to be available (with timeout)
waited_time = 0
while not self.model_store and waited_time < 60:
await asyncio.sleep(1)
waited_time += 1
if not self.model_store:
raise ValueError("Model store not set after waiting 60 seconds")
async def should_refresh_models(self) -> bool:
return self.config.refresh_models
async def list_models(self) -> list[Model] | None:
provider_id = self.__provider_id__
while True:
try:
response = await self.client.list()
except Exception as e:
logger.warning(f"Failed to list models: {str(e)}")
await asyncio.sleep(self.config.refresh_models_interval)
response = await self.client.list()
# always add the two embedding models which can be pulled on demand
models = [
Model(
identifier="all-minilm:l6-v2",
provider_resource_id="all-minilm:l6-v2",
provider_id=provider_id,
metadata={
"embedding_dimension": 384,
"context_length": 512,
},
model_type=ModelType.embedding,
),
# add all-minilm alias
Model(
identifier="all-minilm",
provider_resource_id="all-minilm:l6-v2",
provider_id=provider_id,
metadata={
"embedding_dimension": 384,
"context_length": 512,
},
model_type=ModelType.embedding,
),
Model(
identifier="nomic-embed-text",
provider_resource_id="nomic-embed-text",
provider_id=provider_id,
metadata={
"embedding_dimension": 768,
"context_length": 8192,
},
model_type=ModelType.embedding,
),
]
for m in response.models:
# kill embedding models since we don't know dimensions for them
if m.details.family in ["bert"]:
continue
models = []
for m in response.models:
model_type = ModelType.embedding if m.details.family in ["bert"] else ModelType.llm
if model_type == ModelType.embedding:
continue
models.append(
Model(
identifier=m.model,
provider_resource_id=m.model,
provider_id=provider_id,
metadata={},
model_type=model_type,
)
models.append(
Model(
identifier=m.model,
provider_resource_id=m.model,
provider_id=provider_id,
metadata={},
model_type=ModelType.llm,
)
await self.model_store.update_registered_llm_models(provider_id, models)
logger.debug(f"ollama refreshed model list ({len(models)} models)")
await asyncio.sleep(self.config.refresh_models_interval)
)
return models
async def health(self) -> HealthResponse:
"""
@ -190,12 +194,7 @@ class OllamaInferenceAdapter(
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
async def shutdown(self) -> None:
if hasattr(self, "_refresh_task") and not self._refresh_task.done():
logger.debug("ollama cancelling background refresh task")
self._refresh_task.cancel()
self._client = None
self._openai_client = None
self._clients.clear()
async def unregister_model(self, model_id: str) -> None:
pass

View file

@ -12,11 +12,6 @@ from llama_stack.providers.utils.inference.model_registry import (
)
LLM_MODEL_IDS = [
# the models w/ "openai/" prefix are the litellm specific model names.
# they should be deprecated in favor of the canonical openai model names.
"openai/gpt-4o",
"openai/gpt-4o-mini",
"openai/chatgpt-4o-latest",
"gpt-3.5-turbo-0125",
"gpt-3.5-turbo",
"gpt-3.5-turbo-instruct",
@ -43,8 +38,6 @@ class EmbeddingModelInfo:
EMBEDDING_MODEL_IDS: dict[str, EmbeddingModelInfo] = {
"openai/text-embedding-3-small": EmbeddingModelInfo(1536, 8192),
"openai/text-embedding-3-large": EmbeddingModelInfo(3072, 8192),
"text-embedding-3-small": EmbeddingModelInfo(1536, 8192),
"text-embedding-3-large": EmbeddingModelInfo(3072, 8192),
}

View file

@ -6,13 +6,14 @@
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from pydantic import Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class TogetherImplConfig(BaseModel):
class TogetherImplConfig(RemoteInferenceProviderConfig):
url: str = Field(
default="https://api.together.xyz/v1",
description="The URL for the Together AI server",

View file

@ -66,7 +66,7 @@ logger = get_logger(name=__name__, category="inference")
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
def __init__(self, config: TogetherImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
self.config = config
async def initialize(self) -> None:

View file

@ -33,10 +33,6 @@ class VLLMInferenceAdapterConfig(BaseModel):
default=False,
description="Whether to refresh models periodically",
)
refresh_models_interval: int = Field(
default=300,
description="Interval in seconds to refresh models",
)
@field_validator("tls_verify")
@classmethod

View file

@ -3,7 +3,6 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
@ -293,7 +292,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
# automatically set by the resolver when instantiating the provider
__provider_id__: str
model_store: ModelStore | None = None
_refresh_task: asyncio.Task | None = None
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
@ -301,65 +299,30 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
self.client = None
async def initialize(self) -> None:
if not self.config.url:
# intentionally don't raise an error here, we want to allow the provider to be "dormant"
# or available in distributions like "starter" without causing a ruckus
return
pass
if self.config.refresh_models:
self._refresh_task = asyncio.create_task(self._refresh_models())
def cb(task):
import traceback
if task.cancelled():
log.error(f"vLLM background refresh task canceled:\n{''.join(traceback.format_stack())}")
elif task.exception():
# print the stack trace for the exception
exc = task.exception()
log.error(f"vLLM background refresh task died: {exc}")
traceback.print_exception(exc)
else:
log.error("vLLM background refresh task completed unexpectedly")
self._refresh_task.add_done_callback(cb)
async def _refresh_models(self) -> None:
provider_id = self.__provider_id__
waited_time = 0
while not self.model_store and waited_time < 60:
await asyncio.sleep(1)
waited_time += 1
if not self.model_store:
raise ValueError("Model store not set after waiting 60 seconds")
async def should_refresh_models(self) -> bool:
return self.config.refresh_models
async def list_models(self) -> list[Model] | None:
self._lazy_initialize_client()
assert self.client is not None # mypy
while True:
try:
models = []
async for m in self.client.models.list():
model_type = ModelType.llm # unclear how to determine embedding vs. llm models
models.append(
Model(
identifier=m.id,
provider_resource_id=m.id,
provider_id=provider_id,
metadata={},
model_type=model_type,
)
)
await self.model_store.update_registered_llm_models(provider_id, models)
log.debug(f"vLLM refreshed model list ({len(models)} models)")
except Exception as e:
log.error(f"vLLM background refresh task failed: {e}")
await asyncio.sleep(self.config.refresh_models_interval)
models = []
async for m in self.client.models.list():
model_type = ModelType.llm # unclear how to determine embedding vs. llm models
models.append(
Model(
identifier=m.id,
provider_resource_id=m.id,
provider_id=self.__provider_id__,
metadata={},
model_type=model_type,
)
)
return models
async def shutdown(self) -> None:
if self._refresh_task:
self._refresh_task.cancel()
self._refresh_task = None
pass
async def unregister_model(self, model_id: str) -> None:
pass

View file

@ -20,6 +20,13 @@ from llama_stack.providers.utils.inference import (
logger = get_logger(name=__name__, category="core")
class RemoteInferenceProviderConfig(BaseModel):
allowed_models: list[str] | None = Field(
default=None,
description="List of models that should be registered with the model registry. If None, all models are allowed.",
)
# TODO: this class is more confusing than useful right now. We need to make it
# more closer to the Model class.
class ProviderModelEntry(BaseModel):
@ -65,7 +72,10 @@ def build_model_entry(provider_model_id: str, model_descriptor: str) -> Provider
class ModelRegistryHelper(ModelsProtocolPrivate):
def __init__(self, model_entries: list[ProviderModelEntry]):
__provider_id__: str
def __init__(self, model_entries: list[ProviderModelEntry], allowed_models: list[str] | None = None):
self.allowed_models = allowed_models
self.alias_to_provider_id_map = {}
self.provider_id_to_llama_model_map = {}
for entry in model_entries:
@ -79,6 +89,27 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
self.alias_to_provider_id_map[entry.llama_model] = entry.provider_model_id
self.provider_id_to_llama_model_map[entry.provider_model_id] = entry.llama_model
async def list_models(self) -> list[Model] | None:
models = []
for entry in self.model_entries:
ids = [entry.provider_model_id] + entry.aliases
for id in ids:
if self.allowed_models and id not in self.allowed_models:
continue
models.append(
Model(
model_id=id,
provider_resource_id=entry.provider_model_id,
model_type=ModelType.llm,
metadata=entry.metadata,
provider_id=self.__provider_id__,
)
)
return models
async def should_refresh_models(self) -> bool:
return False
def get_provider_model_id(self, identifier: str) -> str | None:
return self.alias_to_provider_id_map.get(identifier, None)

View file

@ -4,13 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from enum import Enum
from typing import Any, cast
import httpx
from mcp import ClientSession
from mcp import ClientSession, McpError
from mcp import types as mcp_types
from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client
from llama_stack.apis.common.content_types import ImageContentItem, InterleavedContentItem, TextContentItem
from llama_stack.apis.tools import (
@ -21,31 +24,61 @@ from llama_stack.apis.tools import (
)
from llama_stack.distribution.datatypes import AuthenticationRequiredError
from llama_stack.log import get_logger
from llama_stack.providers.utils.tools.ttl_dict import TTLDict
logger = get_logger(__name__, category="tools")
protocol_cache = TTLDict(ttl_seconds=3600)
class MCPProtol(Enum):
UNKNOWN = 0
STREAMABLE_HTTP = 1
SSE = 2
@asynccontextmanager
async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
try:
async with sse_client(endpoint, headers=headers) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
yield session
except* httpx.HTTPStatusError as eg:
for exc in eg.exceptions:
# mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter,
# so we explicitly cast each item to httpx.HTTPStatusError. This is safe because
# `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type.
err = cast(httpx.HTTPStatusError, exc)
if err.response.status_code == 401:
raise AuthenticationRequiredError(exc) from exc
raise
async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]:
# we use a ttl'd dict to cache the happy path protocol for each endpoint
# but, we always fall back to trying the other protocol if we cannot initialize the session
connection_strategies = [MCPProtol.STREAMABLE_HTTP, MCPProtol.SSE]
mcp_protocol = protocol_cache.get(endpoint, default=MCPProtol.UNKNOWN)
if mcp_protocol == MCPProtol.SSE:
connection_strategies = [MCPProtol.SSE, MCPProtol.STREAMABLE_HTTP]
for i, strategy in enumerate(connection_strategies):
try:
client = streamablehttp_client
if strategy == MCPProtol.SSE:
client = sse_client
async with client(endpoint, headers=headers) as client_streams:
async with ClientSession(read_stream=client_streams[0], write_stream=client_streams[1]) as session:
await session.initialize()
protocol_cache[endpoint] = strategy
yield session
return
except* httpx.HTTPStatusError as eg:
for exc in eg.exceptions:
# mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter,
# so we explicitly cast each item to httpx.HTTPStatusError. This is safe because
# `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type.
err = cast(httpx.HTTPStatusError, exc)
if err.response.status_code == 401:
raise AuthenticationRequiredError(exc) from exc
if i == len(connection_strategies) - 1:
raise
except* McpError:
if i < len(connection_strategies) - 1:
logger.warning(
f"failed to connect via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
)
else:
raise
async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse:
tools = []
async with sse_client_wrapper(endpoint, headers) as session:
async with client_wrapper(endpoint, headers) as session:
tools_result = await session.list_tools()
for tool in tools_result.tools:
parameters = []
@ -73,7 +106,7 @@ async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefs
async def invoke_mcp_tool(
endpoint: str, headers: dict[str, str], tool_name: str, kwargs: dict[str, Any]
) -> ToolInvocationResult:
async with sse_client_wrapper(endpoint, headers) as session:
async with client_wrapper(endpoint, headers) as session:
result = await session.call_tool(tool_name, kwargs)
content: list[InterleavedContentItem] = []

View file

@ -0,0 +1,70 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import time
from threading import RLock
from typing import Any
class TTLDict(dict):
"""
A dictionary with a ttl for each item
"""
def __init__(self, ttl_seconds: float, *args, **kwargs):
super().__init__(*args, **kwargs)
self.ttl_seconds = ttl_seconds
self._expires: dict[Any, Any] = {} # expires holds when an item will expire
self._lock = RLock()
if args or kwargs:
for k, v in self.items():
self.__setitem__(k, v)
def __delitem__(self, key):
with self._lock:
del self._expires[key]
super().__delitem__(key)
def __setitem__(self, key, value):
with self._lock:
self._expires[key] = time.monotonic() + self.ttl_seconds
super().__setitem__(key, value)
def _is_expired(self, key):
if key not in self._expires:
return False
return time.monotonic() > self._expires[key]
def __getitem__(self, key):
with self._lock:
if self._is_expired(key):
del self._expires[key]
super().__delitem__(key)
raise KeyError(f"{key} has expired and was removed")
return super().__getitem__(key)
def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
def __contains__(self, key):
try:
_ = self[key]
return True
except KeyError:
return False
def __repr__(self):
with self._lock:
for key in self.keys():
if self._is_expired(key):
del self._expires[key]
super().__delitem__(key)
return f"TTLDict({self.ttl_seconds}, {super().__repr__()})"

View file

@ -22,6 +22,7 @@ class WebMethod:
# A descriptive name of the corresponding span created by tracing
descriptive_name: str | None = None
experimental: bool | None = False
required_scope: str | None = None
T = TypeVar("T", bound=Callable[..., Any])
@ -36,6 +37,7 @@ def webmethod(
raw_bytes_request_body: bool | None = False,
descriptive_name: str | None = None,
experimental: bool | None = False,
required_scope: str | None = None,
) -> Callable[[T], T]:
"""
Decorator that supplies additional metadata to an endpoint operation function.
@ -45,6 +47,7 @@ def webmethod(
:param request_examples: Sample requests that the operation might take. Pass a list of objects, not JSON.
:param response_examples: Sample responses that the operation might produce. Pass a list of objects, not JSON.
:param experimental: True if the operation is experimental and subject to change.
:param required_scope: Required scope for this endpoint (e.g., 'monitoring.viewer').
"""
def wrap(func: T) -> T:
@ -57,6 +60,7 @@ def webmethod(
raw_bytes_request_body=raw_bytes_request_body,
descriptive_name=descriptive_name,
experimental=experimental,
required_scope=required_scope,
)
return func

View file

@ -785,21 +785,6 @@ models:
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.2-3B
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_OPENAI:=__disabled__}/openai/gpt-4o
provider_id: ${env.ENABLE_OPENAI:=__disabled__}
provider_model_id: openai/gpt-4o
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_OPENAI:=__disabled__}/openai/gpt-4o-mini
provider_id: ${env.ENABLE_OPENAI:=__disabled__}
provider_model_id: openai/gpt-4o-mini
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_OPENAI:=__disabled__}/openai/chatgpt-4o-latest
provider_id: ${env.ENABLE_OPENAI:=__disabled__}
provider_model_id: openai/chatgpt-4o-latest
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_OPENAI:=__disabled__}/gpt-3.5-turbo-0125
provider_id: ${env.ENABLE_OPENAI:=__disabled__}
@ -870,20 +855,6 @@ models:
provider_id: ${env.ENABLE_OPENAI:=__disabled__}
provider_model_id: o4-mini
model_type: llm
- metadata:
embedding_dimension: 1536
context_length: 8192
model_id: ${env.ENABLE_OPENAI:=__disabled__}/openai/text-embedding-3-small
provider_id: ${env.ENABLE_OPENAI:=__disabled__}
provider_model_id: openai/text-embedding-3-small
model_type: embedding
- metadata:
embedding_dimension: 3072
context_length: 8192
model_id: ${env.ENABLE_OPENAI:=__disabled__}/openai/text-embedding-3-large
provider_id: ${env.ENABLE_OPENAI:=__disabled__}
provider_model_id: openai/text-embedding-3-large
model_type: embedding
- metadata:
embedding_dimension: 1536
context_length: 8192

View file

@ -1,3 +1,6 @@
---
orphan: true
---
# NVIDIA Distribution
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations.

View file

@ -785,21 +785,6 @@ models:
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.2-3B
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_OPENAI:=__disabled__}/openai/gpt-4o
provider_id: ${env.ENABLE_OPENAI:=__disabled__}
provider_model_id: openai/gpt-4o
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_OPENAI:=__disabled__}/openai/gpt-4o-mini
provider_id: ${env.ENABLE_OPENAI:=__disabled__}
provider_model_id: openai/gpt-4o-mini
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_OPENAI:=__disabled__}/openai/chatgpt-4o-latest
provider_id: ${env.ENABLE_OPENAI:=__disabled__}
provider_model_id: openai/chatgpt-4o-latest
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_OPENAI:=__disabled__}/gpt-3.5-turbo-0125
provider_id: ${env.ENABLE_OPENAI:=__disabled__}
@ -870,20 +855,6 @@ models:
provider_id: ${env.ENABLE_OPENAI:=__disabled__}
provider_model_id: o4-mini
model_type: llm
- metadata:
embedding_dimension: 1536
context_length: 8192
model_id: ${env.ENABLE_OPENAI:=__disabled__}/openai/text-embedding-3-small
provider_id: ${env.ENABLE_OPENAI:=__disabled__}
provider_model_id: openai/text-embedding-3-small
model_type: embedding
- metadata:
embedding_dimension: 3072
context_length: 8192
model_id: ${env.ENABLE_OPENAI:=__disabled__}/openai/text-embedding-3-large
provider_id: ${env.ENABLE_OPENAI:=__disabled__}
provider_model_id: openai/text-embedding-3-large
model_type: embedding
- metadata:
embedding_dimension: 1536
context_length: 8192

View file

@ -15,7 +15,7 @@
"@radix-ui/react-tooltip": "^1.2.6",
"class-variance-authority": "^0.7.1",
"clsx": "^2.1.1",
"llama-stack-client": "^0.2.14",
"llama-stack-client": "^0.2.15",
"lucide-react": "^0.510.0",
"next": "15.3.3",
"next-auth": "^4.24.11",
@ -6468,14 +6468,15 @@
}
},
"node_modules/form-data": {
"version": "4.0.2",
"resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.2.tgz",
"integrity": "sha512-hGfm/slu0ZabnNt4oaRZ6uREyfCj6P4fT/n6A1rGV+Z0VdGXjfOhVUpkn6qVQONHGIFwmveGXyDs75+nr6FM8w==",
"version": "4.0.4",
"resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.4.tgz",
"integrity": "sha512-KrGhL9Q4zjj0kiUt5OO4Mr/A/jlI2jDYs5eHBpYHPcBEVSiipAvn2Ko2HnPe20rmcuuvMHNdZFp+4IlGTMF0Ow==",
"license": "MIT",
"dependencies": {
"asynckit": "^0.4.0",
"combined-stream": "^1.0.8",
"es-set-tostringtag": "^2.1.0",
"hasown": "^2.0.2",
"mime-types": "^2.1.12"
},
"engines": {
@ -9099,9 +9100,9 @@
"license": "MIT"
},
"node_modules/llama-stack-client": {
"version": "0.2.14",
"resolved": "https://registry.npmjs.org/llama-stack-client/-/llama-stack-client-0.2.14.tgz",
"integrity": "sha512-bVU3JHp+EPEKR0Vb9vcd9ZyQj/72jSDuptKLwOXET9WrkphIQ8xuW5ueecMTgq8UEls3lwB3HiZM2cDOR9eDsQ==",
"version": "0.2.15",
"resolved": "https://registry.npmjs.org/llama-stack-client/-/llama-stack-client-0.2.15.tgz",
"integrity": "sha512-onfYzgPWAxve4uP7BuiK/ZdEC7w6X1PIXXXpQY57qZC7C4xUAM5kwfT3JWIe/jE22Lwc2vTN1ScfYlAYcoYAsg==",
"license": "Apache-2.0",
"dependencies": {
"@types/node": "^18.11.18",