Merge branch 'main' into chore_build

This commit is contained in:
Wen Zhou 2025-06-26 13:30:43 +02:00 committed by GitHub
commit 604e42c56d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
196 changed files with 2332 additions and 1515 deletions

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from enum import StrEnum
from typing import Self
from pydantic import BaseModel, model_validator
@ -12,7 +12,7 @@ from pydantic import BaseModel, model_validator
from .conditions import parse_conditions
class Action(str, Enum):
class Action(StrEnum):
CREATE = "create"
READ = "read"
UPDATE = "update"

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from enum import StrEnum
from pathlib import Path
from typing import Annotated, Any
@ -29,8 +29,8 @@ from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig
LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
LLAMA_STACK_BUILD_CONFIG_VERSION = 2
LLAMA_STACK_RUN_CONFIG_VERSION = 2
RoutingKey = str | list[str]
@ -159,7 +159,7 @@ class LoggingConfig(BaseModel):
)
class AuthProviderType(str, Enum):
class AuthProviderType(StrEnum):
"""Supported authentication provider types."""
OAUTH2_TOKEN = "oauth2_token"
@ -182,7 +182,7 @@ class AuthenticationRequiredError(Exception):
pass
class QuotaPeriod(str, Enum):
class QuotaPeriod(StrEnum):
DAY = "day"
@ -229,7 +229,7 @@ class ServerConfig(BaseModel):
class StackRunConfig(BaseModel):
version: str = LLAMA_STACK_RUN_CONFIG_VERSION
version: int = LLAMA_STACK_RUN_CONFIG_VERSION
image_name: str = Field(
...,
@ -300,7 +300,7 @@ a default SQLite store will be used.""",
class BuildConfig(BaseModel):
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
version: int = LLAMA_STACK_BUILD_CONFIG_VERSION
distribution_spec: DistributionSpec = Field(description="The distribution spec to build including API providers. ")
image_type: str = Field(

View file

@ -30,7 +30,13 @@ from llama_stack.apis.inference import (
ListOpenAIChatCompletionResponse,
LogProbConfig,
Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAICompletionWithInputMessages,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
Order,
ResponseFormat,
SamplingParams,
@ -41,14 +47,6 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
from llama_stack.log import get_logger

View file

@ -16,17 +16,15 @@ from llama_stack.apis.vector_io import (
QueryChunksResponse,
SearchRankingOptions,
VectorIO,
VectorStoreDeleteResponse,
VectorStoreListResponse,
VectorStoreObject,
VectorStoreSearchResponsePage,
)
from llama_stack.apis.vector_io.vector_io import (
VectorStoreChunkingStrategy,
VectorStoreDeleteResponse,
VectorStoreFileContentsResponse,
VectorStoreFileDeleteResponse,
VectorStoreFileObject,
VectorStoreFileStatus,
VectorStoreListResponse,
VectorStoreObject,
VectorStoreSearchResponsePage,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable

View file

@ -98,6 +98,15 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
method = getattr(impls[api], register_method)
for obj in objects:
# In complex templates, like our starter template, we may have dynamic model ids
# given by environment variables. This allows those environment variables to have
# a default value of __disabled__ to skip registration of the model if not set.
if (
hasattr(obj, "provider_model_id")
and obj.provider_model_id is not None
and "__disabled__" in obj.provider_model_id
):
continue
# we want to maintain the type information in arguments to method.
# instead of method(**obj.model_dump()), which may convert a typed attr to a dict,
# we use model_dump() to find all the attrs and then getattr to get the still typed value.
@ -118,7 +127,12 @@ class EnvVarError(Exception):
def __init__(self, var_name: str, path: str = ""):
self.var_name = var_name
self.path = path
super().__init__(f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}")
super().__init__(
f"Environment variable '{var_name}' not set or empty {f'at {path}' if path else ''}. "
f"Use ${{env.{var_name}:=default_value}} to provide a default value, "
f"${{env.{var_name}:+value_if_set}} to make the field conditional, "
f"or ensure the environment variable is set."
)
def replace_env_vars(config: Any, path: str = "") -> Any:
@ -141,25 +155,27 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
return result
elif isinstance(config, str):
# Updated pattern to support both default values (:) and conditional values (+)
pattern = r"\${env\.([A-Z0-9_]+)(?:([:\+])([^}]*))?}"
# Pattern supports bash-like syntax: := for default and :+ for conditional and a optional value
pattern = r"\${env\.([A-Z0-9_]+)(?::([=+])([^}]*))?}"
def get_env_var(match):
def get_env_var(match: re.Match):
env_var = match.group(1)
operator = match.group(2) # ':' for default, '+' for conditional
operator = match.group(2) # '=' for default, '+' for conditional
value_expr = match.group(3)
env_value = os.environ.get(env_var)
if operator == ":": # Default value syntax: ${env.FOO:default}
if operator == "=": # Default value syntax: ${env.FOO:=default}
if not env_value:
if value_expr is None:
# value_expr returns empty string (not None) when not matched
# This means ${env.FOO:=} is an error
if value_expr == "":
raise EnvVarError(env_var, path)
else:
value = value_expr
else:
value = env_value
elif operator == "+": # Conditional value syntax: ${env.FOO+value_if_set}
elif operator == "+": # Conditional value syntax: ${env.FOO:+value_if_set}
if env_value:
value = value_expr
else:
@ -174,13 +190,42 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
return os.path.expanduser(value)
try:
return re.sub(pattern, get_env_var, config)
result = re.sub(pattern, get_env_var, config)
return _convert_string_to_proper_type(result)
except EnvVarError as e:
raise EnvVarError(e.var_name, e.path) from None
return config
def _convert_string_to_proper_type(value: str) -> Any:
# This might be tricky depending on what the config type is, if 'str | None' we are
# good, if 'str' we need to keep the empty string... 'str | None' is more common and
# providers config should be typed this way.
# TODO: we could try to load the config class and see if the config has a field with type 'str | None'
# and then convert the empty string to None or not
if value == "":
return None
lowered = value.lower()
if lowered == "true":
return True
elif lowered == "false":
return False
try:
return int(value)
except ValueError:
pass
try:
return float(value)
except ValueError:
pass
return value
def validate_env_pair(env_pair: str) -> tuple[str, str]:
"""Validate and split an environment variable key-value pair."""
try:

View file

@ -25,7 +25,7 @@ class LlamaStackApi:
def run_scoring(self, row, scoring_function_ids: list[str], scoring_params: dict | None):
"""Run scoring on a single row"""
if not scoring_params:
scoring_params = {fn_id: None for fn_id in scoring_function_ids}
scoring_params = dict.fromkeys(scoring_function_ids)
return self.client.scoring.score(input_rows=[row], scoring_functions=scoring_params)