From 93ed8aa814c7086cb0a806f9bf0bb8c5f0dacbcd Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 26 Dec 2024 16:39:31 -0800 Subject: [PATCH] remove more imports --- llama_stack/apis/inference/inference.py | 5 ++++- llama_stack/apis/scoring/scoring.py | 6 +++--- llama_stack/distribution/routers/routers.py | 1 + llama_stack/providers/inline/scoring/braintrust/config.py | 4 +++- llama_stack/providers/tests/memory/fixtures.py | 3 +-- llama_stack/providers/utils/scoring/aggregation_utils.py | 3 ++- 6 files changed, 14 insertions(+), 8 deletions(-) diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 28b9d9106..e48042091 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -7,7 +7,9 @@ from enum import Enum from typing import ( + Any, AsyncIterator, + Dict, List, Literal, Optional, @@ -32,8 +34,9 @@ from typing_extensions import Annotated from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.models import Model + from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol -from llama_stack.apis.models import * # noqa: F403 class LogProbConfig(BaseModel): diff --git a/llama_stack/apis/scoring/scoring.py b/llama_stack/apis/scoring/scoring.py index a47620a3d..17d1426b5 100644 --- a/llama_stack/apis/scoring/scoring.py +++ b/llama_stack/apis/scoring/scoring.py @@ -4,13 +4,13 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List, Protocol, runtime_checkable +from typing import Any, Dict, List, Optional, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.scoring_functions import * # noqa: F403 +# from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams # mapping of metric to value diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index a25a848db..e874e6333 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -15,6 +15,7 @@ from llama_stack.apis.memory_banks.memory_banks import BankParams from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.scoring import * # noqa: F403 from llama_stack.apis.tools import * # noqa: F403 +from llama_stack.apis.models import ModelType from llama_stack.distribution.datatypes import RoutingTable diff --git a/llama_stack/providers/inline/scoring/braintrust/config.py b/llama_stack/providers/inline/scoring/braintrust/config.py index e12249432..d4e0d9bcd 100644 --- a/llama_stack/providers/inline/scoring/braintrust/config.py +++ b/llama_stack/providers/inline/scoring/braintrust/config.py @@ -3,7 +3,9 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.apis.scoring import * # noqa: F401, F403 +from typing import Any, Dict, Optional + +from pydantic import BaseModel, Field class BraintrustScoringConfig(BaseModel): diff --git a/llama_stack/providers/tests/memory/fixtures.py b/llama_stack/providers/tests/memory/fixtures.py index b2a5a87c9..10314ea03 100644 --- a/llama_stack/providers/tests/memory/fixtures.py +++ b/llama_stack/providers/tests/memory/fixtures.py @@ -10,8 +10,7 @@ import tempfile import pytest import pytest_asyncio -from llama_stack.apis.inference import ModelInput, ModelType - +from llama_stack.apis.models import ModelInput, ModelType from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig from llama_stack.providers.inline.memory.faiss import FaissImplConfig diff --git a/llama_stack/providers/utils/scoring/aggregation_utils.py b/llama_stack/providers/utils/scoring/aggregation_utils.py index 7b9d58944..ded53faca 100644 --- a/llama_stack/providers/utils/scoring/aggregation_utils.py +++ b/llama_stack/providers/utils/scoring/aggregation_utils.py @@ -6,7 +6,8 @@ import statistics from typing import Any, Dict, List -from llama_stack.apis.scoring import AggregationFunctionType, ScoringResultRow +from llama_stack.apis.scoring import ScoringResultRow +from llama_stack.apis.scoring_functions import AggregationFunctionType def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: