mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
Fix precommit check after moving to ruff (#927)
Lint check in main branch is failing. This fixes the lint check after we moved to ruff in https://github.com/meta-llama/llama-stack/pull/921. We need to move to a `ruff.toml` file as well as fixing and ignoring some additional checks. Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
parent
4773092dd1
commit
34ab7a3b6c
217 changed files with 981 additions and 2681 deletions
|
@ -115,9 +115,7 @@ async def resolve_impls(
|
|||
- flatmaps, sorts and resolves the providers in dependency order
|
||||
- for each API, produces either a (local, passthrough or router) implementation
|
||||
"""
|
||||
routing_table_apis = set(
|
||||
x.routing_table_api for x in builtin_automatically_routed_apis()
|
||||
)
|
||||
routing_table_apis = set(x.routing_table_api for x in builtin_automatically_routed_apis())
|
||||
router_apis = set(x.router_api for x in builtin_automatically_routed_apis())
|
||||
|
||||
providers_with_specs = {}
|
||||
|
@ -125,16 +123,12 @@ async def resolve_impls(
|
|||
for api_str, providers in run_config.providers.items():
|
||||
api = Api(api_str)
|
||||
if api in routing_table_apis:
|
||||
raise ValueError(
|
||||
f"Provider for `{api_str}` is automatically provided and cannot be overridden"
|
||||
)
|
||||
raise ValueError(f"Provider for `{api_str}` is automatically provided and cannot be overridden")
|
||||
|
||||
specs = {}
|
||||
for provider in providers:
|
||||
if provider.provider_type not in provider_registry[api]:
|
||||
raise ValueError(
|
||||
f"Provider `{provider.provider_type}` is not available for API `{api}`"
|
||||
)
|
||||
raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`")
|
||||
|
||||
p = provider_registry[api][provider.provider_type]
|
||||
if p.deprecation_error:
|
||||
|
@ -145,9 +139,7 @@ async def resolve_impls(
|
|||
log.warning(
|
||||
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
|
||||
)
|
||||
p.deps__ = [a.value for a in p.api_dependencies] + [
|
||||
a.value for a in p.optional_api_dependencies
|
||||
]
|
||||
p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies]
|
||||
spec = ProviderWithSpec(
|
||||
spec=p,
|
||||
**(provider.model_dump()),
|
||||
|
@ -158,9 +150,7 @@ async def resolve_impls(
|
|||
providers_with_specs[key] = specs
|
||||
|
||||
apis_to_serve = run_config.apis or set(
|
||||
list(providers_with_specs.keys())
|
||||
+ [x.value for x in routing_table_apis]
|
||||
+ [x.value for x in router_apis]
|
||||
list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis]
|
||||
)
|
||||
|
||||
for info in builtin_automatically_routed_apis():
|
||||
|
@ -197,9 +187,7 @@ async def resolve_impls(
|
|||
)
|
||||
}
|
||||
|
||||
sorted_providers = topological_sort(
|
||||
{k: v.values() for k, v in providers_with_specs.items()}
|
||||
)
|
||||
sorted_providers = topological_sort({k: v.values() for k, v in providers_with_specs.items()})
|
||||
apis = [x[1].spec.api for x in sorted_providers]
|
||||
sorted_providers.append(
|
||||
(
|
||||
|
@ -237,9 +225,7 @@ async def resolve_impls(
|
|||
|
||||
inner_impls = {}
|
||||
if isinstance(provider.spec, RoutingTableProviderSpec):
|
||||
inner_impls = inner_impls_by_provider_id[
|
||||
f"inner-{provider.spec.router_api.value}"
|
||||
]
|
||||
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
|
||||
|
||||
impl = await instantiate_provider(
|
||||
provider,
|
||||
|
@ -336,10 +322,7 @@ async def instantiate_provider(
|
|||
# TODO: check compliance for special tool groups
|
||||
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
|
||||
check_protocol_compliance(impl, protocols[provider_spec.api])
|
||||
if (
|
||||
not isinstance(provider_spec, AutoRoutedProviderSpec)
|
||||
and provider_spec.api in additional_protocols
|
||||
):
|
||||
if not isinstance(provider_spec, AutoRoutedProviderSpec) and provider_spec.api in additional_protocols:
|
||||
additional_api, _, _ = additional_protocols[provider_spec.api]
|
||||
check_protocol_compliance(impl, additional_api)
|
||||
|
||||
|
@ -367,19 +350,12 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
|||
obj_params = set(obj_sig.parameters)
|
||||
obj_params.discard("self")
|
||||
if not (proto_params <= obj_params):
|
||||
log.error(
|
||||
f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}"
|
||||
)
|
||||
log.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
|
||||
missing_methods.append((name, "signature_mismatch"))
|
||||
else:
|
||||
# Check if the method is actually implemented in the class
|
||||
method_owner = next(
|
||||
(cls for cls in mro if name in cls.__dict__), None
|
||||
)
|
||||
if (
|
||||
method_owner is None
|
||||
or method_owner.__name__ == protocol.__name__
|
||||
):
|
||||
method_owner = next((cls for cls in mro if name in cls.__dict__), None)
|
||||
if method_owner is None or method_owner.__name__ == protocol.__name__:
|
||||
missing_methods.append((name, "not_actually_implemented"))
|
||||
|
||||
if missing_methods:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue