mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
chore(lint): update Ruff ignores for project conventions and maintainability (#1184)
- Added new ignores from flake8-bugbear (`B007`, `B008`) - Ignored `C901` (high function complexity) for now, pending review - Maintained PyTorch conventions (`N812`, `N817`) - Allowed `E731` (lambda assignments) for flexibility - Consolidated existing ignores (`E402`, `E501`, `F405`, `C408`, `N812`) - Documented rationale for each ignored rule This keeps our linting aligned with project needs while tracking potential fixes. Signed-off-by: Sébastien Han <seb@redhat.com> Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
3b57d8ee88
commit
6fa257b475
33 changed files with 113 additions and 145 deletions
|
@ -141,7 +141,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
|||
completer=WordCompleter(available_providers),
|
||||
complete_while_typing=True,
|
||||
validator=Validator.from_callable(
|
||||
lambda x: x in available_providers,
|
||||
lambda x: x in available_providers, # noqa: B023 - see https://github.com/astral-sh/ruff/issues/7847
|
||||
error_message="Invalid provider, use <TAB> to see options",
|
||||
),
|
||||
)
|
||||
|
|
|
@ -112,7 +112,7 @@ def test_parse_and_maybe_upgrade_config_old_format(old_config):
|
|||
|
||||
inference_providers = result.providers["inference"]
|
||||
assert len(inference_providers) == 2
|
||||
assert set(x.provider_id for x in inference_providers) == {
|
||||
assert {x.provider_id for x in inference_providers} == {
|
||||
"remote::ollama-00",
|
||||
"meta-reference-01",
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@ from llama_stack.providers.datatypes import Api, ProviderSpec
|
|||
|
||||
|
||||
def stack_apis() -> List[Api]:
|
||||
return [v for v in Api]
|
||||
return list(Api)
|
||||
|
||||
|
||||
class AutoRoutedApiInfo(BaseModel):
|
||||
|
@ -55,7 +55,7 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
|
|||
|
||||
|
||||
def providable_apis() -> List[Api]:
|
||||
routing_table_apis = set(x.routing_table_api for x in builtin_automatically_routed_apis())
|
||||
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
||||
return [api for api in Api if api not in routing_table_apis and api != Api.inspect]
|
||||
|
||||
|
||||
|
|
|
@ -115,8 +115,8 @@ async def resolve_impls(
|
|||
- flatmaps, sorts and resolves the providers in dependency order
|
||||
- for each API, produces either a (local, passthrough or router) implementation
|
||||
"""
|
||||
routing_table_apis = set(x.routing_table_api for x in builtin_automatically_routed_apis())
|
||||
router_apis = set(x.router_api for x in builtin_automatically_routed_apis())
|
||||
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
|
||||
router_apis = {x.router_api for x in builtin_automatically_routed_apis()}
|
||||
|
||||
providers_with_specs = {}
|
||||
|
||||
|
|
|
@ -134,7 +134,7 @@ def rag_chat_page():
|
|||
dict(
|
||||
name="builtin::rag/knowledge_search",
|
||||
args={
|
||||
"vector_db_ids": [vector_db_id for vector_db_id in selected_vector_dbs],
|
||||
"vector_db_ids": list(selected_vector_dbs),
|
||||
},
|
||||
)
|
||||
],
|
||||
|
|
|
@ -797,10 +797,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
|
||||
) -> Tuple[List[ToolDefinition], Dict[str, str]]:
|
||||
# Determine which tools to include
|
||||
agent_config_toolgroups = set(
|
||||
(toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup)
|
||||
agent_config_toolgroups = {
|
||||
toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup
|
||||
for toolgroup in self.agent_config.toolgroups
|
||||
)
|
||||
}
|
||||
toolgroups_for_turn_set = (
|
||||
agent_config_toolgroups
|
||||
if toolgroups_for_turn is None
|
||||
|
|
|
@ -86,7 +86,6 @@ class MetaReferenceEvalImpl(
|
|||
) -> Job:
|
||||
task_def = self.benchmarks[benchmark_id]
|
||||
dataset_id = task_def.dataset_id
|
||||
candidate = task_config.eval_candidate
|
||||
scoring_functions = task_def.scoring_functions
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value))
|
||||
|
|
|
@ -208,7 +208,6 @@ class MetaReferenceInferenceImpl(
|
|||
logprobs = []
|
||||
stop_reason = None
|
||||
|
||||
tokenizer = self.generator.formatter.tokenizer
|
||||
for token_result in self.generator.completion(request):
|
||||
tokens.append(token_result.token)
|
||||
if token_result.text == "<|eot_id|>":
|
||||
|
|
|
@ -207,7 +207,7 @@ def maybe_parse_message(maybe_json: Optional[str]) -> Optional[ProcessingMessage
|
|||
return parse_message(maybe_json)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
except ValueError as e:
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
|
@ -352,7 +352,7 @@ class ModelParallelProcessGroup:
|
|||
if isinstance(obj, TaskResponse):
|
||||
yield obj.result
|
||||
|
||||
except GeneratorExit as e:
|
||||
except GeneratorExit:
|
||||
self.request_socket.send(encode_msg(CancelSentinel()))
|
||||
while True:
|
||||
obj_json = self.request_socket.send()
|
||||
|
|
|
@ -7,6 +7,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
# The file gets a special treatment for now?
|
||||
# ruff: noqa: N803
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
|
|
@ -264,7 +264,7 @@ class LoraFinetuningSingleDevice:
|
|||
)
|
||||
|
||||
self.adapter_params = get_adapter_params(model)
|
||||
self._is_dora = any(["magnitude" in k for k in self.adapter_params.keys()])
|
||||
self._is_dora = any("magnitude" in k for k in self.adapter_params.keys())
|
||||
|
||||
set_trainable_params(model, self.adapter_params)
|
||||
|
||||
|
|
|
@ -133,7 +133,7 @@ class BraintrustScoringImpl(
|
|||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||
scoring_fn_defs_list = [x for x in self.supported_fn_defs_registry.values()]
|
||||
scoring_fn_defs_list = list(self.supported_fn_defs_registry.values())
|
||||
for f in scoring_fn_defs_list:
|
||||
assert f.identifier.startswith("braintrust"), (
|
||||
"All braintrust scoring fn must have identifier prefixed with 'braintrust'! "
|
||||
|
|
|
@ -198,7 +198,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
if tool_prompt_format:
|
||||
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring")
|
||||
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring", stacklevel=2)
|
||||
|
||||
await check_health(self._config) # this raises errors
|
||||
|
||||
|
|
|
@ -106,7 +106,7 @@ async def convert_chat_completion_request(
|
|||
payload.update(temperature=strategy.temperature)
|
||||
elif isinstance(strategy, TopKSamplingStrategy):
|
||||
if strategy.top_k != -1 and strategy.top_k < 1:
|
||||
warnings.warn("top_k must be -1 or >= 1")
|
||||
warnings.warn("top_k must be -1 or >= 1", stacklevel=2)
|
||||
nvext.update(top_k=strategy.top_k)
|
||||
elif isinstance(strategy, GreedySamplingStrategy):
|
||||
nvext.update(top_k=-1)
|
||||
|
@ -168,7 +168,7 @@ def convert_completion_request(
|
|||
payload.update(top_p=request.sampling_params.top_p)
|
||||
elif request.sampling_params.strategy == "top_k":
|
||||
if request.sampling_params.top_k != -1 and request.sampling_params.top_k < 1:
|
||||
warnings.warn("top_k must be -1 or >= 1")
|
||||
warnings.warn("top_k must be -1 or >= 1", stacklevel=2)
|
||||
nvext.update(top_k=request.sampling_params.top_k)
|
||||
elif request.sampling_params.strategy == "greedy":
|
||||
nvext.update(top_k=-1)
|
||||
|
|
|
@ -39,12 +39,11 @@ class Testeval:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_eval_evaluate_rows(self, eval_stack, inference_model, judge_model):
|
||||
eval_impl, benchmarks_impl, datasetio_impl, datasets_impl, models_impl = (
|
||||
eval_impl, benchmarks_impl, datasetio_impl, datasets_impl = (
|
||||
eval_stack[Api.eval],
|
||||
eval_stack[Api.benchmarks],
|
||||
eval_stack[Api.datasetio],
|
||||
eval_stack[Api.datasets],
|
||||
eval_stack[Api.models],
|
||||
)
|
||||
|
||||
await register_dataset(datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval")
|
||||
|
@ -92,11 +91,10 @@ class Testeval:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_eval_run_eval(self, eval_stack, inference_model, judge_model):
|
||||
eval_impl, benchmarks_impl, datasets_impl, models_impl = (
|
||||
eval_impl, benchmarks_impl, datasets_impl = (
|
||||
eval_stack[Api.eval],
|
||||
eval_stack[Api.benchmarks],
|
||||
eval_stack[Api.datasets],
|
||||
eval_stack[Api.models],
|
||||
)
|
||||
|
||||
await register_dataset(datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval")
|
||||
|
@ -131,11 +129,10 @@ class Testeval:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_eval_run_benchmark_eval(self, eval_stack, inference_model):
|
||||
eval_impl, benchmarks_impl, datasets_impl, models_impl = (
|
||||
eval_impl, benchmarks_impl, datasets_impl = (
|
||||
eval_stack[Api.eval],
|
||||
eval_stack[Api.benchmarks],
|
||||
eval_stack[Api.datasets],
|
||||
eval_stack[Api.models],
|
||||
)
|
||||
|
||||
response = await datasets_impl.list_datasets()
|
||||
|
|
|
@ -18,54 +18,48 @@ from llama_stack.models.llama.sku_list import all_registered_models
|
|||
INFERENCE_APIS = ["chat_completion"]
|
||||
FUNCTIONALITIES = ["streaming", "structured_output", "tool_calling"]
|
||||
SUPPORTED_MODELS = {
|
||||
"ollama": set(
|
||||
[
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
CoreModelId.llama_guard_3_1b.value,
|
||||
]
|
||||
),
|
||||
"fireworks": set(
|
||||
[
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
CoreModelId.llama_guard_3_11b_vision.value,
|
||||
]
|
||||
),
|
||||
"together": set(
|
||||
[
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
CoreModelId.llama_guard_3_11b_vision.value,
|
||||
]
|
||||
),
|
||||
"ollama": {
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
CoreModelId.llama_guard_3_1b.value,
|
||||
},
|
||||
"fireworks": {
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
CoreModelId.llama_guard_3_11b_vision.value,
|
||||
},
|
||||
"together": {
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
CoreModelId.llama_guard_3_11b_vision.value,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -45,13 +45,11 @@ class TestScoring:
|
|||
scoring_functions_impl,
|
||||
datasetio_impl,
|
||||
datasets_impl,
|
||||
models_impl,
|
||||
) = (
|
||||
scoring_stack[Api.scoring],
|
||||
scoring_stack[Api.scoring_functions],
|
||||
scoring_stack[Api.datasetio],
|
||||
scoring_stack[Api.datasets],
|
||||
scoring_stack[Api.models],
|
||||
)
|
||||
scoring_fns_list = await scoring_functions_impl.list_scoring_functions()
|
||||
provider_id = scoring_fns_list[0].provider_id
|
||||
|
@ -102,13 +100,11 @@ class TestScoring:
|
|||
scoring_functions_impl,
|
||||
datasetio_impl,
|
||||
datasets_impl,
|
||||
models_impl,
|
||||
) = (
|
||||
scoring_stack[Api.scoring],
|
||||
scoring_stack[Api.scoring_functions],
|
||||
scoring_stack[Api.datasetio],
|
||||
scoring_stack[Api.datasets],
|
||||
scoring_stack[Api.models],
|
||||
)
|
||||
await register_dataset(datasets_impl, for_rag=True)
|
||||
response = await datasets_impl.list_datasets()
|
||||
|
@ -163,13 +159,11 @@ class TestScoring:
|
|||
scoring_functions_impl,
|
||||
datasetio_impl,
|
||||
datasets_impl,
|
||||
models_impl,
|
||||
) = (
|
||||
scoring_stack[Api.scoring],
|
||||
scoring_stack[Api.scoring_functions],
|
||||
scoring_stack[Api.datasetio],
|
||||
scoring_stack[Api.datasets],
|
||||
scoring_stack[Api.models],
|
||||
)
|
||||
await register_dataset(datasets_impl, for_rag=True)
|
||||
rows = await datasetio_impl.get_rows_paginated(
|
||||
|
|
|
@ -605,7 +605,7 @@ def convert_tool_call(
|
|||
tool_name=tool_call.function.name,
|
||||
arguments=json.loads(tool_call.function.arguments),
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return UnparseableToolCall(
|
||||
call_id=tool_call.id or "",
|
||||
tool_name=tool_call.function.name or "",
|
||||
|
@ -876,7 +876,9 @@ async def convert_openai_chat_completion_stream(
|
|||
# it is possible to have parallel tool calls in stream, but
|
||||
# ChatCompletionResponseEvent only supports one per stream
|
||||
if len(choice.delta.tool_calls) > 1:
|
||||
warnings.warn("multiple tool calls found in a single delta, using the first, ignoring the rest")
|
||||
warnings.warn(
|
||||
"multiple tool calls found in a single delta, using the first, ignoring the rest", stacklevel=2
|
||||
)
|
||||
|
||||
if not enable_incremental_tool_calls:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
|
|
|
@ -36,7 +36,7 @@ class RedisKVStoreImpl(KVStore):
|
|||
value = await self.redis.get(key)
|
||||
if value is None:
|
||||
return None
|
||||
ttl = await self.redis.ttl(key)
|
||||
await self.redis.ttl(key)
|
||||
return value
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
|
|
|
@ -32,7 +32,7 @@ def aggregate_categorical_count(
|
|||
scoring_results: List[ScoringResultRow],
|
||||
) -> Dict[str, Any]:
|
||||
scores = [str(r["score"]) for r in scoring_results]
|
||||
unique_scores = sorted(list(set(scores)))
|
||||
unique_scores = sorted(set(scores))
|
||||
return {"categorical_count": {s: scores.count(s) for s in unique_scores}}
|
||||
|
||||
|
||||
|
|
|
@ -66,7 +66,7 @@ class RegisteredBaseScoringFn(BaseScoringFn):
|
|||
return self.__class__.__name__
|
||||
|
||||
def get_supported_scoring_fn_defs(self) -> List[ScoringFn]:
|
||||
return [x for x in self.supported_fn_defs_registry.values()]
|
||||
return list(self.supported_fn_defs_registry.values())
|
||||
|
||||
def register_scoring_fn_def(self, scoring_fn: ScoringFn) -> None:
|
||||
if scoring_fn.identifier in self.supported_fn_defs_registry:
|
||||
|
|
|
@ -99,7 +99,7 @@ def collect_template_dependencies(template_dir: Path) -> tuple[str | None, list[
|
|||
template = template_func()
|
||||
normal_deps, special_deps = get_provider_dependencies(template.providers)
|
||||
# Combine all dependencies in order: normal deps, special deps, server deps
|
||||
all_deps = sorted(list(set(normal_deps + SERVER_DEPENDENCIES))) + sorted(list(set(special_deps)))
|
||||
all_deps = sorted(set(normal_deps + SERVER_DEPENDENCIES)) + sorted(set(special_deps))
|
||||
|
||||
return template.name, all_deps
|
||||
except Exception:
|
||||
|
|
|
@ -123,39 +123,16 @@ select = [
|
|||
"I", # isort
|
||||
]
|
||||
ignore = [
|
||||
"E203",
|
||||
"E305",
|
||||
"E402",
|
||||
"E501", # line too long
|
||||
"E721",
|
||||
"E741",
|
||||
"F405",
|
||||
"F841",
|
||||
"C408", # ignored because we like the dict keyword argument syntax
|
||||
"E302",
|
||||
"W291",
|
||||
"E303",
|
||||
"N812", # ignored because import torch.nn.functional as F is PyTorch convention
|
||||
"N817", # ignored because importing using acronyms is convention (DistributedDataParallel as DDP)
|
||||
"E731", # allow usage of assigning lambda expressions
|
||||
# The following ignores are desired by the project maintainers.
|
||||
"E402", # Module level import not at top of file
|
||||
"E501", # Line too long
|
||||
"F405", # Maybe undefined or defined from star import
|
||||
"C408", # Ignored because we like the dict keyword argument syntax
|
||||
"N812", # Ignored because import torch.nn.functional as F is PyTorch convention
|
||||
|
||||
# These are the additional ones we started ignoring after moving to ruff. We should look into each one of them later.
|
||||
"C901",
|
||||
"C405",
|
||||
"C414",
|
||||
"N803",
|
||||
"N999",
|
||||
"C403",
|
||||
"C416",
|
||||
"B028",
|
||||
"C419",
|
||||
"C401",
|
||||
"B023",
|
||||
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
|
||||
# to line this up with executable bit
|
||||
"EXE001",
|
||||
"N802", # random naming hints don't need
|
||||
"C901", # Complexity of the function is too high
|
||||
# these ignores are from flake8-bugbear; please fix!
|
||||
"B007",
|
||||
"B008",
|
||||
]
|
||||
|
||||
|
|
|
@ -3,3 +3,4 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
# ruff: noqa: N999
|
||||
|
|
|
@ -3,3 +3,4 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
# ruff: noqa: N999
|
||||
|
|
|
@ -117,7 +117,7 @@ def client_with_models(llama_stack_client, text_model_id, vision_model_id, embed
|
|||
assert len(providers) > 0, "No inference providers found"
|
||||
inference_providers = [p.provider_id for p in providers if p.provider_type != "inline::sentence-transformers"]
|
||||
|
||||
model_ids = set(m.identifier for m in client.models.list())
|
||||
model_ids = {m.identifier for m in client.models.list()}
|
||||
model_ids.update(m.provider_resource_id for m in client.models.list())
|
||||
|
||||
if text_model_id and text_model_id not in model_ids:
|
||||
|
|
|
@ -3,3 +3,4 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
# ruff: noqa: N999
|
||||
|
|
|
@ -176,7 +176,7 @@ def test_embedding_truncation_error(
|
|||
):
|
||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
|
||||
with pytest.raises(BadRequestError) as excinfo:
|
||||
with pytest.raises(BadRequestError):
|
||||
llama_stack_client.inference.embeddings(
|
||||
model_id=embedding_model_id, contents=[DUMMY_LONG_TEXT], text_truncation=text_truncation
|
||||
)
|
||||
|
@ -243,7 +243,7 @@ def test_embedding_text_truncation_error(
|
|||
):
|
||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
|
||||
with pytest.raises(BadRequestError) as excinfo:
|
||||
with pytest.raises(BadRequestError):
|
||||
llama_stack_client.inference.embeddings(
|
||||
model_id=embedding_model_id, contents=[DUMMY_STRING], text_truncation=text_truncation
|
||||
)
|
||||
|
|
|
@ -139,7 +139,7 @@ def test_text_completion_log_probs_streaming(client_with_models, text_model_id,
|
|||
"top_k": 1,
|
||||
},
|
||||
)
|
||||
streamed_content = [chunk for chunk in response]
|
||||
streamed_content = list(response)
|
||||
for chunk in streamed_content:
|
||||
if chunk.delta: # if there's a token, we expect logprobs
|
||||
assert chunk.logprobs, "Logprobs should not be empty"
|
||||
|
@ -405,7 +405,7 @@ def test_text_chat_completion_tool_calling_tools_not_in_request(
|
|||
assert delta.tool_call.tool_name == "get_object_namespace_list"
|
||||
if delta.type == "tool_call" and delta.parse_status == "failed":
|
||||
# expect raw message that failed to parse in tool_call
|
||||
assert type(delta.tool_call) == str
|
||||
assert isinstance(delta.tool_call, str)
|
||||
assert len(delta.tool_call) > 0
|
||||
else:
|
||||
for tc in response.completion_message.tool_calls:
|
||||
|
|
|
@ -42,29 +42,27 @@ def featured_models():
|
|||
|
||||
|
||||
SUPPORTED_MODELS = {
|
||||
"ollama": set(
|
||||
[
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
CoreModelId.llama_guard_3_1b.value,
|
||||
]
|
||||
),
|
||||
"tgi": set([model.core_model_id.value for model in all_registered_models() if model.huggingface_repo]),
|
||||
"vllm": set([model.core_model_id.value for model in all_registered_models() if model.huggingface_repo]),
|
||||
"ollama": {
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
CoreModelId.llama_guard_3_1b.value,
|
||||
},
|
||||
"tgi": {model.core_model_id.value for model in all_registered_models() if model.huggingface_repo},
|
||||
"vllm": {model.core_model_id.value for model in all_registered_models() if model.huggingface_repo},
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -3,3 +3,4 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
# ruff: noqa: N999
|
||||
|
|
|
@ -42,7 +42,7 @@ def code_scanner_shield_id(available_shields):
|
|||
|
||||
@pytest.fixture(scope="session")
|
||||
def model_providers(llama_stack_client):
|
||||
return set([x.provider_id for x in llama_stack_client.providers.list() if x.api == "inference"])
|
||||
return {x.provider_id for x in llama_stack_client.providers.list() if x.api == "inference"}
|
||||
|
||||
|
||||
def test_unsafe_examples(llama_stack_client, llama_guard_text_shield_id):
|
||||
|
|
|
@ -3,3 +3,4 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
# ruff: noqa: N999
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue