Merge branch 'main' into eval_api_final

This commit is contained in:
Xi Yan 2025-03-26 12:29:45 -07:00
commit bc0cd07008
79 changed files with 3257 additions and 2358 deletions

View file

@ -12,6 +12,7 @@ from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.evaluation import Evaluation
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.models import Models
@ -74,6 +75,7 @@ def api_protocol_map() -> Dict[Api, Any]:
Api.tool_groups: ToolGroups,
Api.tool_runtime: ToolRuntime,
Api.evaluation: Evaluation,
Api.files: Files,
}
@ -107,7 +109,9 @@ async def resolve_impls(
2. Sorting them in dependency order.
3. Instantiating them with required dependencies.
"""
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
routing_table_apis = {
x.routing_table_api for x in builtin_automatically_routed_apis()
}
router_apis = {x.router_api for x in builtin_automatically_routed_apis()}
providers_with_specs = validate_and_prepare_providers(
@ -115,7 +119,9 @@ async def resolve_impls(
)
apis_to_serve = run_config.apis or set(
list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis]
list(providers_with_specs.keys())
+ [x.value for x in routing_table_apis]
+ [x.value for x in router_apis]
)
providers_with_specs.update(specs_for_autorouted_apis(apis_to_serve))
@ -180,17 +186,23 @@ def validate_and_prepare_providers(
for api_str, providers in run_config.providers.items():
api = Api(api_str)
if api in routing_table_apis:
raise ValueError(f"Provider for `{api_str}` is automatically provided and cannot be overridden")
raise ValueError(
f"Provider for `{api_str}` is automatically provided and cannot be overridden"
)
specs = {}
for provider in providers:
if not provider.provider_id or provider.provider_id == "__disabled__":
logger.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
logger.warning(
f"Provider `{provider.provider_type}` for API `{api}` is disabled"
)
continue
validate_provider(provider, api, provider_registry)
p = provider_registry[api][provider.provider_type]
p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
p.deps__ = [a.value for a in p.api_dependencies] + [
a.value for a in p.optional_api_dependencies
]
spec = ProviderWithSpec(spec=p, **provider.model_dump())
specs[provider.provider_id] = spec
@ -200,10 +212,14 @@ def validate_and_prepare_providers(
return providers_with_specs
def validate_provider(provider: Provider, api: Api, provider_registry: ProviderRegistry):
def validate_provider(
provider: Provider, api: Api, provider_registry: ProviderRegistry
):
"""Validates if the provider is allowed and handles deprecations."""
if provider.provider_type not in provider_registry[api]:
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
raise ValueError(
f"Provider `{provider.provider_type}` is not available for API `{api}`"
)
p = provider_registry[api][provider.provider_type]
if p.deprecation_error:
@ -278,7 +294,9 @@ async def instantiate_providers(
) -> Dict:
"""Instantiates providers asynchronously while managing dependencies."""
impls: Dict[Api, Any] = {}
inner_impls_by_provider_id: Dict[str, Dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
inner_impls_by_provider_id: Dict[str, Dict[str, Any]] = {
f"inner-{x.value}": {} for x in router_apis
}
for api_str, provider in sorted_providers:
deps = {a: impls[a] for a in provider.spec.api_dependencies}
for a in provider.spec.optional_api_dependencies:
@ -287,7 +305,9 @@ async def instantiate_providers(
inner_impls = {}
if isinstance(provider.spec, RoutingTableProviderSpec):
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
inner_impls = inner_impls_by_provider_id[
f"inner-{provider.spec.router_api.value}"
]
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry)
@ -345,7 +365,9 @@ async def instantiate_provider(
provider_spec = provider.spec
if not hasattr(provider_spec, "module"):
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
raise AttributeError(
f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute"
)
module = importlib.import_module(provider_spec.module)
args = []
@ -382,7 +404,10 @@ async def instantiate_provider(
# TODO: check compliance for special tool groups
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
check_protocol_compliance(impl, protocols[provider_spec.api])
if not isinstance(provider_spec, AutoRoutedProviderSpec) and provider_spec.api in additional_protocols:
if (
not isinstance(provider_spec, AutoRoutedProviderSpec)
and provider_spec.api in additional_protocols
):
additional_api, _, _ = additional_protocols[provider_spec.api]
check_protocol_compliance(impl, additional_api)
@ -410,12 +435,19 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
obj_params = set(obj_sig.parameters)
obj_params.discard("self")
if not (proto_params <= obj_params):
logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
logger.error(
f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}"
)
missing_methods.append((name, "signature_mismatch"))
else:
# Check if the method is actually implemented in the class
method_owner = next((cls for cls in mro if name in cls.__dict__), None)
if method_owner is None or method_owner.__name__ == protocol.__name__:
method_owner = next(
(cls for cls in mro if name in cls.__dict__), None
)
if (
method_owner is None
or method_owner.__name__ == protocol.__name__
):
missing_methods.append((name, "not_actually_implemented"))
if missing_methods:

View file

@ -13,6 +13,7 @@ LLAMA_CHECKPOINT_DIR=${LLAMA_CHECKPOINT_DIR:-}
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
PYPI_VERSION=${PYPI_VERSION:-}
VIRTUAL_ENV=${VIRTUAL_ENV:-}
set -euo pipefail
@ -69,22 +70,25 @@ while [[ $# -gt 0 ]]; do
;;
esac
done
PYTHON_BINARY="python"
case "$env_type" in
"venv")
# Activate virtual environment
if [ ! -d "$env_path_or_name" ]; then
echo -e "${RED}Error: Virtual environment not found at $env_path_or_name${NC}" >&2
exit 1
fi
if [ -n "$VIRTUAL_ENV" && "$VIRTUAL_ENV" == "$env_path_or_name" ]; then
echo -e "${GREEN}Virtual environment already activated${NC}" >&2
else
# Activate virtual environment
if [ ! -d "$env_path_or_name" ]; then
echo -e "${RED}Error: Virtual environment not found at $env_path_or_name${NC}" >&2
exit 1
fi
if [ ! -f "$env_path_or_name/bin/activate" ]; then
echo -e "${RED}Error: Virtual environment activate binary not found at $env_path_or_name/bin/activate" >&2
exit 1
fi
if [ ! -f "$env_path_or_name/bin/activate" ]; then
echo -e "${RED}Error: Virtual environment activate binary not found at $env_path_or_name/bin/activate" >&2
exit 1
fi
source "$env_path_or_name/bin/activate"
source "$env_path_or_name/bin/activate"
fi
;;
"conda")
if ! is_command_available conda; then

View file

@ -18,15 +18,19 @@ def preserve_contexts_async_generator(
This is needed because we start a new asyncio event loop for each streaming request,
and we need to preserve the context across the event loop boundary.
"""
# Capture initial context values
initial_context_values = {context_var.name: context_var.get() for context_var in context_vars}
async def wrapper() -> AsyncGenerator[T, None]:
while True:
try:
item = await gen.__anext__()
context_values = {context_var.name: context_var.get() for context_var in context_vars}
yield item
# Restore context values before any await
for context_var in context_vars:
_ = context_var.set(context_values[context_var.name])
context_var.set(initial_context_values[context_var.name])
item = await gen.__anext__()
yield item
except StopAsyncIteration:
break