chore(lint): update Ruff ignores for project conventions and maintainability (#1184)

- Added new ignores from flake8-bugbear (`B007`, `B008`)
- Ignored `C901` (high function complexity) for now, pending review
- Maintained PyTorch conventions (`N812`, `N817`)
- Allowed `E731` (lambda assignments) for flexibility
- Consolidated existing ignores (`E402`, `E501`, `F405`, `C408`, `N812`)
- Documented rationale for each ignored rule

This keeps our linting aligned with project needs while tracking
potential fixes.

Signed-off-by: Sébastien Han <seb@redhat.com>

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-02-28 18:36:49 +01:00 committed by GitHub
parent 3b57d8ee88
commit 6fa257b475
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
33 changed files with 113 additions and 145 deletions

View file

@ -141,7 +141,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
completer=WordCompleter(available_providers), completer=WordCompleter(available_providers),
complete_while_typing=True, complete_while_typing=True,
validator=Validator.from_callable( validator=Validator.from_callable(
lambda x: x in available_providers, lambda x: x in available_providers, # noqa: B023 - see https://github.com/astral-sh/ruff/issues/7847
error_message="Invalid provider, use <TAB> to see options", error_message="Invalid provider, use <TAB> to see options",
), ),
) )

View file

@ -112,7 +112,7 @@ def test_parse_and_maybe_upgrade_config_old_format(old_config):
inference_providers = result.providers["inference"] inference_providers = result.providers["inference"]
assert len(inference_providers) == 2 assert len(inference_providers) == 2
assert set(x.provider_id for x in inference_providers) == { assert {x.provider_id for x in inference_providers} == {
"remote::ollama-00", "remote::ollama-00",
"meta-reference-01", "meta-reference-01",
} }

View file

@ -13,7 +13,7 @@ from llama_stack.providers.datatypes import Api, ProviderSpec
def stack_apis() -> List[Api]: def stack_apis() -> List[Api]:
return [v for v in Api] return list(Api)
class AutoRoutedApiInfo(BaseModel): class AutoRoutedApiInfo(BaseModel):
@ -55,7 +55,7 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
def providable_apis() -> List[Api]: def providable_apis() -> List[Api]:
routing_table_apis = set(x.routing_table_api for x in builtin_automatically_routed_apis()) routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
return [api for api in Api if api not in routing_table_apis and api != Api.inspect] return [api for api in Api if api not in routing_table_apis and api != Api.inspect]

View file

@ -115,8 +115,8 @@ async def resolve_impls(
- flatmaps, sorts and resolves the providers in dependency order - flatmaps, sorts and resolves the providers in dependency order
- for each API, produces either a (local, passthrough or router) implementation - for each API, produces either a (local, passthrough or router) implementation
""" """
routing_table_apis = set(x.routing_table_api for x in builtin_automatically_routed_apis()) routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
router_apis = set(x.router_api for x in builtin_automatically_routed_apis()) router_apis = {x.router_api for x in builtin_automatically_routed_apis()}
providers_with_specs = {} providers_with_specs = {}

View file

@ -134,7 +134,7 @@ def rag_chat_page():
dict( dict(
name="builtin::rag/knowledge_search", name="builtin::rag/knowledge_search",
args={ args={
"vector_db_ids": [vector_db_id for vector_db_id in selected_vector_dbs], "vector_db_ids": list(selected_vector_dbs),
}, },
) )
], ],

View file

@ -797,10 +797,10 @@ class ChatAgent(ShieldRunnerMixin):
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
) -> Tuple[List[ToolDefinition], Dict[str, str]]: ) -> Tuple[List[ToolDefinition], Dict[str, str]]:
# Determine which tools to include # Determine which tools to include
agent_config_toolgroups = set( agent_config_toolgroups = {
(toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup) toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup
for toolgroup in self.agent_config.toolgroups for toolgroup in self.agent_config.toolgroups
) }
toolgroups_for_turn_set = ( toolgroups_for_turn_set = (
agent_config_toolgroups agent_config_toolgroups
if toolgroups_for_turn is None if toolgroups_for_turn is None

View file

@ -86,7 +86,6 @@ class MetaReferenceEvalImpl(
) -> Job: ) -> Job:
task_def = self.benchmarks[benchmark_id] task_def = self.benchmarks[benchmark_id]
dataset_id = task_def.dataset_id dataset_id = task_def.dataset_id
candidate = task_config.eval_candidate
scoring_functions = task_def.scoring_functions scoring_functions = task_def.scoring_functions
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value)) validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value))

View file

@ -208,7 +208,6 @@ class MetaReferenceInferenceImpl(
logprobs = [] logprobs = []
stop_reason = None stop_reason = None
tokenizer = self.generator.formatter.tokenizer
for token_result in self.generator.completion(request): for token_result in self.generator.completion(request):
tokens.append(token_result.token) tokens.append(token_result.token)
if token_result.text == "<|eot_id|>": if token_result.text == "<|eot_id|>":

View file

@ -207,7 +207,7 @@ def maybe_parse_message(maybe_json: Optional[str]) -> Optional[ProcessingMessage
return parse_message(maybe_json) return parse_message(maybe_json)
except json.JSONDecodeError: except json.JSONDecodeError:
return None return None
except ValueError as e: except ValueError:
return None return None
@ -352,7 +352,7 @@ class ModelParallelProcessGroup:
if isinstance(obj, TaskResponse): if isinstance(obj, TaskResponse):
yield obj.result yield obj.result
except GeneratorExit as e: except GeneratorExit:
self.request_socket.send(encode_msg(CancelSentinel())) self.request_socket.send(encode_msg(CancelSentinel()))
while True: while True:
obj_json = self.request_socket.send() obj_json = self.request_socket.send()

View file

@ -7,6 +7,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
# The file gets a special treatment for now?
# ruff: noqa: N803
import unittest import unittest
import torch import torch

View file

@ -264,7 +264,7 @@ class LoraFinetuningSingleDevice:
) )
self.adapter_params = get_adapter_params(model) self.adapter_params = get_adapter_params(model)
self._is_dora = any(["magnitude" in k for k in self.adapter_params.keys()]) self._is_dora = any("magnitude" in k for k in self.adapter_params.keys())
set_trainable_params(model, self.adapter_params) set_trainable_params(model, self.adapter_params)

View file

@ -133,7 +133,7 @@ class BraintrustScoringImpl(
async def shutdown(self) -> None: ... async def shutdown(self) -> None: ...
async def list_scoring_functions(self) -> List[ScoringFn]: async def list_scoring_functions(self) -> List[ScoringFn]:
scoring_fn_defs_list = [x for x in self.supported_fn_defs_registry.values()] scoring_fn_defs_list = list(self.supported_fn_defs_registry.values())
for f in scoring_fn_defs_list: for f in scoring_fn_defs_list:
assert f.identifier.startswith("braintrust"), ( assert f.identifier.startswith("braintrust"), (
"All braintrust scoring fn must have identifier prefixed with 'braintrust'! " "All braintrust scoring fn must have identifier prefixed with 'braintrust'! "

View file

@ -198,7 +198,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
if tool_prompt_format: if tool_prompt_format:
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring") warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring", stacklevel=2)
await check_health(self._config) # this raises errors await check_health(self._config) # this raises errors

View file

@ -106,7 +106,7 @@ async def convert_chat_completion_request(
payload.update(temperature=strategy.temperature) payload.update(temperature=strategy.temperature)
elif isinstance(strategy, TopKSamplingStrategy): elif isinstance(strategy, TopKSamplingStrategy):
if strategy.top_k != -1 and strategy.top_k < 1: if strategy.top_k != -1 and strategy.top_k < 1:
warnings.warn("top_k must be -1 or >= 1") warnings.warn("top_k must be -1 or >= 1", stacklevel=2)
nvext.update(top_k=strategy.top_k) nvext.update(top_k=strategy.top_k)
elif isinstance(strategy, GreedySamplingStrategy): elif isinstance(strategy, GreedySamplingStrategy):
nvext.update(top_k=-1) nvext.update(top_k=-1)
@ -168,7 +168,7 @@ def convert_completion_request(
payload.update(top_p=request.sampling_params.top_p) payload.update(top_p=request.sampling_params.top_p)
elif request.sampling_params.strategy == "top_k": elif request.sampling_params.strategy == "top_k":
if request.sampling_params.top_k != -1 and request.sampling_params.top_k < 1: if request.sampling_params.top_k != -1 and request.sampling_params.top_k < 1:
warnings.warn("top_k must be -1 or >= 1") warnings.warn("top_k must be -1 or >= 1", stacklevel=2)
nvext.update(top_k=request.sampling_params.top_k) nvext.update(top_k=request.sampling_params.top_k)
elif request.sampling_params.strategy == "greedy": elif request.sampling_params.strategy == "greedy":
nvext.update(top_k=-1) nvext.update(top_k=-1)

View file

@ -39,12 +39,11 @@ class Testeval:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_eval_evaluate_rows(self, eval_stack, inference_model, judge_model): async def test_eval_evaluate_rows(self, eval_stack, inference_model, judge_model):
eval_impl, benchmarks_impl, datasetio_impl, datasets_impl, models_impl = ( eval_impl, benchmarks_impl, datasetio_impl, datasets_impl = (
eval_stack[Api.eval], eval_stack[Api.eval],
eval_stack[Api.benchmarks], eval_stack[Api.benchmarks],
eval_stack[Api.datasetio], eval_stack[Api.datasetio],
eval_stack[Api.datasets], eval_stack[Api.datasets],
eval_stack[Api.models],
) )
await register_dataset(datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval") await register_dataset(datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval")
@ -92,11 +91,10 @@ class Testeval:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_eval_run_eval(self, eval_stack, inference_model, judge_model): async def test_eval_run_eval(self, eval_stack, inference_model, judge_model):
eval_impl, benchmarks_impl, datasets_impl, models_impl = ( eval_impl, benchmarks_impl, datasets_impl = (
eval_stack[Api.eval], eval_stack[Api.eval],
eval_stack[Api.benchmarks], eval_stack[Api.benchmarks],
eval_stack[Api.datasets], eval_stack[Api.datasets],
eval_stack[Api.models],
) )
await register_dataset(datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval") await register_dataset(datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval")
@ -131,11 +129,10 @@ class Testeval:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_eval_run_benchmark_eval(self, eval_stack, inference_model): async def test_eval_run_benchmark_eval(self, eval_stack, inference_model):
eval_impl, benchmarks_impl, datasets_impl, models_impl = ( eval_impl, benchmarks_impl, datasets_impl = (
eval_stack[Api.eval], eval_stack[Api.eval],
eval_stack[Api.benchmarks], eval_stack[Api.benchmarks],
eval_stack[Api.datasets], eval_stack[Api.datasets],
eval_stack[Api.models],
) )
response = await datasets_impl.list_datasets() response = await datasets_impl.list_datasets()

View file

@ -18,54 +18,48 @@ from llama_stack.models.llama.sku_list import all_registered_models
INFERENCE_APIS = ["chat_completion"] INFERENCE_APIS = ["chat_completion"]
FUNCTIONALITIES = ["streaming", "structured_output", "tool_calling"] FUNCTIONALITIES = ["streaming", "structured_output", "tool_calling"]
SUPPORTED_MODELS = { SUPPORTED_MODELS = {
"ollama": set( "ollama": {
[ CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_8b_instruct.value, CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_8b_instruct.value, CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value, CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value, CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value, CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value, CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value, CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value, CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value, CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value, CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value, CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value, CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value, CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value, CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama3_3_70b_instruct.value, CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_8b.value, CoreModelId.llama_guard_3_1b.value,
CoreModelId.llama_guard_3_1b.value, },
] "fireworks": {
), CoreModelId.llama3_1_8b_instruct.value,
"fireworks": set( CoreModelId.llama3_1_70b_instruct.value,
[ CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_1_8b_instruct.value, CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value, CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value, CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_1b_instruct.value, CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_2_3b_instruct.value, CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value, CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama3_2_90b_vision_instruct.value, CoreModelId.llama_guard_3_11b_vision.value,
CoreModelId.llama3_3_70b_instruct.value, },
CoreModelId.llama_guard_3_8b.value, "together": {
CoreModelId.llama_guard_3_11b_vision.value, CoreModelId.llama3_1_8b_instruct.value,
] CoreModelId.llama3_1_70b_instruct.value,
), CoreModelId.llama3_1_405b_instruct.value,
"together": set( CoreModelId.llama3_2_3b_instruct.value,
[ CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_1_8b_instruct.value, CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_1_70b_instruct.value, CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value, CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama3_2_3b_instruct.value, CoreModelId.llama_guard_3_11b_vision.value,
CoreModelId.llama3_2_11b_vision_instruct.value, },
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_11b_vision.value,
]
),
} }

View file

@ -45,13 +45,11 @@ class TestScoring:
scoring_functions_impl, scoring_functions_impl,
datasetio_impl, datasetio_impl,
datasets_impl, datasets_impl,
models_impl,
) = ( ) = (
scoring_stack[Api.scoring], scoring_stack[Api.scoring],
scoring_stack[Api.scoring_functions], scoring_stack[Api.scoring_functions],
scoring_stack[Api.datasetio], scoring_stack[Api.datasetio],
scoring_stack[Api.datasets], scoring_stack[Api.datasets],
scoring_stack[Api.models],
) )
scoring_fns_list = await scoring_functions_impl.list_scoring_functions() scoring_fns_list = await scoring_functions_impl.list_scoring_functions()
provider_id = scoring_fns_list[0].provider_id provider_id = scoring_fns_list[0].provider_id
@ -102,13 +100,11 @@ class TestScoring:
scoring_functions_impl, scoring_functions_impl,
datasetio_impl, datasetio_impl,
datasets_impl, datasets_impl,
models_impl,
) = ( ) = (
scoring_stack[Api.scoring], scoring_stack[Api.scoring],
scoring_stack[Api.scoring_functions], scoring_stack[Api.scoring_functions],
scoring_stack[Api.datasetio], scoring_stack[Api.datasetio],
scoring_stack[Api.datasets], scoring_stack[Api.datasets],
scoring_stack[Api.models],
) )
await register_dataset(datasets_impl, for_rag=True) await register_dataset(datasets_impl, for_rag=True)
response = await datasets_impl.list_datasets() response = await datasets_impl.list_datasets()
@ -163,13 +159,11 @@ class TestScoring:
scoring_functions_impl, scoring_functions_impl,
datasetio_impl, datasetio_impl,
datasets_impl, datasets_impl,
models_impl,
) = ( ) = (
scoring_stack[Api.scoring], scoring_stack[Api.scoring],
scoring_stack[Api.scoring_functions], scoring_stack[Api.scoring_functions],
scoring_stack[Api.datasetio], scoring_stack[Api.datasetio],
scoring_stack[Api.datasets], scoring_stack[Api.datasets],
scoring_stack[Api.models],
) )
await register_dataset(datasets_impl, for_rag=True) await register_dataset(datasets_impl, for_rag=True)
rows = await datasetio_impl.get_rows_paginated( rows = await datasetio_impl.get_rows_paginated(

View file

@ -605,7 +605,7 @@ def convert_tool_call(
tool_name=tool_call.function.name, tool_name=tool_call.function.name,
arguments=json.loads(tool_call.function.arguments), arguments=json.loads(tool_call.function.arguments),
) )
except Exception as e: except Exception:
return UnparseableToolCall( return UnparseableToolCall(
call_id=tool_call.id or "", call_id=tool_call.id or "",
tool_name=tool_call.function.name or "", tool_name=tool_call.function.name or "",
@ -876,7 +876,9 @@ async def convert_openai_chat_completion_stream(
# it is possible to have parallel tool calls in stream, but # it is possible to have parallel tool calls in stream, but
# ChatCompletionResponseEvent only supports one per stream # ChatCompletionResponseEvent only supports one per stream
if len(choice.delta.tool_calls) > 1: if len(choice.delta.tool_calls) > 1:
warnings.warn("multiple tool calls found in a single delta, using the first, ignoring the rest") warnings.warn(
"multiple tool calls found in a single delta, using the first, ignoring the rest", stacklevel=2
)
if not enable_incremental_tool_calls: if not enable_incremental_tool_calls:
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(

View file

@ -36,7 +36,7 @@ class RedisKVStoreImpl(KVStore):
value = await self.redis.get(key) value = await self.redis.get(key)
if value is None: if value is None:
return None return None
ttl = await self.redis.ttl(key) await self.redis.ttl(key)
return value return value
async def delete(self, key: str) -> None: async def delete(self, key: str) -> None:

View file

@ -32,7 +32,7 @@ def aggregate_categorical_count(
scoring_results: List[ScoringResultRow], scoring_results: List[ScoringResultRow],
) -> Dict[str, Any]: ) -> Dict[str, Any]:
scores = [str(r["score"]) for r in scoring_results] scores = [str(r["score"]) for r in scoring_results]
unique_scores = sorted(list(set(scores))) unique_scores = sorted(set(scores))
return {"categorical_count": {s: scores.count(s) for s in unique_scores}} return {"categorical_count": {s: scores.count(s) for s in unique_scores}}

View file

@ -66,7 +66,7 @@ class RegisteredBaseScoringFn(BaseScoringFn):
return self.__class__.__name__ return self.__class__.__name__
def get_supported_scoring_fn_defs(self) -> List[ScoringFn]: def get_supported_scoring_fn_defs(self) -> List[ScoringFn]:
return [x for x in self.supported_fn_defs_registry.values()] return list(self.supported_fn_defs_registry.values())
def register_scoring_fn_def(self, scoring_fn: ScoringFn) -> None: def register_scoring_fn_def(self, scoring_fn: ScoringFn) -> None:
if scoring_fn.identifier in self.supported_fn_defs_registry: if scoring_fn.identifier in self.supported_fn_defs_registry:

View file

@ -99,7 +99,7 @@ def collect_template_dependencies(template_dir: Path) -> tuple[str | None, list[
template = template_func() template = template_func()
normal_deps, special_deps = get_provider_dependencies(template.providers) normal_deps, special_deps = get_provider_dependencies(template.providers)
# Combine all dependencies in order: normal deps, special deps, server deps # Combine all dependencies in order: normal deps, special deps, server deps
all_deps = sorted(list(set(normal_deps + SERVER_DEPENDENCIES))) + sorted(list(set(special_deps))) all_deps = sorted(set(normal_deps + SERVER_DEPENDENCIES)) + sorted(set(special_deps))
return template.name, all_deps return template.name, all_deps
except Exception: except Exception:

View file

@ -123,39 +123,16 @@ select = [
"I", # isort "I", # isort
] ]
ignore = [ ignore = [
"E203", # The following ignores are desired by the project maintainers.
"E305", "E402", # Module level import not at top of file
"E402", "E501", # Line too long
"E501", # line too long "F405", # Maybe undefined or defined from star import
"E721", "C408", # Ignored because we like the dict keyword argument syntax
"E741", "N812", # Ignored because import torch.nn.functional as F is PyTorch convention
"F405",
"F841",
"C408", # ignored because we like the dict keyword argument syntax
"E302",
"W291",
"E303",
"N812", # ignored because import torch.nn.functional as F is PyTorch convention
"N817", # ignored because importing using acronyms is convention (DistributedDataParallel as DDP)
"E731", # allow usage of assigning lambda expressions
# These are the additional ones we started ignoring after moving to ruff. We should look into each one of them later. # These are the additional ones we started ignoring after moving to ruff. We should look into each one of them later.
"C901", "C901", # Complexity of the function is too high
"C405",
"C414",
"N803",
"N999",
"C403",
"C416",
"B028",
"C419",
"C401",
"B023",
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
# to line this up with executable bit
"EXE001",
"N802", # random naming hints don't need
# these ignores are from flake8-bugbear; please fix! # these ignores are from flake8-bugbear; please fix!
"B007",
"B008", "B008",
] ]

View file

@ -3,3 +3,4 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# ruff: noqa: N999

View file

@ -3,3 +3,4 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# ruff: noqa: N999

View file

@ -117,7 +117,7 @@ def client_with_models(llama_stack_client, text_model_id, vision_model_id, embed
assert len(providers) > 0, "No inference providers found" assert len(providers) > 0, "No inference providers found"
inference_providers = [p.provider_id for p in providers if p.provider_type != "inline::sentence-transformers"] inference_providers = [p.provider_id for p in providers if p.provider_type != "inline::sentence-transformers"]
model_ids = set(m.identifier for m in client.models.list()) model_ids = {m.identifier for m in client.models.list()}
model_ids.update(m.provider_resource_id for m in client.models.list()) model_ids.update(m.provider_resource_id for m in client.models.list())
if text_model_id and text_model_id not in model_ids: if text_model_id and text_model_id not in model_ids:

View file

@ -3,3 +3,4 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# ruff: noqa: N999

View file

@ -176,7 +176,7 @@ def test_embedding_truncation_error(
): ):
if inference_provider_type not in SUPPORTED_PROVIDERS: if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet") pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
with pytest.raises(BadRequestError) as excinfo: with pytest.raises(BadRequestError):
llama_stack_client.inference.embeddings( llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_LONG_TEXT], text_truncation=text_truncation model_id=embedding_model_id, contents=[DUMMY_LONG_TEXT], text_truncation=text_truncation
) )
@ -243,7 +243,7 @@ def test_embedding_text_truncation_error(
): ):
if inference_provider_type not in SUPPORTED_PROVIDERS: if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet") pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
with pytest.raises(BadRequestError) as excinfo: with pytest.raises(BadRequestError):
llama_stack_client.inference.embeddings( llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_STRING], text_truncation=text_truncation model_id=embedding_model_id, contents=[DUMMY_STRING], text_truncation=text_truncation
) )

View file

@ -139,7 +139,7 @@ def test_text_completion_log_probs_streaming(client_with_models, text_model_id,
"top_k": 1, "top_k": 1,
}, },
) )
streamed_content = [chunk for chunk in response] streamed_content = list(response)
for chunk in streamed_content: for chunk in streamed_content:
if chunk.delta: # if there's a token, we expect logprobs if chunk.delta: # if there's a token, we expect logprobs
assert chunk.logprobs, "Logprobs should not be empty" assert chunk.logprobs, "Logprobs should not be empty"
@ -405,7 +405,7 @@ def test_text_chat_completion_tool_calling_tools_not_in_request(
assert delta.tool_call.tool_name == "get_object_namespace_list" assert delta.tool_call.tool_name == "get_object_namespace_list"
if delta.type == "tool_call" and delta.parse_status == "failed": if delta.type == "tool_call" and delta.parse_status == "failed":
# expect raw message that failed to parse in tool_call # expect raw message that failed to parse in tool_call
assert type(delta.tool_call) == str assert isinstance(delta.tool_call, str)
assert len(delta.tool_call) > 0 assert len(delta.tool_call) > 0
else: else:
for tc in response.completion_message.tool_calls: for tc in response.completion_message.tool_calls:

View file

@ -42,29 +42,27 @@ def featured_models():
SUPPORTED_MODELS = { SUPPORTED_MODELS = {
"ollama": set( "ollama": {
[ CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_8b_instruct.value, CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_8b_instruct.value, CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value, CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value, CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value, CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value, CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value, CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value, CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value, CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value, CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value, CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value, CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value, CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value, CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama3_3_70b_instruct.value, CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_8b.value, CoreModelId.llama_guard_3_1b.value,
CoreModelId.llama_guard_3_1b.value, },
] "tgi": {model.core_model_id.value for model in all_registered_models() if model.huggingface_repo},
), "vllm": {model.core_model_id.value for model in all_registered_models() if model.huggingface_repo},
"tgi": set([model.core_model_id.value for model in all_registered_models() if model.huggingface_repo]),
"vllm": set([model.core_model_id.value for model in all_registered_models() if model.huggingface_repo]),
} }

View file

@ -3,3 +3,4 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# ruff: noqa: N999

View file

@ -42,7 +42,7 @@ def code_scanner_shield_id(available_shields):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def model_providers(llama_stack_client): def model_providers(llama_stack_client):
return set([x.provider_id for x in llama_stack_client.providers.list() if x.api == "inference"]) return {x.provider_id for x in llama_stack_client.providers.list() if x.api == "inference"}
def test_unsafe_examples(llama_stack_client, llama_guard_text_shield_id): def test_unsafe_examples(llama_stack_client, llama_guard_text_shield_id):

View file

@ -3,3 +3,4 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# ruff: noqa: N999