From c7139b0b6792e5c3b30d245b05c3870554c7e758 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 12 Mar 2025 11:59:21 -0700 Subject: [PATCH] fix: fix precommit (#1594) # What does this PR do? - fix precommit [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan CI [//]: # (## Documentation) --- .../open-benchmark/open_benchmark.py | 29 ++++++++++++------- llama_stack/templates/template.py | 6 ++-- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/llama_stack/templates/open-benchmark/open_benchmark.py b/llama_stack/templates/open-benchmark/open_benchmark.py index 7df33a715..2b40797f9 100644 --- a/llama_stack/templates/open-benchmark/open_benchmark.py +++ b/llama_stack/templates/open-benchmark/open_benchmark.py @@ -4,8 +4,9 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import List, Tuple +from typing import Dict, List, Tuple +from llama_stack.apis.common.content_types import URL from llama_stack.apis.models.models import ModelType from llama_stack.distribution.datatypes import ( BenchmarkInput, @@ -15,21 +16,27 @@ from llama_stack.distribution.datatypes import ( ShieldInput, ToolGroupInput, ) -from llama_stack.providers.inline.vector_io.sqlite_vec.config import SQLiteVectorIOConfig +from llama_stack.providers.inline.vector_io.sqlite_vec.config import ( + SQLiteVectorIOConfig, +) from llama_stack.providers.remote.inference.anthropic.config import AnthropicConfig from llama_stack.providers.remote.inference.gemini.config import GeminiConfig from llama_stack.providers.remote.inference.groq.config import GroqConfig from llama_stack.providers.remote.inference.openai.config import OpenAIConfig from llama_stack.providers.remote.inference.together.config import TogetherImplConfig from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig -from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig -from llama_stack.providers.utils.inference.model_registry import ( - ProviderModelEntry, +from llama_stack.providers.remote.vector_io.pgvector.config import ( + PGVectorVectorIOConfig, +) +from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry +from llama_stack.templates.template import ( + DistributionTemplate, + RunConfigSettings, + get_model_registry, ) -from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry -def get_inference_providers() -> Tuple[List[Provider], List[ModelInput]]: +def get_inference_providers() -> Tuple[List[Provider], Dict[str, List[ProviderModelEntry]]]: # in this template, we allow each API key to be optional providers = [ ( @@ -164,7 +171,7 @@ def get_distribution_template() -> DistributionTemplate: DatasetInput( dataset_id="simpleqa", provider_id="huggingface", - url={"uri": "https://huggingface.co/datasets/llamastack/simpleqa"}, + url=URL(uri="https://huggingface.co/datasets/llamastack/simpleqa"), metadata={ "path": "llamastack/simpleqa", "split": "train", @@ -178,7 +185,7 @@ def get_distribution_template() -> DistributionTemplate: DatasetInput( dataset_id="mmlu_cot", provider_id="huggingface", - url={"uri": "https://huggingface.co/datasets/llamastack/mmlu_cot"}, + url=URL(uri="https://huggingface.co/datasets/llamastack/mmlu_cot"), metadata={ "path": "llamastack/mmlu_cot", "name": "all", @@ -193,7 +200,7 @@ def get_distribution_template() -> DistributionTemplate: DatasetInput( dataset_id="gpqa_cot", provider_id="huggingface", - url={"uri": "https://huggingface.co/datasets/llamastack/gpqa_0shot_cot"}, + url=URL(uri="https://huggingface.co/datasets/llamastack/gpqa_0shot_cot"), metadata={ "path": "llamastack/gpqa_0shot_cot", "name": "gpqa_main", @@ -208,7 +215,7 @@ def get_distribution_template() -> DistributionTemplate: DatasetInput( dataset_id="math_500", provider_id="huggingface", - url={"uri": "https://huggingface.co/datasets/llamastack/math_500"}, + url=URL(uri="https://huggingface.co/datasets/llamastack/math_500"), metadata={ "path": "llamastack/math_500", "split": "test", diff --git a/llama_stack/templates/template.py b/llama_stack/templates/template.py index aa1ce144f..a5c8e80bc 100644 --- a/llama_stack/templates/template.py +++ b/llama_stack/templates/template.py @@ -30,7 +30,9 @@ from llama_stack.providers.utils.inference.model_registry import ProviderModelEn from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig -def get_model_registry(available_models: Dict[str, List[ProviderModelEntry]]) -> List[ModelInput]: +def get_model_registry( + available_models: Dict[str, List[ProviderModelEntry]], +) -> List[ModelInput]: models = [] for provider_id, entries in available_models.items(): for entry in entries: @@ -193,7 +195,7 @@ class DistributionTemplate(BaseModel): default_models.append( DefaultModel( model_id=model_entry.provider_model_id, - doc_string=f"({' -- '.join(doc_parts)})" if doc_parts else "", + doc_string=(f"({' -- '.join(doc_parts)})" if doc_parts else ""), ) )