Merge branch 'main' into chore/strong-typing

This commit is contained in:
Stefan Thaler 2025-07-21 07:40:00 +01:00 committed by GitHub
commit 16d6a7a22f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
84 changed files with 3177 additions and 2793 deletions

View file

@ -12,11 +12,13 @@ import os
import sys
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from io import BytesIO
from pathlib import Path
from typing import Any, TypeVar, Union, get_args, get_origin
import httpx
import yaml
from fastapi import Response as FastAPIResponse
from llama_stack_client import (
NOT_GIVEN,
APIResponse,
@ -112,6 +114,27 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any:
raise ValueError(f"Failed to convert parameter {value} into {annotation}: {e}") from e
class LibraryClientUploadFile:
"""LibraryClient UploadFile object that mimics FastAPI's UploadFile interface."""
def __init__(self, filename: str, content: bytes):
self.filename = filename
self.content = content
self.content_type = "application/octet-stream"
async def read(self) -> bytes:
return self.content
class LibraryClientHttpxResponse:
"""LibraryClient httpx Response object for FastAPI Response conversion."""
def __init__(self, response):
self.content = response.body if isinstance(response.body, bytes) else response.body.encode()
self.status_code = response.status_code
self.headers = response.headers
class LlamaStackAsLibraryClient(LlamaStackClient):
def __init__(
self,
@ -128,6 +151,8 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
self.skip_logger_removal = skip_logger_removal
self.provider_data = provider_data
self.loop = asyncio.new_event_loop()
def initialize(self):
if in_notebook():
import nest_asyncio
@ -136,7 +161,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
if not self.skip_logger_removal:
self._remove_root_logger_handlers()
return asyncio.run(self.async_client.initialize())
return self.loop.run_until_complete(self.async_client.initialize())
def _remove_root_logger_handlers(self):
"""
@ -149,10 +174,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
def request(self, *args, **kwargs):
# NOTE: We are using AsyncLlamaStackClient under the hood
# A new event loop is needed to convert the AsyncStream
# from async client into SyncStream return type for streaming
loop = asyncio.new_event_loop()
loop = self.loop
asyncio.set_event_loop(loop)
if kwargs.get("stream"):
@ -169,7 +191,6 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
pending = asyncio.all_tasks(loop)
if pending:
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
loop.close()
return sync_generator()
else:
@ -179,7 +200,6 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
pending = asyncio.all_tasks(loop)
if pending:
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
loop.close()
return result
@ -295,6 +315,31 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
)
return response
def _handle_file_uploads(self, options: Any, body: dict) -> tuple[dict, list[str]]:
"""Handle file uploads from OpenAI client and add them to the request body."""
if not (hasattr(options, "files") and options.files):
return body, []
if not isinstance(options.files, list):
return body, []
field_names = []
for file_tuple in options.files:
if not (isinstance(file_tuple, tuple) and len(file_tuple) >= 2):
continue
field_name = file_tuple[0]
file_object = file_tuple[1]
if isinstance(file_object, BytesIO):
file_object.seek(0)
file_content = file_object.read()
filename = getattr(file_object, "name", "uploaded_file")
field_names.append(field_name)
body[field_name] = LibraryClientUploadFile(filename, file_content)
return body, field_names
async def _call_non_streaming(
self,
*,
@ -310,15 +355,23 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
matched_func, path_params, route = find_matching_route(options.method, path, self.route_impls)
body |= path_params
body = self._convert_body(path, options.method, body)
body, field_names = self._handle_file_uploads(options, body)
body = self._convert_body(path, options.method, body, exclude_params=set(field_names))
await start_trace(route, {"__location__": "library_client"})
try:
result = await matched_func(**body)
finally:
await end_trace()
# Handle FastAPI Response objects (e.g., from file content retrieval)
if isinstance(result, FastAPIResponse):
return LibraryClientHttpxResponse(result)
json_content = json.dumps(convert_pydantic_to_json_value(result))
filtered_body = {k: v for k, v in body.items() if not isinstance(v, LibraryClientUploadFile)}
mock_response = httpx.Response(
status_code=httpx.codes.OK,
content=json_content.encode("utf-8"),
@ -330,7 +383,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
url=options.url,
params=options.params,
headers=options.headers or {},
json=convert_pydantic_to_json_value(body),
json=convert_pydantic_to_json_value(filtered_body),
),
)
response = APIResponse(
@ -405,13 +458,17 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
)
return await response.parse()
def _convert_body(self, path: str, method: str, body: dict | None = None) -> dict:
def _convert_body(
self, path: str, method: str, body: dict | None = None, exclude_params: set[str] | None = None
) -> dict:
if not body:
return {}
if self.route_impls is None:
raise ValueError("Client not initialized")
exclude_params = exclude_params or set()
func, _, _ = find_matching_route(method, path, self.route_impls)
sig = inspect.signature(func)
@ -423,6 +480,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
for param_name, param in sig.parameters.items():
if param_name in body:
value = body.get(param_name)
converted_body[param_name] = convert_to_pydantic(param.annotation, value)
if param_name in exclude_params:
converted_body[param_name] = value
else:
converted_body[param_name] = convert_to_pydantic(param.annotation, value)
return converted_body

View file

@ -200,7 +200,7 @@ def validate_and_prepare_providers(
specs = {}
for provider in providers:
if not provider.provider_id or provider.provider_id == "__disabled__":
logger.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
logger.debug(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
continue
validate_provider(provider, api, provider_registry)

View file

@ -80,3 +80,38 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
if existing_model is None:
raise ValueError(f"Model {model_id} not found")
await self.unregister_object(existing_model)
async def update_registered_llm_models(
self,
provider_id: str,
models: list[Model],
) -> None:
existing_models = await self.get_all_with_type("model")
# we may have an alias for the model registered by the user (or during initialization
# from run.yaml) that we need to keep track of
model_ids = {}
for model in existing_models:
# we leave embeddings models alone because often we don't get metadata
# (embedding dimension, etc.) from the provider
if model.provider_id == provider_id and model.model_type == ModelType.llm:
model_ids[model.provider_resource_id] = model.identifier
logger.debug(f"unregistering model {model.identifier}")
await self.unregister_object(model)
for model in models:
if model.model_type != ModelType.llm:
continue
if model.provider_resource_id in model_ids:
model.identifier = model_ids[model.provider_resource_id]
logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
await self.register_object(
ModelWithOwner(
identifier=model.identifier,
provider_resource_id=model.provider_resource_id,
provider_id=provider_id,
metadata=model.metadata,
model_type=model.model_type,
)
)

View file

@ -445,9 +445,7 @@ def main(args: argparse.Namespace | None = None):
# now that the logger is initialized, print the line about which type of config we are using.
logger.info(log_line)
logger.info("Run configuration:")
safe_config = redact_sensitive_fields(config.model_dump(mode="json"))
logger.info(yaml.dump(safe_config, indent=2))
_log_run_config(run_config=config)
app = FastAPI(
lifespan=lifespan,
@ -455,6 +453,7 @@ def main(args: argparse.Namespace | None = None):
redoc_url="/redoc",
openapi_url="/openapi.json",
)
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
app.add_middleware(ClientVersionMiddleware)
@ -493,7 +492,13 @@ def main(args: argparse.Namespace | None = None):
)
try:
impls = asyncio.run(construct_stack(config))
# Create and set the event loop that will be used for both construction and server runtime
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Construct the stack in the persistent event loop
impls = loop.run_until_complete(construct_stack(config))
except InvalidProviderError as e:
logger.error(f"Error: {str(e)}")
sys.exit(1)
@ -591,7 +596,16 @@ def main(args: argparse.Namespace | None = None):
if ssl_config:
uvicorn_config.update(ssl_config)
uvicorn.run(**uvicorn_config)
# Run uvicorn in the existing event loop to preserve background tasks
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
def _log_run_config(run_config: StackRunConfig):
"""Logs the run config with redacted fields and disabled providers removed."""
logger.info("Run configuration:")
safe_config = redact_sensitive_fields(run_config.model_dump(mode="json"))
clean_config = remove_disabled_providers(safe_config)
logger.info(yaml.dump(clean_config, indent=2))
def extract_path_params(route: str) -> list[str]:
@ -602,5 +616,20 @@ def extract_path_params(route: str) -> list[str]:
return params
def remove_disabled_providers(obj):
if isinstance(obj, dict):
if (
obj.get("provider_id") == "__disabled__"
or obj.get("shield_id") == "__disabled__"
or obj.get("provider_model_id") == "__disabled__"
):
return None
return {k: v for k, v in ((k, remove_disabled_providers(v)) for k, v in obj.items()) if v is not None}
elif isinstance(obj, list):
return [item for item in (remove_disabled_providers(i) for i in obj) if item is not None]
else:
return obj
if __name__ == "__main__":
main()

View file

@ -172,7 +172,6 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
# Create a copy with resolved provider_id but original config
disabled_provider = v.copy()
disabled_provider["provider_id"] = resolved_provider_id
result.append(disabled_provider)
continue
except EnvVarError:
# If we can't resolve the provider_id, continue with normal processing