forked from phoenix-oss/llama-stack-mirror
Merge branch 'main' into eval_api_final
This commit is contained in:
commit
bc0cd07008
79 changed files with 3257 additions and 2358 deletions
|
@ -12,6 +12,7 @@ from llama_stack.apis.benchmarks import Benchmarks
|
|||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.evaluation import Evaluation
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.models import Models
|
||||
|
@ -74,6 +75,7 @@ def api_protocol_map() -> Dict[Api, Any]:
|
|||
Api.tool_groups: ToolGroups,
|
||||
Api.tool_runtime: ToolRuntime,
|
||||
Api.evaluation: Evaluation,
|
||||
Api.files: Files,
|
||||
}
|
||||
|
||||
|
||||
|
@ -107,7 +109,9 @@ async def resolve_impls(
|
|||
2. Sorting them in dependency order.
|
||||
3. Instantiating them with required dependencies.
|
||||
"""
|
||||
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
||||
routing_table_apis = {
|
||||
x.routing_table_api for x in builtin_automatically_routed_apis()
|
||||
}
|
||||
router_apis = {x.router_api for x in builtin_automatically_routed_apis()}
|
||||
|
||||
providers_with_specs = validate_and_prepare_providers(
|
||||
|
@ -115,7 +119,9 @@ async def resolve_impls(
|
|||
)
|
||||
|
||||
apis_to_serve = run_config.apis or set(
|
||||
list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis]
|
||||
list(providers_with_specs.keys())
|
||||
+ [x.value for x in routing_table_apis]
|
||||
+ [x.value for x in router_apis]
|
||||
)
|
||||
|
||||
providers_with_specs.update(specs_for_autorouted_apis(apis_to_serve))
|
||||
|
@ -180,17 +186,23 @@ def validate_and_prepare_providers(
|
|||
for api_str, providers in run_config.providers.items():
|
||||
api = Api(api_str)
|
||||
if api in routing_table_apis:
|
||||
raise ValueError(f"Provider for `{api_str}` is automatically provided and cannot be overridden")
|
||||
raise ValueError(
|
||||
f"Provider for `{api_str}` is automatically provided and cannot be overridden"
|
||||
)
|
||||
|
||||
specs = {}
|
||||
for provider in providers:
|
||||
if not provider.provider_id or provider.provider_id == "__disabled__":
|
||||
logger.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled")
|
||||
logger.warning(
|
||||
f"Provider `{provider.provider_type}` for API `{api}` is disabled"
|
||||
)
|
||||
continue
|
||||
|
||||
validate_provider(provider, api, provider_registry)
|
||||
p = provider_registry[api][provider.provider_type]
|
||||
p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
|
||||
p.deps__ = [a.value for a in p.api_dependencies] + [
|
||||
a.value for a in p.optional_api_dependencies
|
||||
]
|
||||
spec = ProviderWithSpec(spec=p, **provider.model_dump())
|
||||
specs[provider.provider_id] = spec
|
||||
|
||||
|
@ -200,10 +212,14 @@ def validate_and_prepare_providers(
|
|||
return providers_with_specs
|
||||
|
||||
|
||||
def validate_provider(provider: Provider, api: Api, provider_registry: ProviderRegistry):
|
||||
def validate_provider(
|
||||
provider: Provider, api: Api, provider_registry: ProviderRegistry
|
||||
):
|
||||
"""Validates if the provider is allowed and handles deprecations."""
|
||||
if provider.provider_type not in provider_registry[api]:
|
||||
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
|
||||
raise ValueError(
|
||||
f"Provider `{provider.provider_type}` is not available for API `{api}`"
|
||||
)
|
||||
|
||||
p = provider_registry[api][provider.provider_type]
|
||||
if p.deprecation_error:
|
||||
|
@ -278,7 +294,9 @@ async def instantiate_providers(
|
|||
) -> Dict:
|
||||
"""Instantiates providers asynchronously while managing dependencies."""
|
||||
impls: Dict[Api, Any] = {}
|
||||
inner_impls_by_provider_id: Dict[str, Dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
|
||||
inner_impls_by_provider_id: Dict[str, Dict[str, Any]] = {
|
||||
f"inner-{x.value}": {} for x in router_apis
|
||||
}
|
||||
for api_str, provider in sorted_providers:
|
||||
deps = {a: impls[a] for a in provider.spec.api_dependencies}
|
||||
for a in provider.spec.optional_api_dependencies:
|
||||
|
@ -287,7 +305,9 @@ async def instantiate_providers(
|
|||
|
||||
inner_impls = {}
|
||||
if isinstance(provider.spec, RoutingTableProviderSpec):
|
||||
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
|
||||
inner_impls = inner_impls_by_provider_id[
|
||||
f"inner-{provider.spec.router_api.value}"
|
||||
]
|
||||
|
||||
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry)
|
||||
|
||||
|
@ -345,7 +365,9 @@ async def instantiate_provider(
|
|||
|
||||
provider_spec = provider.spec
|
||||
if not hasattr(provider_spec, "module"):
|
||||
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
|
||||
raise AttributeError(
|
||||
f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute"
|
||||
)
|
||||
|
||||
module = importlib.import_module(provider_spec.module)
|
||||
args = []
|
||||
|
@ -382,7 +404,10 @@ async def instantiate_provider(
|
|||
# TODO: check compliance for special tool groups
|
||||
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
|
||||
check_protocol_compliance(impl, protocols[provider_spec.api])
|
||||
if not isinstance(provider_spec, AutoRoutedProviderSpec) and provider_spec.api in additional_protocols:
|
||||
if (
|
||||
not isinstance(provider_spec, AutoRoutedProviderSpec)
|
||||
and provider_spec.api in additional_protocols
|
||||
):
|
||||
additional_api, _, _ = additional_protocols[provider_spec.api]
|
||||
check_protocol_compliance(impl, additional_api)
|
||||
|
||||
|
@ -410,12 +435,19 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
|||
obj_params = set(obj_sig.parameters)
|
||||
obj_params.discard("self")
|
||||
if not (proto_params <= obj_params):
|
||||
logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
|
||||
logger.error(
|
||||
f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}"
|
||||
)
|
||||
missing_methods.append((name, "signature_mismatch"))
|
||||
else:
|
||||
# Check if the method is actually implemented in the class
|
||||
method_owner = next((cls for cls in mro if name in cls.__dict__), None)
|
||||
if method_owner is None or method_owner.__name__ == protocol.__name__:
|
||||
method_owner = next(
|
||||
(cls for cls in mro if name in cls.__dict__), None
|
||||
)
|
||||
if (
|
||||
method_owner is None
|
||||
or method_owner.__name__ == protocol.__name__
|
||||
):
|
||||
missing_methods.append((name, "not_actually_implemented"))
|
||||
|
||||
if missing_methods:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue