forked from phoenix-oss/llama-stack-mirror
# What does this PR do? - Update `/eval-tasks` to `/benchmarks` - ⚠️ Remove differentiation between `app` v.s. `benchmark` eval task config. Now we only have `BenchmarkConfig`. The overloaded `benchmark` is confusing and do not add any value. Backward compatibility is being kept as the "type" is not being used anywhere. [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan - This change is backward compatible - Run notebook test with ``` pytest -v -s --nbval-lax ./docs/getting_started.ipynb pytest -v -s --nbval-lax ./docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb ``` <img width="846" alt="image" src="https://github.com/user-attachments/assets/d2fc06a7-593a-444f-bc1f-10ab9b0c843d" /> [//]: # (## Documentation) [//]: # (- [ ] Added a Changelog entry if the change is significant) --------- Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com> Signed-off-by: Ben Browning <bbrownin@redhat.com> Signed-off-by: Sébastien Han <seb@redhat.com> Signed-off-by: reidliu <reid201711@gmail.com> Co-authored-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com> Co-authored-by: Ben Browning <ben324@gmail.com> Co-authored-by: Sébastien Han <seb@redhat.com> Co-authored-by: Reid <61492567+reidliu41@users.noreply.github.com> Co-authored-by: reidliu <reid201711@gmail.com> Co-authored-by: Yuan Tang <terrytangyuan@gmail.com>
390 lines
14 KiB
Python
390 lines
14 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
import importlib
|
|
import inspect
|
|
import logging
|
|
from typing import Any, Dict, List, Set
|
|
|
|
from llama_stack.apis.agents import Agents
|
|
from llama_stack.apis.benchmarks import Benchmarks
|
|
from llama_stack.apis.datasetio import DatasetIO
|
|
from llama_stack.apis.datasets import Datasets
|
|
from llama_stack.apis.eval import Eval
|
|
from llama_stack.apis.inference import Inference
|
|
from llama_stack.apis.inspect import Inspect
|
|
from llama_stack.apis.models import Models
|
|
from llama_stack.apis.post_training import PostTraining
|
|
from llama_stack.apis.safety import Safety
|
|
from llama_stack.apis.scoring import Scoring
|
|
from llama_stack.apis.scoring_functions import ScoringFunctions
|
|
from llama_stack.apis.shields import Shields
|
|
from llama_stack.apis.telemetry import Telemetry
|
|
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
|
from llama_stack.apis.vector_dbs import VectorDBs
|
|
from llama_stack.apis.vector_io import VectorIO
|
|
from llama_stack.distribution.client import get_client_impl
|
|
from llama_stack.distribution.datatypes import (
|
|
AutoRoutedProviderSpec,
|
|
Provider,
|
|
RoutingTableProviderSpec,
|
|
StackRunConfig,
|
|
)
|
|
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
|
from llama_stack.distribution.store import DistributionRegistry
|
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
|
from llama_stack.providers.datatypes import (
|
|
Api,
|
|
BenchmarksProtocolPrivate,
|
|
DatasetsProtocolPrivate,
|
|
InlineProviderSpec,
|
|
ModelsProtocolPrivate,
|
|
ProviderSpec,
|
|
RemoteProviderConfig,
|
|
RemoteProviderSpec,
|
|
ScoringFunctionsProtocolPrivate,
|
|
ShieldsProtocolPrivate,
|
|
ToolsProtocolPrivate,
|
|
VectorDBsProtocolPrivate,
|
|
)
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class InvalidProviderError(Exception):
|
|
pass
|
|
|
|
|
|
def api_protocol_map() -> Dict[Api, Any]:
|
|
return {
|
|
Api.agents: Agents,
|
|
Api.inference: Inference,
|
|
Api.inspect: Inspect,
|
|
Api.vector_io: VectorIO,
|
|
Api.vector_dbs: VectorDBs,
|
|
Api.models: Models,
|
|
Api.safety: Safety,
|
|
Api.shields: Shields,
|
|
Api.telemetry: Telemetry,
|
|
Api.datasetio: DatasetIO,
|
|
Api.datasets: Datasets,
|
|
Api.scoring: Scoring,
|
|
Api.scoring_functions: ScoringFunctions,
|
|
Api.eval: Eval,
|
|
Api.benchmarks: Benchmarks,
|
|
Api.post_training: PostTraining,
|
|
Api.tool_groups: ToolGroups,
|
|
Api.tool_runtime: ToolRuntime,
|
|
}
|
|
|
|
|
|
def additional_protocols_map() -> Dict[Api, Any]:
|
|
return {
|
|
Api.inference: (ModelsProtocolPrivate, Models, Api.models),
|
|
Api.tool_groups: (ToolsProtocolPrivate, ToolGroups, Api.tool_groups),
|
|
Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs),
|
|
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
|
|
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
|
|
Api.scoring: (
|
|
ScoringFunctionsProtocolPrivate,
|
|
ScoringFunctions,
|
|
Api.scoring_functions,
|
|
),
|
|
Api.eval: (BenchmarksProtocolPrivate, Benchmarks, Api.benchmarks),
|
|
}
|
|
|
|
|
|
# TODO: make all this naming far less atrocious. Provider. ProviderSpec. ProviderWithSpec. WTF!
|
|
class ProviderWithSpec(Provider):
|
|
spec: ProviderSpec
|
|
|
|
|
|
ProviderRegistry = Dict[Api, Dict[str, ProviderSpec]]
|
|
|
|
|
|
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
|
|
async def resolve_impls(
|
|
run_config: StackRunConfig,
|
|
provider_registry: ProviderRegistry,
|
|
dist_registry: DistributionRegistry,
|
|
) -> Dict[Api, Any]:
|
|
"""
|
|
Does two things:
|
|
- flatmaps, sorts and resolves the providers in dependency order
|
|
- for each API, produces either a (local, passthrough or router) implementation
|
|
"""
|
|
routing_table_apis = set(x.routing_table_api for x in builtin_automatically_routed_apis())
|
|
router_apis = set(x.router_api for x in builtin_automatically_routed_apis())
|
|
|
|
providers_with_specs = {}
|
|
|
|
for api_str, providers in run_config.providers.items():
|
|
api = Api(api_str)
|
|
if api in routing_table_apis:
|
|
raise ValueError(f"Provider for `{api_str}` is automatically provided and cannot be overridden")
|
|
|
|
specs = {}
|
|
for provider in providers:
|
|
if provider.provider_type not in provider_registry[api]:
|
|
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
|
|
|
|
p = provider_registry[api][provider.provider_type]
|
|
if p.deprecation_error:
|
|
log.error(p.deprecation_error, "red", attrs=["bold"])
|
|
raise InvalidProviderError(p.deprecation_error)
|
|
|
|
elif p.deprecation_warning:
|
|
log.warning(
|
|
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
|
|
)
|
|
p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
|
|
spec = ProviderWithSpec(
|
|
spec=p,
|
|
**(provider.model_dump()),
|
|
)
|
|
specs[provider.provider_id] = spec
|
|
|
|
key = api_str if api not in router_apis else f"inner-{api_str}"
|
|
providers_with_specs[key] = specs
|
|
|
|
apis_to_serve = run_config.apis or set(
|
|
list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis]
|
|
)
|
|
|
|
for info in builtin_automatically_routed_apis():
|
|
if info.router_api.value not in apis_to_serve:
|
|
continue
|
|
|
|
providers_with_specs[info.routing_table_api.value] = {
|
|
"__builtin__": ProviderWithSpec(
|
|
provider_id="__routing_table__",
|
|
provider_type="__routing_table__",
|
|
config={},
|
|
spec=RoutingTableProviderSpec(
|
|
api=info.routing_table_api,
|
|
router_api=info.router_api,
|
|
module="llama_stack.distribution.routers",
|
|
api_dependencies=[],
|
|
deps__=([f"inner-{info.router_api.value}"]),
|
|
),
|
|
)
|
|
}
|
|
|
|
providers_with_specs[info.router_api.value] = {
|
|
"__builtin__": ProviderWithSpec(
|
|
provider_id="__autorouted__",
|
|
provider_type="__autorouted__",
|
|
config={},
|
|
spec=AutoRoutedProviderSpec(
|
|
api=info.router_api,
|
|
module="llama_stack.distribution.routers",
|
|
routing_table_api=info.routing_table_api,
|
|
api_dependencies=[info.routing_table_api],
|
|
deps__=([info.routing_table_api.value]),
|
|
),
|
|
)
|
|
}
|
|
|
|
sorted_providers = topological_sort({k: v.values() for k, v in providers_with_specs.items()})
|
|
apis = [x[1].spec.api for x in sorted_providers]
|
|
sorted_providers.append(
|
|
(
|
|
"inspect",
|
|
ProviderWithSpec(
|
|
provider_id="__builtin__",
|
|
provider_type="__builtin__",
|
|
config={
|
|
"run_config": run_config.dict(),
|
|
},
|
|
spec=InlineProviderSpec(
|
|
api=Api.inspect,
|
|
provider_type="__builtin__",
|
|
config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
|
|
module="llama_stack.distribution.inspect",
|
|
api_dependencies=apis,
|
|
deps__=([x.value for x in apis]),
|
|
),
|
|
),
|
|
)
|
|
)
|
|
|
|
log.info(f"Resolved {len(sorted_providers)} providers")
|
|
for api_str, provider in sorted_providers:
|
|
log.info(f" {api_str} => {provider.provider_id}")
|
|
log.info("")
|
|
|
|
impls = {}
|
|
inner_impls_by_provider_id = {f"inner-{x.value}": {} for x in router_apis}
|
|
for api_str, provider in sorted_providers:
|
|
deps = {a: impls[a] for a in provider.spec.api_dependencies}
|
|
for a in provider.spec.optional_api_dependencies:
|
|
if a in impls:
|
|
deps[a] = impls[a]
|
|
|
|
inner_impls = {}
|
|
if isinstance(provider.spec, RoutingTableProviderSpec):
|
|
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
|
|
|
|
impl = await instantiate_provider(
|
|
provider,
|
|
deps,
|
|
inner_impls,
|
|
dist_registry,
|
|
)
|
|
# TODO: ugh slightly redesign this shady looking code
|
|
if "inner-" in api_str:
|
|
inner_impls_by_provider_id[api_str][provider.provider_id] = impl
|
|
else:
|
|
api = Api(api_str)
|
|
impls[api] = impl
|
|
|
|
return impls
|
|
|
|
|
|
def topological_sort(
|
|
providers_with_specs: Dict[str, List[ProviderWithSpec]],
|
|
) -> List[ProviderWithSpec]:
|
|
def dfs(kv, visited: Set[str], stack: List[str]):
|
|
api_str, providers = kv
|
|
visited.add(api_str)
|
|
|
|
deps = []
|
|
for provider in providers:
|
|
for dep in provider.spec.deps__:
|
|
deps.append(dep)
|
|
|
|
for dep in deps:
|
|
if dep not in visited and dep in providers_with_specs:
|
|
dfs((dep, providers_with_specs[dep]), visited, stack)
|
|
|
|
stack.append(api_str)
|
|
|
|
visited = set()
|
|
stack = []
|
|
|
|
for api_str, providers in providers_with_specs.items():
|
|
if api_str not in visited:
|
|
dfs((api_str, providers), visited, stack)
|
|
|
|
flattened = []
|
|
for api_str in stack:
|
|
for provider in providers_with_specs[api_str]:
|
|
flattened.append((api_str, provider))
|
|
return flattened
|
|
|
|
|
|
# returns a class implementing the protocol corresponding to the Api
|
|
async def instantiate_provider(
|
|
provider: ProviderWithSpec,
|
|
deps: Dict[str, Any],
|
|
inner_impls: Dict[str, Any],
|
|
dist_registry: DistributionRegistry,
|
|
):
|
|
protocols = api_protocol_map()
|
|
additional_protocols = additional_protocols_map()
|
|
|
|
provider_spec = provider.spec
|
|
module = importlib.import_module(provider_spec.module)
|
|
|
|
args = []
|
|
if isinstance(provider_spec, RemoteProviderSpec):
|
|
config_type = instantiate_class_type(provider_spec.config_class)
|
|
config = config_type(**provider.config)
|
|
|
|
method = "get_adapter_impl"
|
|
args = [config, deps]
|
|
|
|
elif isinstance(provider_spec, AutoRoutedProviderSpec):
|
|
method = "get_auto_router_impl"
|
|
|
|
config = None
|
|
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps]
|
|
elif isinstance(provider_spec, RoutingTableProviderSpec):
|
|
method = "get_routing_table_impl"
|
|
|
|
config = None
|
|
args = [provider_spec.api, inner_impls, deps, dist_registry]
|
|
else:
|
|
method = "get_provider_impl"
|
|
|
|
config_type = instantiate_class_type(provider_spec.config_class)
|
|
config = config_type(**provider.config)
|
|
args = [config, deps]
|
|
|
|
fn = getattr(module, method)
|
|
impl = await fn(*args)
|
|
impl.__provider_id__ = provider.provider_id
|
|
impl.__provider_spec__ = provider_spec
|
|
impl.__provider_config__ = config
|
|
|
|
# TODO: check compliance for special tool groups
|
|
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
|
|
check_protocol_compliance(impl, protocols[provider_spec.api])
|
|
if not isinstance(provider_spec, AutoRoutedProviderSpec) and provider_spec.api in additional_protocols:
|
|
additional_api, _, _ = additional_protocols[provider_spec.api]
|
|
check_protocol_compliance(impl, additional_api)
|
|
|
|
return impl
|
|
|
|
|
|
def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
|
missing_methods = []
|
|
|
|
mro = type(obj).__mro__
|
|
for name, value in inspect.getmembers(protocol):
|
|
if inspect.isfunction(value) and hasattr(value, "__webmethod__"):
|
|
if not hasattr(obj, name):
|
|
missing_methods.append((name, "missing"))
|
|
elif not callable(getattr(obj, name)):
|
|
missing_methods.append((name, "not_callable"))
|
|
else:
|
|
# Check if the method signatures are compatible
|
|
obj_method = getattr(obj, name)
|
|
proto_sig = inspect.signature(value)
|
|
obj_sig = inspect.signature(obj_method)
|
|
|
|
proto_params = set(proto_sig.parameters)
|
|
proto_params.discard("self")
|
|
obj_params = set(obj_sig.parameters)
|
|
obj_params.discard("self")
|
|
if not (proto_params <= obj_params):
|
|
log.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
|
|
missing_methods.append((name, "signature_mismatch"))
|
|
else:
|
|
# Check if the method is actually implemented in the class
|
|
method_owner = next((cls for cls in mro if name in cls.__dict__), None)
|
|
if method_owner is None or method_owner.__name__ == protocol.__name__:
|
|
missing_methods.append((name, "not_actually_implemented"))
|
|
|
|
if missing_methods:
|
|
raise ValueError(
|
|
f"Provider `{obj.__provider_id__} ({obj.__provider_spec__.api})` does not implement the following methods:\n{missing_methods}"
|
|
)
|
|
|
|
|
|
async def resolve_remote_stack_impls(
|
|
config: RemoteProviderConfig,
|
|
apis: List[str],
|
|
) -> Dict[Api, Any]:
|
|
protocols = api_protocol_map()
|
|
additional_protocols = additional_protocols_map()
|
|
|
|
impls = {}
|
|
for api_str in apis:
|
|
api = Api(api_str)
|
|
impls[api] = await get_client_impl(
|
|
protocols[api],
|
|
config,
|
|
{},
|
|
)
|
|
if api in additional_protocols:
|
|
_, additional_protocol, additional_api = additional_protocols[api]
|
|
impls[additional_api] = await get_client_impl(
|
|
additional_protocol,
|
|
config,
|
|
{},
|
|
)
|
|
|
|
return impls
|