Merge branch 'main' into add-watsonx-inference-adapter

This commit is contained in:
Sajikumar JS 2025-03-24 23:49:51 +05:30
commit 363e2565f5
162 changed files with 3845 additions and 3126 deletions

View file

@ -6,13 +6,17 @@
from typing import Any, Dict, Optional
from llama_stack.distribution.datatypes import RoutableObjectWithProvider
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.log import get_logger
logger = get_logger(__name__, category="core")
def check_access(obj: RoutableObjectWithProvider, user_attributes: Optional[Dict[str, Any]] = None) -> bool:
def check_access(
obj_identifier: str,
obj_attributes: Optional[AccessAttributes],
user_attributes: Optional[Dict[str, Any]] = None,
) -> bool:
"""Check if the current user has access to the given object, based on access attributes.
Access control algorithm:
@ -43,39 +47,40 @@ def check_access(obj: RoutableObjectWithProvider, user_attributes: Optional[Dict
# - The extra "projects" attribute is ignored
Args:
obj: The resource object to check access for
obj_identifier: The identifier of the resource object to check access for
obj_attributes: The access attributes of the resource object
user_attributes: The attributes of the current user
Returns:
bool: True if access is granted, False if denied
"""
# If object has no access attributes, allow access by default
if not hasattr(obj, "access_attributes") or not obj.access_attributes:
if not obj_attributes:
return True
# If no user attributes, deny access to objects with access control
if not user_attributes:
return False
obj_attributes = obj.access_attributes.model_dump(exclude_none=True)
if not obj_attributes:
dict_attribs = obj_attributes.model_dump(exclude_none=True)
if not dict_attribs:
return True
# Check each attribute category (requires ALL categories to match)
for attr_key, required_values in obj_attributes.items():
# TODO: formalize this into a proper ABAC policy
for attr_key, required_values in dict_attribs.items():
user_values = user_attributes.get(attr_key, [])
if not user_values:
logger.debug(
f"Access denied to {obj.type} '{obj.identifier}': missing required attribute category '{attr_key}'"
)
logger.debug(f"Access denied to {obj_identifier}: missing required attribute category '{attr_key}'")
return False
if not any(val in user_values for val in required_values):
logger.debug(
f"Access denied to {obj.type} '{obj.identifier}': "
f"Access denied to {obj_identifier}: "
f"no match for attribute '{attr_key}', required one of {required_values}"
)
return False
logger.debug(f"Access granted to {obj.type} '{obj.identifier}'")
logger.debug(f"Access granted to {obj_identifier}")
return True

View file

@ -92,6 +92,7 @@ RUN apt-get update && apt-get install -y \
procps psmisc lsof \
traceroute \
bubblewrap \
gcc \
&& rm -rf /var/lib/apt/lists/*
ENV UV_SYSTEM_PYTHON=1

View file

@ -9,7 +9,6 @@ import inspect
import json
import logging
import os
import re
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from pathlib import Path
@ -37,7 +36,10 @@ from llama_stack.distribution.request_headers import (
request_provider_data_context,
)
from llama_stack.distribution.resolver import ProviderRegistry
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.distribution.server.endpoints import (
find_matching_endpoint,
initialize_endpoint_impls,
)
from llama_stack.distribution.stack import (
construct_stack,
get_stack_run_config_from_template,
@ -232,31 +234,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
safe_config = redact_sensitive_fields(self.config.model_dump())
console.print(yaml.dump(safe_config, indent=2))
endpoints = get_all_api_endpoints()
endpoint_impls = {}
def _convert_path_to_regex(path: str) -> str:
# Convert {param} to named capture groups
# handle {param:path} as well which allows for forward slashes in the param value
pattern = re.sub(
r"{(\w+)(?::path)?}",
lambda m: f"(?P<{m.group(1)}>{'[^/]+' if not m.group(0).endswith(':path') else '.+'})",
path,
)
return f"^{pattern}$"
for api, api_endpoints in endpoints.items():
if api not in self.impls:
continue
for endpoint in api_endpoints:
impl = self.impls[api]
func = getattr(impl, endpoint.name)
if endpoint.method not in endpoint_impls:
endpoint_impls[endpoint.method] = {}
endpoint_impls[endpoint.method][_convert_path_to_regex(endpoint.route)] = func
self.endpoint_impls = endpoint_impls
self.endpoint_impls = initialize_endpoint_impls(self.impls)
return True
async def request(
@ -290,32 +268,6 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
)
return response
def _find_matching_endpoint(self, method: str, path: str) -> tuple[Any, dict]:
"""Find the matching endpoint implementation for a given method and path.
Args:
method: HTTP method (GET, POST, etc.)
path: URL path to match against
Returns:
A tuple of (endpoint_function, path_params)
Raises:
ValueError: If no matching endpoint is found
"""
impls = self.endpoint_impls.get(method)
if not impls:
raise ValueError(f"No endpoint found for {path}")
for regex, func in impls.items():
match = re.match(regex, path)
if match:
# Extract named groups from the regex match
path_params = match.groupdict()
return func, path_params
raise ValueError(f"No endpoint found for {path}")
async def _call_non_streaming(
self,
*,
@ -326,10 +278,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
body = options.params or {}
body |= options.json_data or {}
matched_func, path_params = self._find_matching_endpoint(options.method, path)
matched_func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls)
body |= path_params
body = self._convert_body(path, options.method, body)
await start_trace(options.url, {"__location__": "library_client"})
await start_trace(route, {"__location__": "library_client"})
try:
result = await matched_func(**body)
finally:
@ -371,13 +323,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
path = options.url
body = options.params or {}
body |= options.json_data or {}
func, path_params = self._find_matching_endpoint(options.method, path)
func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls)
body |= path_params
body = self._convert_body(path, options.method, body)
async def gen():
await start_trace(options.url, {"__location__": "library_client"})
await start_trace(route, {"__location__": "library_client"})
try:
async for chunk in await func(**body):
data = json.dumps(convert_pydantic_to_json_value(chunk))
@ -422,7 +374,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
if not body:
return {}
func, _ = self._find_matching_endpoint(method, path)
func, _, _ = find_matching_endpoint(method, path, self.endpoint_impls)
sig = inspect.signature(func)
# Strip NOT_GIVENs to use the defaults in signature

View file

@ -14,13 +14,7 @@ from llama_stack.apis.common.content_types import (
)
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.apis.eval import (
BenchmarkConfig,
Eval,
EvaluateResponse,
Job,
JobStatus,
)
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEventType,
@ -623,7 +617,7 @@ class EvalRouter(Eval):
self,
benchmark_id: str,
job_id: str,
) -> Optional[JobStatus]:
) -> Job:
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)

View file

@ -198,7 +198,7 @@ class CommonRoutingTableImpl(RoutingTable):
return None
# Check if user has permission to access this object
if not check_access(obj, get_auth_attributes()):
if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()):
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch")
return None
@ -241,7 +241,11 @@ class CommonRoutingTableImpl(RoutingTable):
# Apply attribute-based access control filtering
if filtered_objs:
filtered_objs = [obj for obj in filtered_objs if check_access(obj, get_auth_attributes())]
filtered_objs = [
obj
for obj in filtered_objs
if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes())
]
return filtered_objs

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import inspect
import re
from typing import Dict, List
from pydantic import BaseModel
@ -19,6 +20,7 @@ class ApiEndpoint(BaseModel):
route: str
method: str
name: str
descriptive_name: str | None = None
def toolgroup_protocol_map():
@ -58,8 +60,69 @@ def get_all_api_endpoints() -> Dict[Api, List[ApiEndpoint]]:
method = "delete"
else:
method = "post"
endpoints.append(ApiEndpoint(route=route, method=method, name=name))
endpoints.append(
ApiEndpoint(route=route, method=method, name=name, descriptive_name=webmethod.descriptive_name)
)
apis[api] = endpoints
return apis
def initialize_endpoint_impls(impls):
endpoints = get_all_api_endpoints()
endpoint_impls = {}
def _convert_path_to_regex(path: str) -> str:
# Convert {param} to named capture groups
# handle {param:path} as well which allows for forward slashes in the param value
pattern = re.sub(
r"{(\w+)(?::path)?}",
lambda m: f"(?P<{m.group(1)}>{'[^/]+' if not m.group(0).endswith(':path') else '.+'})",
path,
)
return f"^{pattern}$"
for api, api_endpoints in endpoints.items():
if api not in impls:
continue
for endpoint in api_endpoints:
impl = impls[api]
func = getattr(impl, endpoint.name)
if endpoint.method not in endpoint_impls:
endpoint_impls[endpoint.method] = {}
endpoint_impls[endpoint.method][_convert_path_to_regex(endpoint.route)] = (
func,
endpoint.descriptive_name or endpoint.route,
)
return endpoint_impls
def find_matching_endpoint(method, path, endpoint_impls):
"""Find the matching endpoint implementation for a given method and path.
Args:
method: HTTP method (GET, POST, etc.)
path: URL path to match against
endpoint_impls: A dictionary of endpoint implementations
Returns:
A tuple of (endpoint_function, path_params, descriptive_name)
Raises:
ValueError: If no matching endpoint is found
"""
impls = endpoint_impls.get(method.lower())
if not impls:
raise ValueError(f"No endpoint found for {path}")
for regex, (func, descriptive_name) in impls.items():
match = re.match(regex, path)
if match:
# Extract named groups from the regex match
path_params = match.groupdict()
return func, path_params, descriptive_name
raise ValueError(f"No endpoint found for {path}")

View file

@ -32,6 +32,10 @@ from llama_stack.distribution.request_headers import (
request_provider_data_context,
)
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.server.endpoints import (
find_matching_endpoint,
initialize_endpoint_impls,
)
from llama_stack.distribution.stack import (
construct_stack,
redact_sensitive_fields,
@ -222,14 +226,30 @@ def create_dynamic_typed_route(func: Any, method: str, route: str):
class TracingMiddleware:
def __init__(self, app):
def __init__(self, app, impls):
self.app = app
self.impls = impls
async def __call__(self, scope, receive, send):
path = scope.get("path", "")
await start_trace(path, {"__location__": "server"})
try:
if scope.get("type") == "lifespan":
return await self.app(scope, receive, send)
path = scope.get("path", "")
if not hasattr(self, "endpoint_impls"):
self.endpoint_impls = initialize_endpoint_impls(self.impls)
_, _, trace_path = find_matching_endpoint(scope.get("method", "GET"), path, self.endpoint_impls)
trace_context = await start_trace(trace_path, {"__location__": "server", "raw_path": path})
async def send_with_trace_id(message):
if message["type"] == "http.response.start":
headers = message.get("headers", [])
headers.append([b"x-trace-id", str(trace_context.trace_id).encode()])
message["headers"] = headers
await send(message)
try:
return await self.app(scope, receive, send_with_trace_id)
finally:
await end_trace()
@ -351,7 +371,6 @@ def main():
logger.info(yaml.dump(safe_config, indent=2))
app = FastAPI(lifespan=lifespan)
app.add_middleware(TracingMiddleware)
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
app.add_middleware(ClientVersionMiddleware)
@ -415,6 +434,7 @@ def main():
app.exception_handler(Exception)(global_exception_handler)
app.__llama_stack_impls__ = impls
app.add_middleware(TracingMiddleware, impls=impls)
import uvicorn

View file

@ -12,9 +12,12 @@ import pydantic
from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
logger = get_logger(__name__, category="core")
class DistributionRegistry(Protocol):
async def get_all(self) -> List[RoutableObjectWithProvider]: ...
@ -47,8 +50,13 @@ def _parse_registry_values(values: List[str]) -> List[RoutableObjectWithProvider
"""Utility function to parse registry values into RoutableObjectWithProvider objects."""
all_objects = []
for value in values:
obj = pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(value)
all_objects.append(obj)
try:
obj = pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(value)
all_objects.append(obj)
except pydantic.ValidationError as e:
logger.error(f"Error parsing registry value, raw value: {value}. Error: {e}")
continue
return all_objects
@ -73,7 +81,11 @@ class DiskDistributionRegistry(DistributionRegistry):
if not json_str:
return None
return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(json_str)
try:
return pydantic.TypeAdapter(RoutableObjectWithProvider).validate_json(json_str)
except pydantic.ValidationError as e:
logger.error(f"Error parsing registry value for {type}:{identifier}, raw value: {json_str}. Error: {e}")
return None
async def update(self, obj: RoutableObjectWithProvider) -> None:
await self.kvstore.set(

View file

@ -5,9 +5,7 @@
# the root directory of this source tree.
import streamlit as st
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types.shared.document import Document
from llama_stack_client import Agent, AgentEventLogger, RAGDocument
from llama_stack.distribution.ui.modules.api import llama_stack_api
from llama_stack.distribution.ui.modules.utils import data_url_from_file
@ -35,7 +33,7 @@ def rag_chat_page():
)
if st.button("Create Vector Database"):
documents = [
Document(
RAGDocument(
document_id=uploaded_file.name,
content=data_url_from_file(uploaded_file),
)
@ -167,7 +165,7 @@ def rag_chat_page():
message_placeholder = st.empty()
full_response = ""
retrieval_response = ""
for log in EventLogger().log(response):
for log in AgentEventLogger().log(response):
log.print()
if log.role == "tool_execution":
retrieval_response += log.content.replace("====", "").strip()