mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
Merge branch 'main' into remove-deprecated-embeddings
This commit is contained in:
commit
5c44dcdf0e
770 changed files with 176834 additions and 27431 deletions
|
@ -147,7 +147,7 @@ WORKDIR /app
|
|||
|
||||
RUN dnf -y update && dnf install -y iputils git net-tools wget \
|
||||
vim-minimal python3.12 python3.12-pip python3.12-wheel \
|
||||
python3.12-setuptools python3.12-devel gcc make && \
|
||||
python3.12-setuptools python3.12-devel gcc gcc-c++ make && \
|
||||
ln -s /bin/pip3.12 /bin/pip && ln -s /bin/python3.12 /bin/python && dnf clean all
|
||||
|
||||
ENV UV_SYSTEM_PYTHON=1
|
||||
|
@ -164,7 +164,7 @@ RUN apt-get update && apt-get install -y \
|
|||
procps psmisc lsof \
|
||||
traceroute \
|
||||
bubblewrap \
|
||||
gcc \
|
||||
gcc g++ \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
ENV UV_SYSTEM_PYTHON=1
|
||||
|
|
|
@ -15,7 +15,6 @@ import httpx
|
|||
from pydantic import BaseModel, parse_obj_as
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
||||
from llama_stack.providers.datatypes import RemoteProviderConfig
|
||||
|
||||
_CLIENT_CLASSES = {}
|
||||
|
@ -114,7 +113,24 @@ def create_api_client_class(protocol) -> type:
|
|||
break
|
||||
kwargs[param.name] = args[i]
|
||||
|
||||
url = f"{self.base_url}/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
|
||||
# Get all webmethods for this method (supports multiple decorators)
|
||||
webmethods = getattr(method, "__webmethods__", [])
|
||||
|
||||
if not webmethods:
|
||||
raise RuntimeError(f"Method {method} has no webmethod decorators")
|
||||
|
||||
# Choose the preferred webmethod (non-deprecated if available)
|
||||
preferred_webmethod = None
|
||||
for wm in webmethods:
|
||||
if not getattr(wm, "deprecated", False):
|
||||
preferred_webmethod = wm
|
||||
break
|
||||
|
||||
# If no non-deprecated found, use the first one
|
||||
if preferred_webmethod is None:
|
||||
preferred_webmethod = webmethods[0]
|
||||
|
||||
url = f"{self.base_url}/{preferred_webmethod.level}/{preferred_webmethod.route.lstrip('/')}"
|
||||
|
||||
def convert(value):
|
||||
if isinstance(value, list):
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, Literal, Self
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
|
@ -120,10 +121,6 @@ class AutoRoutedProviderSpec(ProviderSpec):
|
|||
default=None,
|
||||
)
|
||||
|
||||
@property
|
||||
def pip_packages(self) -> list[str]:
|
||||
raise AssertionError("Should not be called on AutoRoutedProviderSpec")
|
||||
|
||||
|
||||
# Example: /models, /shields
|
||||
class RoutingTableProviderSpec(ProviderSpec):
|
||||
|
@ -212,6 +209,7 @@ class AuthProviderType(StrEnum):
|
|||
OAUTH2_TOKEN = "oauth2_token"
|
||||
GITHUB_TOKEN = "github_token"
|
||||
CUSTOM = "custom"
|
||||
KUBERNETES = "kubernetes"
|
||||
|
||||
|
||||
class OAuth2TokenAuthConfig(BaseModel):
|
||||
|
@ -282,8 +280,45 @@ class GitHubTokenAuthConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class KubernetesAuthProviderConfig(BaseModel):
|
||||
"""Configuration for Kubernetes authentication provider."""
|
||||
|
||||
type: Literal[AuthProviderType.KUBERNETES] = AuthProviderType.KUBERNETES
|
||||
api_server_url: str = Field(
|
||||
default="https://kubernetes.default.svc",
|
||||
description="Kubernetes API server URL (e.g., https://api.cluster.domain:6443)",
|
||||
)
|
||||
verify_tls: bool = Field(default=True, description="Whether to verify TLS certificates")
|
||||
tls_cafile: Path | None = Field(default=None, description="Path to CA certificate file for TLS verification")
|
||||
claims_mapping: dict[str, str] = Field(
|
||||
default_factory=lambda: {
|
||||
"username": "roles",
|
||||
"groups": "roles",
|
||||
},
|
||||
description="Mapping of Kubernetes user claims to access attributes",
|
||||
)
|
||||
|
||||
@field_validator("api_server_url")
|
||||
@classmethod
|
||||
def validate_api_server_url(cls, v):
|
||||
parsed = urlparse(v)
|
||||
if not parsed.scheme or not parsed.netloc:
|
||||
raise ValueError(f"api_server_url must be a valid URL with scheme and host: {v}")
|
||||
if parsed.scheme not in ["http", "https"]:
|
||||
raise ValueError(f"api_server_url scheme must be http or https: {v}")
|
||||
return v
|
||||
|
||||
@field_validator("claims_mapping")
|
||||
@classmethod
|
||||
def validate_claims_mapping(cls, v):
|
||||
for key, value in v.items():
|
||||
if not value:
|
||||
raise ValueError(f"claims_mapping value cannot be empty: {key}")
|
||||
return v
|
||||
|
||||
|
||||
AuthProviderConfig = Annotated[
|
||||
OAuth2TokenAuthConfig | GitHubTokenAuthConfig | CustomAuthConfig,
|
||||
OAuth2TokenAuthConfig | GitHubTokenAuthConfig | CustomAuthConfig | KubernetesAuthProviderConfig,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
@ -392,6 +427,12 @@ class ServerConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class InferenceStoreConfig(BaseModel):
|
||||
sql_store_config: SqlStoreConfig
|
||||
max_write_queue_size: int = Field(default=10000, description="Max queued writes for inference store")
|
||||
num_writers: int = Field(default=4, description="Number of concurrent background writers")
|
||||
|
||||
|
||||
class StackRunConfig(BaseModel):
|
||||
version: int = LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
|
||||
|
@ -425,11 +466,12 @@ Configuration for the persistence store used by the distribution registry. If no
|
|||
a default SQLite store will be used.""",
|
||||
)
|
||||
|
||||
inference_store: SqlStoreConfig | None = Field(
|
||||
inference_store: InferenceStoreConfig | SqlStoreConfig | None = Field(
|
||||
default=None,
|
||||
description="""
|
||||
Configuration for the persistence store used by the inference API. If not specified,
|
||||
a default SQLite store will be used.""",
|
||||
Configuration for the persistence store used by the inference API. Can be either a
|
||||
InferenceStoreConfig (with queue tuning parameters) or a SqlStoreConfig (deprecated).
|
||||
If not specified, a default SQLite store will be used.""",
|
||||
)
|
||||
|
||||
# registry of "resources" in the distribution
|
||||
|
|
|
@ -16,16 +16,18 @@ from llama_stack.core.datatypes import BuildConfig, DistributionSpec
|
|||
from llama_stack.core.external import load_external_apis
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
RemoteProviderSpec,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
INTERNAL_APIS = {Api.inspect, Api.providers, Api.prompts}
|
||||
|
||||
|
||||
def stack_apis() -> list[Api]:
|
||||
return list(Api)
|
||||
|
||||
|
@ -70,31 +72,16 @@ def builtin_automatically_routed_apis() -> list[AutoRoutedApiInfo]:
|
|||
|
||||
def providable_apis() -> list[Api]:
|
||||
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
||||
return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers]
|
||||
return [api for api in Api if api not in routing_table_apis and api not in INTERNAL_APIS]
|
||||
|
||||
|
||||
def _load_remote_provider_spec(spec_data: dict[str, Any], api: Api) -> ProviderSpec:
|
||||
adapter = AdapterSpec(**spec_data["adapter"])
|
||||
spec = remote_provider_spec(
|
||||
api=api,
|
||||
adapter=adapter,
|
||||
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
|
||||
)
|
||||
spec = RemoteProviderSpec(api=api, provider_type=f"remote::{spec_data['adapter_type']}", **spec_data)
|
||||
return spec
|
||||
|
||||
|
||||
def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_name: str) -> ProviderSpec:
|
||||
spec = InlineProviderSpec(
|
||||
api=api,
|
||||
provider_type=f"inline::{provider_name}",
|
||||
pip_packages=spec_data.get("pip_packages", []),
|
||||
module=spec_data["module"],
|
||||
config_class=spec_data["config_class"],
|
||||
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
|
||||
optional_api_dependencies=[Api(dep) for dep in spec_data.get("optional_api_dependencies", [])],
|
||||
provider_data_validator=spec_data.get("provider_data_validator"),
|
||||
container_image=spec_data.get("container_image"),
|
||||
)
|
||||
spec = InlineProviderSpec(api=api, provider_type=f"inline::{provider_name}", **spec_data)
|
||||
return spec
|
||||
|
||||
|
||||
|
|
|
@ -10,7 +10,6 @@ import json
|
|||
import logging # allow-direct-logging
|
||||
import os
|
||||
import sys
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
@ -41,7 +40,7 @@ from llama_stack.core.request_headers import (
|
|||
from llama_stack.core.resolver import ProviderRegistry
|
||||
from llama_stack.core.server.routes import RouteImpls, find_matching_route, initialize_route_impls
|
||||
from llama_stack.core.stack import (
|
||||
construct_stack,
|
||||
Stack,
|
||||
get_stack_run_config_from_distro,
|
||||
replace_env_vars,
|
||||
)
|
||||
|
@ -148,7 +147,6 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
self.async_client = AsyncLlamaStackAsLibraryClient(
|
||||
config_path_or_distro_name, custom_provider_registry, provider_data, skip_logger_removal
|
||||
)
|
||||
self.pool_executor = ThreadPoolExecutor(max_workers=4)
|
||||
self.provider_data = provider_data
|
||||
|
||||
self.loop = asyncio.new_event_loop()
|
||||
|
@ -254,7 +252,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
|
||||
try:
|
||||
self.route_impls = None
|
||||
self.impls = await construct_stack(self.config, self.custom_provider_registry)
|
||||
|
||||
stack = Stack(self.config, self.custom_provider_registry)
|
||||
await stack.initialize()
|
||||
self.impls = stack.impls
|
||||
except ModuleNotFoundError as _e:
|
||||
cprint(_e.msg, color="red", file=sys.stderr)
|
||||
cprint(
|
||||
|
@ -291,6 +292,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
)
|
||||
raise _e
|
||||
|
||||
assert self.impls is not None
|
||||
if Api.telemetry in self.impls:
|
||||
setup_logger(self.impls[Api.telemetry])
|
||||
|
||||
|
|
5
llama_stack/core/prompts/__init__.py
Normal file
5
llama_stack/core/prompts/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
233
llama_stack/core/prompts/prompts.py
Normal file
233
llama_stack/core/prompts/prompts.py
Normal file
|
@ -0,0 +1,233 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.prompts import ListPromptsResponse, Prompt, Prompts
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
|
||||
class PromptServiceConfig(BaseModel):
|
||||
"""Configuration for the built-in prompt service.
|
||||
|
||||
:param run_config: Stack run configuration containing distribution info
|
||||
"""
|
||||
|
||||
run_config: StackRunConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: PromptServiceConfig, deps: dict[Any, Any]):
|
||||
"""Get the prompt service implementation."""
|
||||
impl = PromptServiceImpl(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
||||
|
||||
class PromptServiceImpl(Prompts):
|
||||
"""Built-in prompt service implementation using KVStore."""
|
||||
|
||||
def __init__(self, config: PromptServiceConfig, deps: dict[Any, Any]):
|
||||
self.config = config
|
||||
self.deps = deps
|
||||
self.kvstore: KVStore
|
||||
|
||||
async def initialize(self) -> None:
|
||||
kvstore_config = SqliteKVStoreConfig(
|
||||
db_path=(DISTRIBS_BASE_DIR / self.config.run_config.image_name / "prompts.db").as_posix()
|
||||
)
|
||||
self.kvstore = await kvstore_impl(kvstore_config)
|
||||
|
||||
def _get_default_key(self, prompt_id: str) -> str:
|
||||
"""Get the KVStore key that stores the default version number."""
|
||||
return f"prompts:v1:{prompt_id}:default"
|
||||
|
||||
async def _get_prompt_key(self, prompt_id: str, version: int | None = None) -> str:
|
||||
"""Get the KVStore key for prompt data, returning default version if applicable."""
|
||||
if version:
|
||||
return self._get_version_key(prompt_id, str(version))
|
||||
|
||||
default_key = self._get_default_key(prompt_id)
|
||||
resolved_version = await self.kvstore.get(default_key)
|
||||
if resolved_version is None:
|
||||
raise ValueError(f"Prompt {prompt_id}:default not found")
|
||||
return self._get_version_key(prompt_id, resolved_version)
|
||||
|
||||
def _get_version_key(self, prompt_id: str, version: str) -> str:
|
||||
"""Get the KVStore key for a specific prompt version."""
|
||||
return f"prompts:v1:{prompt_id}:{version}"
|
||||
|
||||
def _get_list_key_prefix(self) -> str:
|
||||
"""Get the key prefix for listing prompts."""
|
||||
return "prompts:v1:"
|
||||
|
||||
def _serialize_prompt(self, prompt: Prompt) -> str:
|
||||
"""Serialize a prompt to JSON string for storage."""
|
||||
return json.dumps(
|
||||
{
|
||||
"prompt_id": prompt.prompt_id,
|
||||
"prompt": prompt.prompt,
|
||||
"version": prompt.version,
|
||||
"variables": prompt.variables or [],
|
||||
"is_default": prompt.is_default,
|
||||
}
|
||||
)
|
||||
|
||||
def _deserialize_prompt(self, data: str) -> Prompt:
|
||||
"""Deserialize a prompt from JSON string."""
|
||||
obj = json.loads(data)
|
||||
return Prompt(
|
||||
prompt_id=obj["prompt_id"],
|
||||
prompt=obj["prompt"],
|
||||
version=obj["version"],
|
||||
variables=obj.get("variables", []),
|
||||
is_default=obj.get("is_default", False),
|
||||
)
|
||||
|
||||
async def list_prompts(self) -> ListPromptsResponse:
|
||||
"""List all prompts (default versions only)."""
|
||||
prefix = self._get_list_key_prefix()
|
||||
keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff")
|
||||
|
||||
prompts = []
|
||||
for key in keys:
|
||||
if key.endswith(":default"):
|
||||
try:
|
||||
default_version = await self.kvstore.get(key)
|
||||
if default_version:
|
||||
prompt_id = key.replace(prefix, "").replace(":default", "")
|
||||
version_key = self._get_version_key(prompt_id, default_version)
|
||||
data = await self.kvstore.get(version_key)
|
||||
if data:
|
||||
prompt = self._deserialize_prompt(data)
|
||||
prompts.append(prompt)
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
continue
|
||||
|
||||
prompts.sort(key=lambda p: p.prompt_id or "", reverse=True)
|
||||
return ListPromptsResponse(data=prompts)
|
||||
|
||||
async def get_prompt(self, prompt_id: str, version: int | None = None) -> Prompt:
|
||||
"""Get a prompt by its identifier and optional version."""
|
||||
key = await self._get_prompt_key(prompt_id, version)
|
||||
data = await self.kvstore.get(key)
|
||||
if data is None:
|
||||
raise ValueError(f"Prompt {prompt_id}:{version if version else 'default'} not found")
|
||||
return self._deserialize_prompt(data)
|
||||
|
||||
async def create_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
variables: list[str] | None = None,
|
||||
) -> Prompt:
|
||||
"""Create a new prompt."""
|
||||
if variables is None:
|
||||
variables = []
|
||||
|
||||
prompt_obj = Prompt(
|
||||
prompt_id=Prompt.generate_prompt_id(),
|
||||
prompt=prompt,
|
||||
version=1,
|
||||
variables=variables,
|
||||
)
|
||||
|
||||
version_key = self._get_version_key(prompt_obj.prompt_id, str(prompt_obj.version))
|
||||
data = self._serialize_prompt(prompt_obj)
|
||||
await self.kvstore.set(version_key, data)
|
||||
|
||||
default_key = self._get_default_key(prompt_obj.prompt_id)
|
||||
await self.kvstore.set(default_key, str(prompt_obj.version))
|
||||
|
||||
return prompt_obj
|
||||
|
||||
async def update_prompt(
|
||||
self,
|
||||
prompt_id: str,
|
||||
prompt: str,
|
||||
version: int,
|
||||
variables: list[str] | None = None,
|
||||
set_as_default: bool = True,
|
||||
) -> Prompt:
|
||||
"""Update an existing prompt (increments version)."""
|
||||
if version < 1:
|
||||
raise ValueError("Version must be >= 1")
|
||||
if variables is None:
|
||||
variables = []
|
||||
|
||||
prompt_versions = await self.list_prompt_versions(prompt_id)
|
||||
latest_prompt = max(prompt_versions.data, key=lambda x: int(x.version))
|
||||
|
||||
if version and latest_prompt.version != version:
|
||||
raise ValueError(
|
||||
f"'{version}' is not the latest prompt version for prompt_id='{prompt_id}'. Use the latest version '{latest_prompt.version}' in request."
|
||||
)
|
||||
|
||||
current_version = latest_prompt.version if version is None else version
|
||||
new_version = current_version + 1
|
||||
|
||||
updated_prompt = Prompt(prompt_id=prompt_id, prompt=prompt, version=new_version, variables=variables)
|
||||
|
||||
version_key = self._get_version_key(prompt_id, str(new_version))
|
||||
data = self._serialize_prompt(updated_prompt)
|
||||
await self.kvstore.set(version_key, data)
|
||||
|
||||
if set_as_default:
|
||||
await self.set_default_version(prompt_id, new_version)
|
||||
|
||||
return updated_prompt
|
||||
|
||||
async def delete_prompt(self, prompt_id: str) -> None:
|
||||
"""Delete a prompt and all its versions."""
|
||||
await self.get_prompt(prompt_id)
|
||||
|
||||
prefix = f"prompts:v1:{prompt_id}:"
|
||||
keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff")
|
||||
|
||||
for key in keys:
|
||||
await self.kvstore.delete(key)
|
||||
|
||||
async def list_prompt_versions(self, prompt_id: str) -> ListPromptsResponse:
|
||||
"""List all versions of a specific prompt."""
|
||||
prefix = f"prompts:v1:{prompt_id}:"
|
||||
keys = await self.kvstore.keys_in_range(prefix, prefix + "\xff")
|
||||
|
||||
default_version = None
|
||||
prompts = []
|
||||
|
||||
for key in keys:
|
||||
data = await self.kvstore.get(key)
|
||||
if key.endswith(":default"):
|
||||
default_version = data
|
||||
else:
|
||||
if data:
|
||||
prompt_obj = self._deserialize_prompt(data)
|
||||
prompts.append(prompt_obj)
|
||||
|
||||
if not prompts:
|
||||
raise ValueError(f"Prompt {prompt_id} not found")
|
||||
|
||||
for prompt in prompts:
|
||||
prompt.is_default = str(prompt.version) == default_version
|
||||
|
||||
prompts.sort(key=lambda x: x.version)
|
||||
return ListPromptsResponse(data=prompts)
|
||||
|
||||
async def set_default_version(self, prompt_id: str, version: int) -> Prompt:
|
||||
"""Set which version of a prompt should be the default, If not set. the default is the latest."""
|
||||
version_key = self._get_version_key(prompt_id, str(version))
|
||||
data = await self.kvstore.get(version_key)
|
||||
if data is None:
|
||||
raise ValueError(f"Prompt {prompt_id} version {version} not found")
|
||||
|
||||
default_key = self._get_default_key(prompt_id)
|
||||
await self.kvstore.set(default_key, str(version))
|
||||
|
||||
return self._deserialize_prompt(data)
|
|
@ -19,6 +19,7 @@ from llama_stack.apis.inference import Inference, InferenceProvider
|
|||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.post_training import PostTraining
|
||||
from llama_stack.apis.prompts import Prompts
|
||||
from llama_stack.apis.providers import Providers as ProvidersAPI
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
|
@ -93,6 +94,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
|
|||
Api.tool_groups: ToolGroups,
|
||||
Api.tool_runtime: ToolRuntime,
|
||||
Api.files: Files,
|
||||
Api.prompts: Prompts,
|
||||
}
|
||||
|
||||
if external_apis:
|
||||
|
@ -284,7 +286,15 @@ async def instantiate_providers(
|
|||
if provider.provider_id is None:
|
||||
continue
|
||||
|
||||
deps = {a: impls[a] for a in provider.spec.api_dependencies}
|
||||
try:
|
||||
deps = {a: impls[a] for a in provider.spec.api_dependencies}
|
||||
except KeyError as e:
|
||||
missing_api = e.args[0]
|
||||
raise RuntimeError(
|
||||
f"Failed to resolve '{provider.spec.api.value}' provider '{provider.provider_id}' of type '{provider.spec.provider_type}': "
|
||||
f"required dependency '{missing_api.value}' is not available. "
|
||||
f"Please add a '{missing_api.value}' provider to your configuration or check if the provider is properly configured."
|
||||
) from e
|
||||
for a in provider.spec.optional_api_dependencies:
|
||||
if a in impls:
|
||||
deps[a] = impls[a]
|
||||
|
|
|
@ -78,7 +78,10 @@ async def get_auto_router_impl(
|
|||
|
||||
# TODO: move pass configs to routers instead
|
||||
if api == Api.inference and run_config.inference_store:
|
||||
inference_store = InferenceStore(run_config.inference_store, policy)
|
||||
inference_store = InferenceStore(
|
||||
config=run_config.inference_store,
|
||||
policy=policy,
|
||||
)
|
||||
await inference_store.initialize()
|
||||
api_to_dep_impl["store"] = inference_store
|
||||
|
||||
|
|
|
@ -19,8 +19,6 @@ from llama_stack.apis.common.content_types import (
|
|||
)
|
||||
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
|
||||
from llama_stack.apis.inference import (
|
||||
BatchChatCompletionResponse,
|
||||
BatchCompletionResponse,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
|
@ -59,7 +57,7 @@ from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
|||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||
from llama_stack.providers.utils.telemetry.tracing import enqueue_event, get_current_span
|
||||
|
||||
logger = get_logger(name=__name__, category="core::routers")
|
||||
|
||||
|
@ -86,6 +84,11 @@ class InferenceRouter(Inference):
|
|||
|
||||
async def shutdown(self) -> None:
|
||||
logger.debug("InferenceRouter.shutdown")
|
||||
if self.store:
|
||||
try:
|
||||
await self.store.shutdown()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error during InferenceStore shutdown: {e}")
|
||||
|
||||
async def register_model(
|
||||
self,
|
||||
|
@ -156,7 +159,7 @@ class InferenceRouter(Inference):
|
|||
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model)
|
||||
if self.telemetry:
|
||||
for metric in metrics:
|
||||
await self.telemetry.log_event(metric)
|
||||
enqueue_event(metric)
|
||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
||||
|
||||
async def _count_tokens(
|
||||
|
@ -264,30 +267,6 @@ class InferenceRouter(Inference):
|
|||
)
|
||||
return response
|
||||
|
||||
async def batch_chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages_batch: list[list[Message]],
|
||||
tools: list[ToolDefinition] | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> BatchChatCompletionResponse:
|
||||
logger.debug(
|
||||
f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl(model_id)
|
||||
return await provider.batch_chat_completion(
|
||||
model_id=model_id,
|
||||
messages_batch=messages_batch,
|
||||
tools=tools,
|
||||
tool_config=tool_config,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
@ -329,20 +308,6 @@ class InferenceRouter(Inference):
|
|||
|
||||
return response
|
||||
|
||||
async def batch_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content_batch: list[InterleavedContent],
|
||||
sampling_params: SamplingParams | None = None,
|
||||
response_format: ResponseFormat | None = None,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
) -> BatchCompletionResponse:
|
||||
logger.debug(
|
||||
f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
||||
)
|
||||
provider = await self.routing_table.get_provider_impl(model_id)
|
||||
return await provider.batch_completion(model_id, content_batch, sampling_params, response_format, logprobs)
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -408,7 +373,7 @@ class InferenceRouter(Inference):
|
|||
model=model_obj,
|
||||
)
|
||||
for metric in metrics:
|
||||
await self.telemetry.log_event(metric)
|
||||
enqueue_event(metric)
|
||||
|
||||
# these metrics will show up in the client response.
|
||||
response.metrics = (
|
||||
|
@ -504,7 +469,7 @@ class InferenceRouter(Inference):
|
|||
|
||||
# Store the response with the ID that will be returned to the client
|
||||
if self.store:
|
||||
await self.store.store_chat_completion(response, messages)
|
||||
asyncio.create_task(self.store.store_chat_completion(response, messages))
|
||||
|
||||
if self.telemetry:
|
||||
metrics = self._construct_metrics(
|
||||
|
@ -514,7 +479,7 @@ class InferenceRouter(Inference):
|
|||
model=model_obj,
|
||||
)
|
||||
for metric in metrics:
|
||||
await self.telemetry.log_event(metric)
|
||||
enqueue_event(metric)
|
||||
# these metrics will show up in the client response.
|
||||
response.metrics = (
|
||||
metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics
|
||||
|
@ -641,7 +606,7 @@ class InferenceRouter(Inference):
|
|||
"completion_tokens",
|
||||
"total_tokens",
|
||||
]: # Only log completion and total tokens
|
||||
await self.telemetry.log_event(metric)
|
||||
enqueue_event(metric)
|
||||
|
||||
# Return metrics in response
|
||||
async_metrics = [
|
||||
|
@ -687,7 +652,7 @@ class InferenceRouter(Inference):
|
|||
)
|
||||
for metric in completion_metrics:
|
||||
if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens
|
||||
await self.telemetry.log_event(metric)
|
||||
enqueue_event(metric)
|
||||
|
||||
# Return metrics in response
|
||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics]
|
||||
|
@ -732,7 +697,7 @@ class InferenceRouter(Inference):
|
|||
choices_data[idx] = {
|
||||
"content_parts": [],
|
||||
"tool_calls_builder": {},
|
||||
"finish_reason": None,
|
||||
"finish_reason": "stop",
|
||||
"logprobs_content_parts": [],
|
||||
}
|
||||
current_choice_data = choices_data[idx]
|
||||
|
@ -783,7 +748,7 @@ class InferenceRouter(Inference):
|
|||
model=model,
|
||||
)
|
||||
for metric in metrics:
|
||||
await self.telemetry.log_event(metric)
|
||||
enqueue_event(metric)
|
||||
|
||||
yield chunk
|
||||
finally:
|
||||
|
@ -832,4 +797,4 @@ class InferenceRouter(Inference):
|
|||
object="chat.completion",
|
||||
)
|
||||
logger.debug(f"InferenceRouter.completion_response: {final_response}")
|
||||
await self.store.store_chat_completion(final_response, messages)
|
||||
asyncio.create_task(self.store.store_chat_completion(final_response, messages))
|
||||
|
|
|
@ -56,3 +56,7 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
|||
provider_resource_id=provider_benchmark_id,
|
||||
)
|
||||
await self.register_object(benchmark)
|
||||
|
||||
async def unregister_benchmark(self, benchmark_id: str) -> None:
|
||||
existing_benchmark = await self.get_benchmark(benchmark_id)
|
||||
await self.unregister_object(existing_benchmark)
|
||||
|
|
|
@ -64,6 +64,10 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
|||
return await p.unregister_shield(obj.identifier)
|
||||
elif api == Api.datasetio:
|
||||
return await p.unregister_dataset(obj.identifier)
|
||||
elif api == Api.eval:
|
||||
return await p.unregister_benchmark(obj.identifier)
|
||||
elif api == Api.scoring:
|
||||
return await p.unregister_scoring_function(obj.identifier)
|
||||
elif api == Api.tool_runtime:
|
||||
return await p.unregister_toolgroup(obj.identifier)
|
||||
else:
|
||||
|
|
|
@ -33,7 +33,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
try:
|
||||
models = await provider.list_models()
|
||||
except Exception as e:
|
||||
logger.exception(f"Model refresh failed for provider {provider_id}: {e}")
|
||||
logger.warning(f"Model refresh failed for provider {provider_id}: {e}")
|
||||
continue
|
||||
|
||||
self.listed_providers.add(provider_id)
|
||||
|
|
|
@ -60,3 +60,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
|||
)
|
||||
scoring_fn.provider_id = provider_id
|
||||
await self.register_object(scoring_fn)
|
||||
|
||||
async def unregister_scoring_function(self, scoring_fn_id: str) -> None:
|
||||
existing_scoring_fn = await self.get_scoring_function(scoring_fn_id)
|
||||
await self.unregister_object(existing_scoring_fn)
|
||||
|
|
|
@ -9,7 +9,7 @@ from typing import Any
|
|||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.common.errors import ToolGroupNotFoundError
|
||||
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
|
||||
from llama_stack.core.datatypes import ToolGroupWithOwner
|
||||
from llama_stack.core.datatypes import AuthenticationRequiredError, ToolGroupWithOwner
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .common import CommonRoutingTableImpl
|
||||
|
@ -54,7 +54,18 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
all_tools = []
|
||||
for toolgroup in toolgroups:
|
||||
if toolgroup.identifier not in self.toolgroups_to_tools:
|
||||
await self._index_tools(toolgroup)
|
||||
try:
|
||||
await self._index_tools(toolgroup)
|
||||
except AuthenticationRequiredError:
|
||||
# Send authentication errors back to the client so it knows
|
||||
# that it needs to supply credentials for remote MCP servers.
|
||||
raise
|
||||
except Exception as e:
|
||||
# Other errors that the client cannot fix are logged and
|
||||
# those specific toolgroups are skipped.
|
||||
logger.warning(f"Error listing tools for toolgroup {toolgroup.identifier}: {e}")
|
||||
logger.debug(e, exc_info=True)
|
||||
continue
|
||||
all_tools.extend(self.toolgroups_to_tools[toolgroup.identifier])
|
||||
|
||||
return ListToolsResponse(data=all_tools)
|
||||
|
|
|
@ -52,7 +52,6 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
provider_vector_db_id: str | None = None,
|
||||
vector_db_name: str | None = None,
|
||||
) -> VectorDB:
|
||||
provider_vector_db_id = provider_vector_db_id or vector_db_id
|
||||
if provider_id is None:
|
||||
if len(self.impls_by_provider_id) > 0:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
|
@ -69,14 +68,33 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding)
|
||||
if "embedding_dimension" not in model.metadata:
|
||||
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
|
||||
|
||||
provider = self.impls_by_provider_id[provider_id]
|
||||
logger.warning(
|
||||
"VectorDB is being deprecated in future releases in favor of VectorStore. Please migrate your usage accordingly."
|
||||
)
|
||||
vector_store = await provider.openai_create_vector_store(
|
||||
name=vector_db_name or vector_db_id,
|
||||
embedding_model=embedding_model,
|
||||
embedding_dimension=model.metadata["embedding_dimension"],
|
||||
provider_id=provider_id,
|
||||
provider_vector_db_id=provider_vector_db_id,
|
||||
)
|
||||
|
||||
vector_store_id = vector_store.id
|
||||
actual_provider_vector_db_id = provider_vector_db_id or vector_store_id
|
||||
logger.warning(
|
||||
f"Ignoring vector_db_id {vector_db_id} and using vector_store_id {vector_store_id} instead. Setting VectorDB {vector_db_id} to VectorDB.vector_db_name"
|
||||
)
|
||||
|
||||
vector_db_data = {
|
||||
"identifier": vector_db_id,
|
||||
"identifier": vector_store_id,
|
||||
"type": ResourceType.vector_db.value,
|
||||
"provider_id": provider_id,
|
||||
"provider_resource_id": provider_vector_db_id,
|
||||
"provider_resource_id": actual_provider_vector_db_id,
|
||||
"embedding_model": embedding_model,
|
||||
"embedding_dimension": model.metadata["embedding_dimension"],
|
||||
"vector_db_name": vector_db_name,
|
||||
"vector_db_name": vector_store.name,
|
||||
}
|
||||
vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
|
||||
await self.register_object(vector_db)
|
||||
|
|
|
@ -8,16 +8,18 @@ import ssl
|
|||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from asyncio import Lock
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
from urllib.parse import parse_qs, urljoin, urlparse
|
||||
|
||||
import httpx
|
||||
from jose import jwt
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.common.errors import TokenValidationError
|
||||
from llama_stack.core.datatypes import (
|
||||
AuthenticationConfig,
|
||||
CustomAuthConfig,
|
||||
GitHubTokenAuthConfig,
|
||||
KubernetesAuthProviderConfig,
|
||||
OAuth2TokenAuthConfig,
|
||||
User,
|
||||
)
|
||||
|
@ -162,7 +164,7 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
auth=auth,
|
||||
timeout=10.0, # Add a reasonable timeout
|
||||
)
|
||||
if response.status_code != 200:
|
||||
if response.status_code != httpx.codes.OK:
|
||||
logger.warning(f"Token introspection failed with status code: {response.status_code}")
|
||||
raise ValueError(f"Token introspection failed: {response.status_code}")
|
||||
|
||||
|
@ -272,7 +274,7 @@ class CustomAuthProvider(AuthProvider):
|
|||
json=auth_request.model_dump(),
|
||||
timeout=10.0, # Add a reasonable timeout
|
||||
)
|
||||
if response.status_code != 200:
|
||||
if response.status_code != httpx.codes.OK:
|
||||
logger.warning(f"Authentication failed with status code: {response.status_code}")
|
||||
raise ValueError(f"Authentication failed: {response.status_code}")
|
||||
|
||||
|
@ -374,6 +376,89 @@ async def _get_github_user_info(access_token: str, github_api_base_url: str) ->
|
|||
}
|
||||
|
||||
|
||||
class KubernetesAuthProvider(AuthProvider):
|
||||
"""
|
||||
Kubernetes authentication provider that validates tokens using the Kubernetes SelfSubjectReview API.
|
||||
This provider integrates with Kubernetes API server by using the
|
||||
/apis/authentication.k8s.io/v1/selfsubjectreviews endpoint to validate tokens and extract user information.
|
||||
"""
|
||||
|
||||
def __init__(self, config: KubernetesAuthProviderConfig):
|
||||
self.config = config
|
||||
|
||||
def _httpx_verify_value(self) -> bool | str:
|
||||
"""
|
||||
Build the value for httpx's `verify` parameter.
|
||||
- False disables verification.
|
||||
- Path string points to a CA bundle.
|
||||
- True uses system defaults.
|
||||
"""
|
||||
if not self.config.verify_tls:
|
||||
return False
|
||||
if self.config.tls_cafile:
|
||||
return self.config.tls_cafile.as_posix()
|
||||
return True
|
||||
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> User:
|
||||
"""Validate a token using Kubernetes SelfSubjectReview API endpoint."""
|
||||
# Build the Kubernetes SelfSubjectReview API endpoint URL
|
||||
review_api_url = urljoin(self.config.api_server_url, "/apis/authentication.k8s.io/v1/selfsubjectreviews")
|
||||
|
||||
# Create SelfSubjectReview request body
|
||||
review_request = {"apiVersion": "authentication.k8s.io/v1", "kind": "SelfSubjectReview"}
|
||||
verify = self._httpx_verify_value()
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(verify=verify, timeout=10.0) as client:
|
||||
response = await client.post(
|
||||
review_api_url,
|
||||
json=review_request,
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code == httpx.codes.UNAUTHORIZED:
|
||||
raise TokenValidationError("Invalid token")
|
||||
if response.status_code != httpx.codes.CREATED:
|
||||
logger.warning(f"Kubernetes SelfSubjectReview API failed with status code: {response.status_code}")
|
||||
raise TokenValidationError(f"Token validation failed: {response.status_code}")
|
||||
|
||||
review_response = response.json()
|
||||
# Extract user information from SelfSubjectReview response
|
||||
status = review_response.get("status", {})
|
||||
if not status:
|
||||
raise ValueError("No status found in SelfSubjectReview response")
|
||||
|
||||
user_info = status.get("userInfo", {})
|
||||
if not user_info:
|
||||
raise ValueError("No userInfo found in SelfSubjectReview response")
|
||||
|
||||
username = user_info.get("username")
|
||||
if not username:
|
||||
raise ValueError("No username found in SelfSubjectReview response")
|
||||
|
||||
# Build user attributes from Kubernetes user info
|
||||
user_attributes = get_attributes_from_claims(user_info, self.config.claims_mapping)
|
||||
|
||||
return User(
|
||||
principal=username,
|
||||
attributes=user_attributes,
|
||||
)
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.warning("Kubernetes SelfSubjectReview API request timed out")
|
||||
raise ValueError("Token validation timeout") from None
|
||||
except Exception as e:
|
||||
logger.warning(f"Error during token validation: {str(e)}")
|
||||
raise ValueError(f"Token validation error: {str(e)}") from e
|
||||
|
||||
async def close(self):
|
||||
"""Close any resources."""
|
||||
pass
|
||||
|
||||
|
||||
def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
|
||||
"""Factory function to create the appropriate auth provider."""
|
||||
provider_config = config.provider_config
|
||||
|
@ -384,5 +469,7 @@ def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
|
|||
return OAuth2TokenAuthProvider(provider_config)
|
||||
elif isinstance(provider_config, GitHubTokenAuthConfig):
|
||||
return GitHubTokenAuthProvider(provider_config)
|
||||
elif isinstance(provider_config, KubernetesAuthProviderConfig):
|
||||
return KubernetesAuthProvider(provider_config)
|
||||
else:
|
||||
raise ValueError(f"Unknown authentication provider config type: {type(provider_config)}")
|
||||
|
|
|
@ -14,7 +14,6 @@ from starlette.routing import Route
|
|||
|
||||
from llama_stack.apis.datatypes import Api, ExternalApiSpec
|
||||
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
||||
from llama_stack.core.resolver import api_protocol_map
|
||||
from llama_stack.schema_utils import WebMethod
|
||||
|
||||
|
@ -54,22 +53,23 @@ def get_all_api_routes(
|
|||
protocol_methods.append((f"{tool_group.value}.{name}", method))
|
||||
|
||||
for name, method in protocol_methods:
|
||||
if not hasattr(method, "__webmethod__"):
|
||||
# Get all webmethods for this method (supports multiple decorators)
|
||||
webmethods = getattr(method, "__webmethods__", [])
|
||||
if not webmethods:
|
||||
continue
|
||||
|
||||
# The __webmethod__ attribute is dynamically added by the @webmethod decorator
|
||||
# mypy doesn't know about this dynamic attribute, so we ignore the attr-defined error
|
||||
webmethod = method.__webmethod__ # type: ignore[attr-defined]
|
||||
path = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
|
||||
if webmethod.method == hdrs.METH_GET:
|
||||
http_method = hdrs.METH_GET
|
||||
elif webmethod.method == hdrs.METH_DELETE:
|
||||
http_method = hdrs.METH_DELETE
|
||||
else:
|
||||
http_method = hdrs.METH_POST
|
||||
routes.append(
|
||||
(Route(path=path, methods=[http_method], name=name, endpoint=None), webmethod)
|
||||
) # setting endpoint to None since don't use a Router object
|
||||
# Create routes for each webmethod decorator
|
||||
for webmethod in webmethods:
|
||||
path = f"/{webmethod.level}/{webmethod.route.lstrip('/')}"
|
||||
if webmethod.method == hdrs.METH_GET:
|
||||
http_method = hdrs.METH_GET
|
||||
elif webmethod.method == hdrs.METH_DELETE:
|
||||
http_method = hdrs.METH_DELETE
|
||||
else:
|
||||
http_method = hdrs.METH_POST
|
||||
routes.append(
|
||||
(Route(path=path, methods=[http_method], name=name, endpoint=None), webmethod)
|
||||
) # setting endpoint to None since don't use a Router object
|
||||
|
||||
apis[api] = routes
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
import argparse
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
|
@ -24,7 +25,6 @@ from typing import Annotated, Any, get_origin
|
|||
import httpx
|
||||
import rich.pretty
|
||||
import yaml
|
||||
from aiohttp import hdrs
|
||||
from fastapi import Body, FastAPI, HTTPException, Request, Response
|
||||
from fastapi import Path as FastapiPath
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
|
@ -44,23 +44,17 @@ from llama_stack.core.datatypes import (
|
|||
process_cors_config,
|
||||
)
|
||||
from llama_stack.core.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.core.external import ExternalApiSpec, load_external_apis
|
||||
from llama_stack.core.external import load_external_apis
|
||||
from llama_stack.core.request_headers import (
|
||||
PROVIDER_DATA_VAR,
|
||||
request_provider_data_context,
|
||||
user_from_scope,
|
||||
)
|
||||
from llama_stack.core.resolver import InvalidProviderError
|
||||
from llama_stack.core.server.routes import (
|
||||
find_matching_route,
|
||||
get_all_api_routes,
|
||||
initialize_route_impls,
|
||||
)
|
||||
from llama_stack.core.server.routes import get_all_api_routes
|
||||
from llama_stack.core.stack import (
|
||||
Stack,
|
||||
cast_image_name_to_string,
|
||||
construct_stack,
|
||||
replace_env_vars,
|
||||
shutdown_stack,
|
||||
validate_env_pair,
|
||||
)
|
||||
from llama_stack.core.utils.config import redact_sensitive_fields
|
||||
|
@ -74,13 +68,12 @@ from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
|
|||
)
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
CURRENT_TRACE_CONTEXT,
|
||||
end_trace,
|
||||
setup_logger,
|
||||
start_trace,
|
||||
)
|
||||
|
||||
from .auth import AuthenticationMiddleware
|
||||
from .quota import QuotaMiddleware
|
||||
from .tracing import TracingMiddleware
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
|
@ -132,15 +125,17 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
|
|||
},
|
||||
)
|
||||
elif isinstance(exc, ConflictError):
|
||||
return HTTPException(status_code=409, detail=str(exc))
|
||||
return HTTPException(status_code=httpx.codes.CONFLICT, detail=str(exc))
|
||||
elif isinstance(exc, ResourceNotFoundError):
|
||||
return HTTPException(status_code=404, detail=str(exc))
|
||||
return HTTPException(status_code=httpx.codes.NOT_FOUND, detail=str(exc))
|
||||
elif isinstance(exc, ValueError):
|
||||
return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=f"Invalid value: {str(exc)}")
|
||||
elif isinstance(exc, BadRequestError):
|
||||
return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=str(exc))
|
||||
elif isinstance(exc, PermissionError | AccessDeniedError):
|
||||
return HTTPException(status_code=httpx.codes.FORBIDDEN, detail=f"Permission denied: {str(exc)}")
|
||||
elif isinstance(exc, ConnectionError | httpx.ConnectError):
|
||||
return HTTPException(status_code=httpx.codes.BAD_GATEWAY, detail=str(exc))
|
||||
elif isinstance(exc, asyncio.TimeoutError | TimeoutError):
|
||||
return HTTPException(status_code=httpx.codes.GATEWAY_TIMEOUT, detail=f"Operation timed out: {str(exc)}")
|
||||
elif isinstance(exc, NotImplementedError):
|
||||
|
@ -154,21 +149,34 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
|
|||
)
|
||||
|
||||
|
||||
async def shutdown(app):
|
||||
"""Initiate a graceful shutdown of the application.
|
||||
|
||||
Handled by the lifespan context manager. The shutdown process involves
|
||||
shutting down all implementations registered in the application.
|
||||
class StackApp(FastAPI):
|
||||
"""
|
||||
await shutdown_stack(app.__llama_stack_impls__)
|
||||
A wrapper around the FastAPI application to hold a reference to the Stack instance so that we can
|
||||
start background tasks (e.g. refresh model registry periodically) from the lifespan context manager.
|
||||
"""
|
||||
|
||||
def __init__(self, config: StackRunConfig, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.stack: Stack = Stack(config)
|
||||
|
||||
# This code is called from a running event loop managed by uvicorn so we cannot simply call
|
||||
# asyncio.run() to initialize the stack. We cannot await either since this is not an async
|
||||
# function.
|
||||
# As a workaround, we use a thread pool executor to run the initialize() method
|
||||
# in a separate thread.
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(asyncio.run, self.stack.initialize())
|
||||
future.result()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
async def lifespan(app: StackApp):
|
||||
logger.info("Starting up")
|
||||
assert app.stack is not None
|
||||
app.stack.create_registry_refresh_task()
|
||||
yield
|
||||
logger.info("Shutting down")
|
||||
await shutdown(app)
|
||||
await app.stack.shutdown()
|
||||
|
||||
|
||||
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
||||
|
@ -285,65 +293,6 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
|||
return route_handler
|
||||
|
||||
|
||||
class TracingMiddleware:
|
||||
def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]):
|
||||
self.app = app
|
||||
self.impls = impls
|
||||
self.external_apis = external_apis
|
||||
# FastAPI built-in paths that should bypass custom routing
|
||||
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope.get("type") == "lifespan":
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
path = scope.get("path", "")
|
||||
|
||||
# Check if the path is a FastAPI built-in path
|
||||
if path.startswith(self.fastapi_paths):
|
||||
# Pass through to FastAPI's built-in handlers
|
||||
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
if not hasattr(self, "route_impls"):
|
||||
self.route_impls = initialize_route_impls(self.impls, self.external_apis)
|
||||
|
||||
try:
|
||||
_, _, route_path, webmethod = find_matching_route(
|
||||
scope.get("method", hdrs.METH_GET), path, self.route_impls
|
||||
)
|
||||
except ValueError:
|
||||
# If no matching endpoint is found, pass through to FastAPI
|
||||
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
trace_attributes = {"__location__": "server", "raw_path": path}
|
||||
|
||||
# Extract W3C trace context headers and store as trace attributes
|
||||
headers = dict(scope.get("headers", []))
|
||||
traceparent = headers.get(b"traceparent", b"").decode()
|
||||
if traceparent:
|
||||
trace_attributes["traceparent"] = traceparent
|
||||
tracestate = headers.get(b"tracestate", b"").decode()
|
||||
if tracestate:
|
||||
trace_attributes["tracestate"] = tracestate
|
||||
|
||||
trace_path = webmethod.descriptive_name or route_path
|
||||
trace_context = await start_trace(trace_path, trace_attributes)
|
||||
|
||||
async def send_with_trace_id(message):
|
||||
if message["type"] == "http.response.start":
|
||||
headers = message.get("headers", [])
|
||||
headers.append([b"x-trace-id", str(trace_context.trace_id).encode()])
|
||||
message["headers"] = headers
|
||||
await send(message)
|
||||
|
||||
try:
|
||||
return await self.app(scope, receive, send_with_trace_id)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
|
||||
class ClientVersionMiddleware:
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
|
@ -384,73 +333,61 @@ class ClientVersionMiddleware:
|
|||
return await self.app(scope, receive, send)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace | None = None):
|
||||
"""Start the LlamaStack server."""
|
||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||
def create_app(
|
||||
config_file: str | None = None,
|
||||
env_vars: list[str] | None = None,
|
||||
) -> StackApp:
|
||||
"""Create and configure the FastAPI application.
|
||||
|
||||
add_config_distro_args(parser)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
|
||||
help="Port to listen on",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--env",
|
||||
action="append",
|
||||
help="Environment variables in KEY=value format. Can be specified multiple times.",
|
||||
)
|
||||
Args:
|
||||
config_file: Path to config file. If None, uses LLAMA_STACK_CONFIG env var or default resolution.
|
||||
env_vars: List of environment variables in KEY=value format.
|
||||
disable_version_check: Whether to disable version checking. If None, uses LLAMA_STACK_DISABLE_VERSION_CHECK env var.
|
||||
|
||||
# Determine whether the server args are being passed by the "run" command, if this is the case
|
||||
# the args will be passed as a Namespace object to the main function, otherwise they will be
|
||||
# parsed from the command line
|
||||
if args is None:
|
||||
args = parser.parse_args()
|
||||
Returns:
|
||||
Configured StackApp instance.
|
||||
"""
|
||||
config_file = config_file or os.getenv("LLAMA_STACK_CONFIG")
|
||||
if config_file is None:
|
||||
raise ValueError("No config file provided and LLAMA_STACK_CONFIG env var is not set")
|
||||
|
||||
config_or_distro = get_config_from_args(args)
|
||||
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
|
||||
config_file = resolve_config_or_distro(config_file, Mode.RUN)
|
||||
|
||||
# Load and process configuration
|
||||
logger_config = None
|
||||
with open(config_file) as fp:
|
||||
config_contents = yaml.safe_load(fp)
|
||||
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||
logger_config = LoggingConfig(**cfg)
|
||||
logger = get_logger(name=__name__, category="core::server", config=logger_config)
|
||||
if args.env:
|
||||
for env_pair in args.env:
|
||||
|
||||
if env_vars:
|
||||
for env_pair in env_vars:
|
||||
try:
|
||||
key, value = validate_env_pair(env_pair)
|
||||
logger.info(f"Setting CLI environment variable {key} => {value}")
|
||||
logger.info(f"Setting environment variable {key} => {value}")
|
||||
os.environ[key] = value
|
||||
except ValueError as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
raise ValueError(f"Invalid environment variable format: {env_pair}") from e
|
||||
|
||||
config = replace_env_vars(config_contents)
|
||||
config = StackRunConfig(**cast_image_name_to_string(config))
|
||||
|
||||
_log_run_config(run_config=config)
|
||||
|
||||
app = FastAPI(
|
||||
app = StackApp(
|
||||
lifespan=lifespan,
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
openapi_url="/openapi.json",
|
||||
config=config,
|
||||
)
|
||||
|
||||
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
||||
app.add_middleware(ClientVersionMiddleware)
|
||||
|
||||
try:
|
||||
# Create and set the event loop that will be used for both construction and server runtime
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# Construct the stack in the persistent event loop
|
||||
impls = loop.run_until_complete(construct_stack(config))
|
||||
|
||||
except InvalidProviderError as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
impls = app.stack.impls
|
||||
|
||||
if config.server.auth:
|
||||
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
|
||||
|
@ -513,6 +450,7 @@ def main(args: argparse.Namespace | None = None):
|
|||
|
||||
apis_to_serve.add("inspect")
|
||||
apis_to_serve.add("providers")
|
||||
apis_to_serve.add("prompts")
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
|
||||
|
@ -550,9 +488,54 @@ def main(args: argparse.Namespace | None = None):
|
|||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||
app.exception_handler(Exception)(global_exception_handler)
|
||||
|
||||
app.__llama_stack_impls__ = impls
|
||||
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def main(args: argparse.Namespace | None = None):
|
||||
"""Start the LlamaStack server."""
|
||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||
|
||||
add_config_distro_args(parser)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
|
||||
help="Port to listen on",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--env",
|
||||
action="append",
|
||||
help="Environment variables in KEY=value format. Can be specified multiple times.",
|
||||
)
|
||||
|
||||
# Determine whether the server args are being passed by the "run" command, if this is the case
|
||||
# the args will be passed as a Namespace object to the main function, otherwise they will be
|
||||
# parsed from the command line
|
||||
if args is None:
|
||||
args = parser.parse_args()
|
||||
|
||||
config_or_distro = get_config_from_args(args)
|
||||
|
||||
try:
|
||||
app = create_app(
|
||||
config_file=config_or_distro,
|
||||
env_vars=args.env,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating app: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
|
||||
with open(config_file) as fp:
|
||||
config_contents = yaml.safe_load(fp)
|
||||
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||
logger_config = LoggingConfig(**cfg)
|
||||
else:
|
||||
logger_config = None
|
||||
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
|
||||
|
||||
import uvicorn
|
||||
|
||||
# Configure SSL if certificates are provided
|
||||
|
@ -590,7 +573,6 @@ def main(args: argparse.Namespace | None = None):
|
|||
if ssl_config:
|
||||
uvicorn_config.update(ssl_config)
|
||||
|
||||
# Run uvicorn in the existing event loop to preserve background tasks
|
||||
# We need to catch KeyboardInterrupt because uvicorn's signal handling
|
||||
# re-raises SIGINT signals using signal.raise_signal(), which Python
|
||||
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
|
||||
|
@ -601,13 +583,9 @@ def main(args: argparse.Namespace | None = None):
|
|||
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
|
||||
# signal handling but this is quite intrusive and not worth the effort.
|
||||
try:
|
||||
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
|
||||
asyncio.run(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
logger.info("Received interrupt signal, shutting down gracefully...")
|
||||
finally:
|
||||
if not loop.is_closed():
|
||||
logger.debug("Closing event loop")
|
||||
loop.close()
|
||||
|
||||
|
||||
def _log_run_config(run_config: StackRunConfig):
|
||||
|
|
80
llama_stack/core/server/tracing.py
Normal file
80
llama_stack/core/server/tracing.py
Normal file
|
@ -0,0 +1,80 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from aiohttp import hdrs
|
||||
|
||||
from llama_stack.core.external import ExternalApiSpec
|
||||
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.telemetry.tracing import end_trace, start_trace
|
||||
|
||||
logger = get_logger(name=__name__, category="core::server")
|
||||
|
||||
|
||||
class TracingMiddleware:
|
||||
def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]):
|
||||
self.app = app
|
||||
self.impls = impls
|
||||
self.external_apis = external_apis
|
||||
# FastAPI built-in paths that should bypass custom routing
|
||||
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope.get("type") == "lifespan":
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
path = scope.get("path", "")
|
||||
|
||||
# Check if the path is a FastAPI built-in path
|
||||
if path.startswith(self.fastapi_paths):
|
||||
# Pass through to FastAPI's built-in handlers
|
||||
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
if not hasattr(self, "route_impls"):
|
||||
self.route_impls = initialize_route_impls(self.impls, self.external_apis)
|
||||
|
||||
try:
|
||||
_, _, route_path, webmethod = find_matching_route(
|
||||
scope.get("method", hdrs.METH_GET), path, self.route_impls
|
||||
)
|
||||
except ValueError:
|
||||
# If no matching endpoint is found, pass through to FastAPI
|
||||
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
# Log deprecation warning if route is deprecated
|
||||
if getattr(webmethod, "deprecated", False):
|
||||
logger.warning(
|
||||
f"DEPRECATED ROUTE USED: {scope.get('method', 'GET')} {path} - "
|
||||
f"This route is deprecated and may be removed in a future version. "
|
||||
f"Please check the docs for the supported version."
|
||||
)
|
||||
|
||||
trace_attributes = {"__location__": "server", "raw_path": path}
|
||||
|
||||
# Extract W3C trace context headers and store as trace attributes
|
||||
headers = dict(scope.get("headers", []))
|
||||
traceparent = headers.get(b"traceparent", b"").decode()
|
||||
if traceparent:
|
||||
trace_attributes["traceparent"] = traceparent
|
||||
tracestate = headers.get(b"tracestate", b"").decode()
|
||||
if tracestate:
|
||||
trace_attributes["tracestate"] = tracestate
|
||||
|
||||
trace_path = webmethod.descriptive_name or route_path
|
||||
trace_context = await start_trace(trace_path, trace_attributes)
|
||||
|
||||
async def send_with_trace_id(message):
|
||||
if message["type"] == "http.response.start":
|
||||
headers = message.get("headers", [])
|
||||
headers.append([b"x-trace-id", str(trace_context.trace_id).encode()])
|
||||
message["headers"] = headers
|
||||
await send(message)
|
||||
|
||||
try:
|
||||
return await self.app(scope, receive, send_with_trace_id)
|
||||
finally:
|
||||
await end_trace()
|
|
@ -14,7 +14,6 @@ from typing import Any
|
|||
import yaml
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.batch_inference import BatchInference
|
||||
from llama_stack.apis.benchmarks import Benchmarks
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
|
@ -24,6 +23,7 @@ from llama_stack.apis.inference import Inference
|
|||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.post_training import PostTraining
|
||||
from llama_stack.apis.prompts import Prompts
|
||||
from llama_stack.apis.providers import Providers
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
|
@ -37,6 +37,7 @@ from llama_stack.apis.vector_io import VectorIO
|
|||
from llama_stack.core.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_provider_registry
|
||||
from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl
|
||||
from llama_stack.core.prompts.prompts import PromptServiceConfig, PromptServiceImpl
|
||||
from llama_stack.core.providers import ProviderImpl, ProviderImplConfig
|
||||
from llama_stack.core.resolver import ProviderRegistry, resolve_impls
|
||||
from llama_stack.core.routing_tables.common import CommonRoutingTableImpl
|
||||
|
@ -52,7 +53,6 @@ class LlamaStack(
|
|||
Providers,
|
||||
VectorDBs,
|
||||
Inference,
|
||||
BatchInference,
|
||||
Agents,
|
||||
Safety,
|
||||
SyntheticDataGeneration,
|
||||
|
@ -72,6 +72,7 @@ class LlamaStack(
|
|||
ToolRuntime,
|
||||
RAGToolRuntime,
|
||||
Files,
|
||||
Prompts,
|
||||
):
|
||||
pass
|
||||
|
||||
|
@ -305,76 +306,91 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
|
|||
)
|
||||
impls[Api.providers] = providers_impl
|
||||
|
||||
prompts_impl = PromptServiceImpl(
|
||||
PromptServiceConfig(run_config=run_config),
|
||||
deps=impls,
|
||||
)
|
||||
impls[Api.prompts] = prompts_impl
|
||||
|
||||
# Produces a stack of providers for the given run config. Not all APIs may be
|
||||
# asked for in the run config.
|
||||
async def construct_stack(
|
||||
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None
|
||||
) -> dict[Api, Any]:
|
||||
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
|
||||
from llama_stack.testing.inference_recorder import setup_inference_recording
|
||||
|
||||
class Stack:
|
||||
def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
|
||||
self.run_config = run_config
|
||||
self.provider_registry = provider_registry
|
||||
self.impls = None
|
||||
|
||||
# Produces a stack of providers for the given run config. Not all APIs may be
|
||||
# asked for in the run config.
|
||||
async def initialize(self):
|
||||
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
|
||||
from llama_stack.testing.inference_recorder import setup_inference_recording
|
||||
|
||||
global TEST_RECORDING_CONTEXT
|
||||
TEST_RECORDING_CONTEXT = setup_inference_recording()
|
||||
if TEST_RECORDING_CONTEXT:
|
||||
TEST_RECORDING_CONTEXT.__enter__()
|
||||
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
|
||||
|
||||
dist_registry, _ = await create_dist_registry(self.run_config.metadata_store, self.run_config.image_name)
|
||||
policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else []
|
||||
impls = await resolve_impls(
|
||||
self.run_config, self.provider_registry or get_provider_registry(self.run_config), dist_registry, policy
|
||||
)
|
||||
|
||||
# Add internal implementations after all other providers are resolved
|
||||
add_internal_implementations(impls, self.run_config)
|
||||
|
||||
if Api.prompts in impls:
|
||||
await impls[Api.prompts].initialize()
|
||||
|
||||
await register_resources(self.run_config, impls)
|
||||
|
||||
await refresh_registry_once(impls)
|
||||
self.impls = impls
|
||||
|
||||
def create_registry_refresh_task(self):
|
||||
assert self.impls is not None, "Must call initialize() before starting"
|
||||
|
||||
global REGISTRY_REFRESH_TASK
|
||||
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(self.impls))
|
||||
|
||||
def cb(task):
|
||||
import traceback
|
||||
|
||||
if task.cancelled():
|
||||
logger.error("Model refresh task cancelled")
|
||||
elif task.exception():
|
||||
logger.error(f"Model refresh task failed: {task.exception()}")
|
||||
traceback.print_exception(task.exception())
|
||||
else:
|
||||
logger.debug("Model refresh task completed")
|
||||
|
||||
REGISTRY_REFRESH_TASK.add_done_callback(cb)
|
||||
|
||||
async def shutdown(self):
|
||||
for impl in self.impls.values():
|
||||
impl_name = impl.__class__.__name__
|
||||
logger.info(f"Shutting down {impl_name}")
|
||||
try:
|
||||
if hasattr(impl, "shutdown"):
|
||||
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
||||
else:
|
||||
logger.warning(f"No shutdown method for {impl_name}")
|
||||
except TimeoutError:
|
||||
logger.exception(f"Shutdown timeout for {impl_name}")
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
logger.exception(f"Failed to shutdown {impl_name}: {e}")
|
||||
|
||||
global TEST_RECORDING_CONTEXT
|
||||
TEST_RECORDING_CONTEXT = setup_inference_recording()
|
||||
if TEST_RECORDING_CONTEXT:
|
||||
TEST_RECORDING_CONTEXT.__enter__()
|
||||
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
|
||||
try:
|
||||
TEST_RECORDING_CONTEXT.__exit__(None, None, None)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during inference recording cleanup: {e}")
|
||||
|
||||
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
||||
policy = run_config.server.auth.access_policy if run_config.server.auth else []
|
||||
impls = await resolve_impls(
|
||||
run_config, provider_registry or get_provider_registry(run_config), dist_registry, policy
|
||||
)
|
||||
|
||||
# Add internal implementations after all other providers are resolved
|
||||
add_internal_implementations(impls, run_config)
|
||||
|
||||
await register_resources(run_config, impls)
|
||||
|
||||
await refresh_registry_once(impls)
|
||||
|
||||
global REGISTRY_REFRESH_TASK
|
||||
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(impls))
|
||||
|
||||
def cb(task):
|
||||
import traceback
|
||||
|
||||
if task.cancelled():
|
||||
logger.error("Model refresh task cancelled")
|
||||
elif task.exception():
|
||||
logger.error(f"Model refresh task failed: {task.exception()}")
|
||||
traceback.print_exception(task.exception())
|
||||
else:
|
||||
logger.debug("Model refresh task completed")
|
||||
|
||||
REGISTRY_REFRESH_TASK.add_done_callback(cb)
|
||||
return impls
|
||||
|
||||
|
||||
async def shutdown_stack(impls: dict[Api, Any]):
|
||||
for impl in impls.values():
|
||||
impl_name = impl.__class__.__name__
|
||||
logger.info(f"Shutting down {impl_name}")
|
||||
try:
|
||||
if hasattr(impl, "shutdown"):
|
||||
await asyncio.wait_for(impl.shutdown(), timeout=5)
|
||||
else:
|
||||
logger.warning(f"No shutdown method for {impl_name}")
|
||||
except TimeoutError:
|
||||
logger.exception(f"Shutdown timeout for {impl_name}")
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
logger.exception(f"Failed to shutdown {impl_name}: {e}")
|
||||
|
||||
global TEST_RECORDING_CONTEXT
|
||||
if TEST_RECORDING_CONTEXT:
|
||||
try:
|
||||
TEST_RECORDING_CONTEXT.__exit__(None, None, None)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during inference recording cleanup: {e}")
|
||||
|
||||
global REGISTRY_REFRESH_TASK
|
||||
if REGISTRY_REFRESH_TASK:
|
||||
REGISTRY_REFRESH_TASK.cancel()
|
||||
global REGISTRY_REFRESH_TASK
|
||||
if REGISTRY_REFRESH_TASK:
|
||||
REGISTRY_REFRESH_TASK.cancel()
|
||||
|
||||
|
||||
async def refresh_registry_once(impls: dict[Api, Any]):
|
||||
|
|
|
@ -123,6 +123,6 @@ if [[ "$env_type" == "venv" ]]; then
|
|||
$other_args
|
||||
elif [[ "$env_type" == "container" ]]; then
|
||||
echo -e "${RED}Warning: Llama Stack no longer supports running Containers via the 'llama stack run' command.${NC}"
|
||||
echo -e "Please refer to the documentation for more information: https://llama-stack.readthedocs.io/en/latest/distributions/building_distro.html#llama-stack-build"
|
||||
echo -e "Please refer to the documentation for more information: https://llamastack.github.io/latest/distributions/building_distro.html#llama-stack-build"
|
||||
exit 1
|
||||
fi
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
## Developer Setup
|
||||
|
||||
1. Start up Llama Stack API server. More details [here](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html).
|
||||
1. Start up Llama Stack API server. More details [here](https://llamastack.github.io/latest/getting_started/index.htmll).
|
||||
|
||||
```
|
||||
llama stack build --distro together --image-type venv
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue