chore(lint): update Ruff ignores for project conventions and maintainability (#1184)

- Added new ignores from flake8-bugbear (`B007`, `B008`)
- Ignored `C901` (high function complexity) for now, pending review
- Maintained PyTorch conventions (`N812`, `N817`)
- Allowed `E731` (lambda assignments) for flexibility
- Consolidated existing ignores (`E402`, `E501`, `F405`, `C408`, `N812`)
- Documented rationale for each ignored rule

This keeps our linting aligned with project needs while tracking
potential fixes.

Signed-off-by: Sébastien Han <seb@redhat.com>

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-02-28 18:36:49 +01:00 committed by GitHub
parent 3b57d8ee88
commit 6fa257b475
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
33 changed files with 113 additions and 145 deletions

View file

@ -141,7 +141,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
completer=WordCompleter(available_providers),
complete_while_typing=True,
validator=Validator.from_callable(
lambda x: x in available_providers,
lambda x: x in available_providers, # noqa: B023 - see https://github.com/astral-sh/ruff/issues/7847
error_message="Invalid provider, use <TAB> to see options",
),
)

View file

@ -112,7 +112,7 @@ def test_parse_and_maybe_upgrade_config_old_format(old_config):
inference_providers = result.providers["inference"]
assert len(inference_providers) == 2
assert set(x.provider_id for x in inference_providers) == {
assert {x.provider_id for x in inference_providers} == {
"remote::ollama-00",
"meta-reference-01",
}

View file

@ -13,7 +13,7 @@ from llama_stack.providers.datatypes import Api, ProviderSpec
def stack_apis() -> List[Api]:
return [v for v in Api]
return list(Api)
class AutoRoutedApiInfo(BaseModel):
@ -55,7 +55,7 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
def providable_apis() -> List[Api]:
routing_table_apis = set(x.routing_table_api for x in builtin_automatically_routed_apis())
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
return [api for api in Api if api not in routing_table_apis and api != Api.inspect]

View file

@ -115,8 +115,8 @@ async def resolve_impls(
- flatmaps, sorts and resolves the providers in dependency order
- for each API, produces either a (local, passthrough or router) implementation
"""
routing_table_apis = set(x.routing_table_api for x in builtin_automatically_routed_apis())
router_apis = set(x.router_api for x in builtin_automatically_routed_apis())
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
router_apis = {x.router_api for x in builtin_automatically_routed_apis()}
providers_with_specs = {}

View file

@ -134,7 +134,7 @@ def rag_chat_page():
dict(
name="builtin::rag/knowledge_search",
args={
"vector_db_ids": [vector_db_id for vector_db_id in selected_vector_dbs],
"vector_db_ids": list(selected_vector_dbs),
},
)
],

View file

@ -797,10 +797,10 @@ class ChatAgent(ShieldRunnerMixin):
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
) -> Tuple[List[ToolDefinition], Dict[str, str]]:
# Determine which tools to include
agent_config_toolgroups = set(
(toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup)
agent_config_toolgroups = {
toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup
for toolgroup in self.agent_config.toolgroups
)
}
toolgroups_for_turn_set = (
agent_config_toolgroups
if toolgroups_for_turn is None

View file

@ -86,7 +86,6 @@ class MetaReferenceEvalImpl(
) -> Job:
task_def = self.benchmarks[benchmark_id]
dataset_id = task_def.dataset_id
candidate = task_config.eval_candidate
scoring_functions = task_def.scoring_functions
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value))

View file

@ -208,7 +208,6 @@ class MetaReferenceInferenceImpl(
logprobs = []
stop_reason = None
tokenizer = self.generator.formatter.tokenizer
for token_result in self.generator.completion(request):
tokens.append(token_result.token)
if token_result.text == "<|eot_id|>":

View file

@ -207,7 +207,7 @@ def maybe_parse_message(maybe_json: Optional[str]) -> Optional[ProcessingMessage
return parse_message(maybe_json)
except json.JSONDecodeError:
return None
except ValueError as e:
except ValueError:
return None
@ -352,7 +352,7 @@ class ModelParallelProcessGroup:
if isinstance(obj, TaskResponse):
yield obj.result
except GeneratorExit as e:
except GeneratorExit:
self.request_socket.send(encode_msg(CancelSentinel()))
while True:
obj_json = self.request_socket.send()

View file

@ -7,6 +7,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
# The file gets a special treatment for now?
# ruff: noqa: N803
import unittest
import torch

View file

@ -264,7 +264,7 @@ class LoraFinetuningSingleDevice:
)
self.adapter_params = get_adapter_params(model)
self._is_dora = any(["magnitude" in k for k in self.adapter_params.keys()])
self._is_dora = any("magnitude" in k for k in self.adapter_params.keys())
set_trainable_params(model, self.adapter_params)

View file

@ -133,7 +133,7 @@ class BraintrustScoringImpl(
async def shutdown(self) -> None: ...
async def list_scoring_functions(self) -> List[ScoringFn]:
scoring_fn_defs_list = [x for x in self.supported_fn_defs_registry.values()]
scoring_fn_defs_list = list(self.supported_fn_defs_registry.values())
for f in scoring_fn_defs_list:
assert f.identifier.startswith("braintrust"), (
"All braintrust scoring fn must have identifier prefixed with 'braintrust'! "

View file

@ -198,7 +198,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
if tool_prompt_format:
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring")
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring", stacklevel=2)
await check_health(self._config) # this raises errors

View file

@ -106,7 +106,7 @@ async def convert_chat_completion_request(
payload.update(temperature=strategy.temperature)
elif isinstance(strategy, TopKSamplingStrategy):
if strategy.top_k != -1 and strategy.top_k < 1:
warnings.warn("top_k must be -1 or >= 1")
warnings.warn("top_k must be -1 or >= 1", stacklevel=2)
nvext.update(top_k=strategy.top_k)
elif isinstance(strategy, GreedySamplingStrategy):
nvext.update(top_k=-1)
@ -168,7 +168,7 @@ def convert_completion_request(
payload.update(top_p=request.sampling_params.top_p)
elif request.sampling_params.strategy == "top_k":
if request.sampling_params.top_k != -1 and request.sampling_params.top_k < 1:
warnings.warn("top_k must be -1 or >= 1")
warnings.warn("top_k must be -1 or >= 1", stacklevel=2)
nvext.update(top_k=request.sampling_params.top_k)
elif request.sampling_params.strategy == "greedy":
nvext.update(top_k=-1)

View file

@ -39,12 +39,11 @@ class Testeval:
@pytest.mark.asyncio
async def test_eval_evaluate_rows(self, eval_stack, inference_model, judge_model):
eval_impl, benchmarks_impl, datasetio_impl, datasets_impl, models_impl = (
eval_impl, benchmarks_impl, datasetio_impl, datasets_impl = (
eval_stack[Api.eval],
eval_stack[Api.benchmarks],
eval_stack[Api.datasetio],
eval_stack[Api.datasets],
eval_stack[Api.models],
)
await register_dataset(datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval")
@ -92,11 +91,10 @@ class Testeval:
@pytest.mark.asyncio
async def test_eval_run_eval(self, eval_stack, inference_model, judge_model):
eval_impl, benchmarks_impl, datasets_impl, models_impl = (
eval_impl, benchmarks_impl, datasets_impl = (
eval_stack[Api.eval],
eval_stack[Api.benchmarks],
eval_stack[Api.datasets],
eval_stack[Api.models],
)
await register_dataset(datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval")
@ -131,11 +129,10 @@ class Testeval:
@pytest.mark.asyncio
async def test_eval_run_benchmark_eval(self, eval_stack, inference_model):
eval_impl, benchmarks_impl, datasets_impl, models_impl = (
eval_impl, benchmarks_impl, datasets_impl = (
eval_stack[Api.eval],
eval_stack[Api.benchmarks],
eval_stack[Api.datasets],
eval_stack[Api.models],
)
response = await datasets_impl.list_datasets()

View file

@ -18,54 +18,48 @@ from llama_stack.models.llama.sku_list import all_registered_models
INFERENCE_APIS = ["chat_completion"]
FUNCTIONALITIES = ["streaming", "structured_output", "tool_calling"]
SUPPORTED_MODELS = {
"ollama": set(
[
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_1b.value,
]
),
"fireworks": set(
[
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_11b_vision.value,
]
),
"together": set(
[
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_11b_vision.value,
]
),
"ollama": {
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_1b.value,
},
"fireworks": {
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_11b_vision.value,
},
"together": {
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_11b_vision.value,
},
}

View file

@ -45,13 +45,11 @@ class TestScoring:
scoring_functions_impl,
datasetio_impl,
datasets_impl,
models_impl,
) = (
scoring_stack[Api.scoring],
scoring_stack[Api.scoring_functions],
scoring_stack[Api.datasetio],
scoring_stack[Api.datasets],
scoring_stack[Api.models],
)
scoring_fns_list = await scoring_functions_impl.list_scoring_functions()
provider_id = scoring_fns_list[0].provider_id
@ -102,13 +100,11 @@ class TestScoring:
scoring_functions_impl,
datasetio_impl,
datasets_impl,
models_impl,
) = (
scoring_stack[Api.scoring],
scoring_stack[Api.scoring_functions],
scoring_stack[Api.datasetio],
scoring_stack[Api.datasets],
scoring_stack[Api.models],
)
await register_dataset(datasets_impl, for_rag=True)
response = await datasets_impl.list_datasets()
@ -163,13 +159,11 @@ class TestScoring:
scoring_functions_impl,
datasetio_impl,
datasets_impl,
models_impl,
) = (
scoring_stack[Api.scoring],
scoring_stack[Api.scoring_functions],
scoring_stack[Api.datasetio],
scoring_stack[Api.datasets],
scoring_stack[Api.models],
)
await register_dataset(datasets_impl, for_rag=True)
rows = await datasetio_impl.get_rows_paginated(

View file

@ -605,7 +605,7 @@ def convert_tool_call(
tool_name=tool_call.function.name,
arguments=json.loads(tool_call.function.arguments),
)
except Exception as e:
except Exception:
return UnparseableToolCall(
call_id=tool_call.id or "",
tool_name=tool_call.function.name or "",
@ -876,7 +876,9 @@ async def convert_openai_chat_completion_stream(
# it is possible to have parallel tool calls in stream, but
# ChatCompletionResponseEvent only supports one per stream
if len(choice.delta.tool_calls) > 1:
warnings.warn("multiple tool calls found in a single delta, using the first, ignoring the rest")
warnings.warn(
"multiple tool calls found in a single delta, using the first, ignoring the rest", stacklevel=2
)
if not enable_incremental_tool_calls:
yield ChatCompletionResponseStreamChunk(

View file

@ -36,7 +36,7 @@ class RedisKVStoreImpl(KVStore):
value = await self.redis.get(key)
if value is None:
return None
ttl = await self.redis.ttl(key)
await self.redis.ttl(key)
return value
async def delete(self, key: str) -> None:

View file

@ -32,7 +32,7 @@ def aggregate_categorical_count(
scoring_results: List[ScoringResultRow],
) -> Dict[str, Any]:
scores = [str(r["score"]) for r in scoring_results]
unique_scores = sorted(list(set(scores)))
unique_scores = sorted(set(scores))
return {"categorical_count": {s: scores.count(s) for s in unique_scores}}

View file

@ -66,7 +66,7 @@ class RegisteredBaseScoringFn(BaseScoringFn):
return self.__class__.__name__
def get_supported_scoring_fn_defs(self) -> List[ScoringFn]:
return [x for x in self.supported_fn_defs_registry.values()]
return list(self.supported_fn_defs_registry.values())
def register_scoring_fn_def(self, scoring_fn: ScoringFn) -> None:
if scoring_fn.identifier in self.supported_fn_defs_registry:

View file

@ -99,7 +99,7 @@ def collect_template_dependencies(template_dir: Path) -> tuple[str | None, list[
template = template_func()
normal_deps, special_deps = get_provider_dependencies(template.providers)
# Combine all dependencies in order: normal deps, special deps, server deps
all_deps = sorted(list(set(normal_deps + SERVER_DEPENDENCIES))) + sorted(list(set(special_deps)))
all_deps = sorted(set(normal_deps + SERVER_DEPENDENCIES)) + sorted(set(special_deps))
return template.name, all_deps
except Exception:

View file

@ -123,39 +123,16 @@ select = [
"I", # isort
]
ignore = [
"E203",
"E305",
"E402",
"E501", # line too long
"E721",
"E741",
"F405",
"F841",
"C408", # ignored because we like the dict keyword argument syntax
"E302",
"W291",
"E303",
"N812", # ignored because import torch.nn.functional as F is PyTorch convention
"N817", # ignored because importing using acronyms is convention (DistributedDataParallel as DDP)
"E731", # allow usage of assigning lambda expressions
# The following ignores are desired by the project maintainers.
"E402", # Module level import not at top of file
"E501", # Line too long
"F405", # Maybe undefined or defined from star import
"C408", # Ignored because we like the dict keyword argument syntax
"N812", # Ignored because import torch.nn.functional as F is PyTorch convention
# These are the additional ones we started ignoring after moving to ruff. We should look into each one of them later.
"C901",
"C405",
"C414",
"N803",
"N999",
"C403",
"C416",
"B028",
"C419",
"C401",
"B023",
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
# to line this up with executable bit
"EXE001",
"N802", # random naming hints don't need
"C901", # Complexity of the function is too high
# these ignores are from flake8-bugbear; please fix!
"B007",
"B008",
]

View file

@ -3,3 +3,4 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# ruff: noqa: N999

View file

@ -3,3 +3,4 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# ruff: noqa: N999

View file

@ -117,7 +117,7 @@ def client_with_models(llama_stack_client, text_model_id, vision_model_id, embed
assert len(providers) > 0, "No inference providers found"
inference_providers = [p.provider_id for p in providers if p.provider_type != "inline::sentence-transformers"]
model_ids = set(m.identifier for m in client.models.list())
model_ids = {m.identifier for m in client.models.list()}
model_ids.update(m.provider_resource_id for m in client.models.list())
if text_model_id and text_model_id not in model_ids:

View file

@ -3,3 +3,4 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# ruff: noqa: N999

View file

@ -176,7 +176,7 @@ def test_embedding_truncation_error(
):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
with pytest.raises(BadRequestError) as excinfo:
with pytest.raises(BadRequestError):
llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_LONG_TEXT], text_truncation=text_truncation
)
@ -243,7 +243,7 @@ def test_embedding_text_truncation_error(
):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
with pytest.raises(BadRequestError) as excinfo:
with pytest.raises(BadRequestError):
llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_STRING], text_truncation=text_truncation
)

View file

@ -139,7 +139,7 @@ def test_text_completion_log_probs_streaming(client_with_models, text_model_id,
"top_k": 1,
},
)
streamed_content = [chunk for chunk in response]
streamed_content = list(response)
for chunk in streamed_content:
if chunk.delta: # if there's a token, we expect logprobs
assert chunk.logprobs, "Logprobs should not be empty"
@ -405,7 +405,7 @@ def test_text_chat_completion_tool_calling_tools_not_in_request(
assert delta.tool_call.tool_name == "get_object_namespace_list"
if delta.type == "tool_call" and delta.parse_status == "failed":
# expect raw message that failed to parse in tool_call
assert type(delta.tool_call) == str
assert isinstance(delta.tool_call, str)
assert len(delta.tool_call) > 0
else:
for tc in response.completion_message.tool_calls:

View file

@ -42,29 +42,27 @@ def featured_models():
SUPPORTED_MODELS = {
"ollama": set(
[
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_1b.value,
]
),
"tgi": set([model.core_model_id.value for model in all_registered_models() if model.huggingface_repo]),
"vllm": set([model.core_model_id.value for model in all_registered_models() if model.huggingface_repo]),
"ollama": {
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_1b.value,
},
"tgi": {model.core_model_id.value for model in all_registered_models() if model.huggingface_repo},
"vllm": {model.core_model_id.value for model in all_registered_models() if model.huggingface_repo},
}

View file

@ -3,3 +3,4 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# ruff: noqa: N999

View file

@ -42,7 +42,7 @@ def code_scanner_shield_id(available_shields):
@pytest.fixture(scope="session")
def model_providers(llama_stack_client):
return set([x.provider_id for x in llama_stack_client.providers.list() if x.api == "inference"])
return {x.provider_id for x in llama_stack_client.providers.list() if x.api == "inference"}
def test_unsafe_examples(llama_stack_client, llama_guard_text_shield_id):

View file

@ -3,3 +3,4 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# ruff: noqa: N999