From 21a6bd57eae924139ebabfcccca6be4dcd587c90 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 26 Dec 2024 17:17:03 -0800 Subject: [PATCH] fix imports --- llama_stack/distribution/routers/__init__.py | 6 ++-- llama_stack/distribution/routers/routers.py | 2 +- .../distribution/routers/routing_tables.py | 35 ++++++++++++++----- llama_stack/distribution/store/registry.py | 7 ++-- llama_stack/providers/registry/agents.py | 8 ++++- llama_stack/providers/registry/datasetio.py | 8 ++++- llama_stack/providers/registry/eval.py | 2 +- llama_stack/providers/registry/inference.py | 9 +++-- llama_stack/providers/registry/memory.py | 9 +++-- .../providers/registry/post_training.py | 2 +- llama_stack/providers/registry/safety.py | 2 +- llama_stack/providers/registry/scoring.py | 2 +- llama_stack/providers/registry/telemetry.py | 8 ++++- .../providers/registry/tool_runtime.py | 2 +- .../providers/tests/memory/fixtures.py | 2 +- llama_stack/providers/tests/resolver.py | 14 ++++++-- .../providers/tests/safety/test_safety.py | 6 ++-- 17 files changed, 89 insertions(+), 35 deletions(-) diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index 693f1fbe2..f19a2bffc 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -4,10 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any +from typing import Any, Dict + +from llama_stack.distribution.datatypes import RoutedProtocol -from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.store import DistributionRegistry +from llama_stack.providers.datatypes import Api, RoutingTable from .routing_tables import ( DatasetsRoutingTable, diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index e874e6333..8a05e53a4 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -16,7 +16,7 @@ from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.scoring import * # noqa: F403 from llama_stack.apis.tools import * # noqa: F403 from llama_stack.apis.models import ModelType -from llama_stack.distribution.datatypes import RoutingTable +from llama_stack.providers.datatypes import RoutingTable class MemoryRouter(Memory): diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 3fb086b72..a6d2b1032 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -6,19 +6,38 @@ from typing import Any, Dict, List, Optional -from llama_models.llama3.api.datatypes import * # noqa: F403 +# from llama_models.llama3.api.datatypes import * # noqa: F403 from pydantic import parse_obj_as from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.type_system import ParamType -from llama_stack.apis.datasets import * # noqa: F403 -from llama_stack.apis.eval_tasks import * # noqa: F403 -from llama_stack.apis.memory_banks import * # noqa: F403 -from llama_stack.apis.models import * # noqa: F403 -from llama_stack.apis.shields import * # noqa: F403 -from llama_stack.apis.tools import * # noqa: F403 -from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.apis.datasets import Dataset, Datasets +from llama_stack.apis.eval_tasks import EvalTask, EvalTasks +from llama_stack.apis.memory_banks import BankParams, MemoryBank, MemoryBanks +from llama_stack.apis.models import Model, Models, ModelType +from llama_stack.apis.resource import ResourceType +from llama_stack.apis.scoring_functions import ( + ScoringFn, + ScoringFnParams, + ScoringFunctions, +) +from llama_stack.apis.shields import Shield, Shields +from llama_stack.apis.tools import ( + MCPToolGroupDef, + Tool, + ToolGroup, + ToolGroupDef, + ToolGroups, + UserDefinedToolGroupDef, +) +from llama_stack.distribution.datatypes import ( + RoutableObject, + RoutableObjectWithProvider, + RoutedProtocol, +) + from llama_stack.distribution.store import DistributionRegistry +from llama_stack.providers.datatypes import Api, RoutingTable def get_impl_api(p: Any) -> Api: diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index f98c14443..686054dd2 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -13,11 +13,8 @@ import pydantic from llama_stack.distribution.datatypes import KVStoreConfig, RoutableObjectWithProvider from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR -from llama_stack.providers.utils.kvstore import ( - KVStore, - kvstore_impl, - SqliteKVStoreConfig, -) +from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl +from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig class DistributionRegistry(Protocol): diff --git a/llama_stack/providers/registry/agents.py b/llama_stack/providers/registry/agents.py index 8b6c9027c..6595b1955 100644 --- a/llama_stack/providers/registry/agents.py +++ b/llama_stack/providers/registry/agents.py @@ -6,7 +6,13 @@ from typing import List -from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.providers.datatypes import ( + AdapterSpec, + Api, + InlineProviderSpec, + ProviderSpec, + remote_provider_spec, +) from llama_stack.providers.utils.kvstore import kvstore_dependencies diff --git a/llama_stack/providers/registry/datasetio.py b/llama_stack/providers/registry/datasetio.py index 403c41111..f83dcbc60 100644 --- a/llama_stack/providers/registry/datasetio.py +++ b/llama_stack/providers/registry/datasetio.py @@ -6,7 +6,13 @@ from typing import List -from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.providers.datatypes import ( + AdapterSpec, + Api, + InlineProviderSpec, + ProviderSpec, + remote_provider_spec, +) def available_providers() -> List[ProviderSpec]: diff --git a/llama_stack/providers/registry/eval.py b/llama_stack/providers/registry/eval.py index 718c7eae5..6901c3741 100644 --- a/llama_stack/providers/registry/eval.py +++ b/llama_stack/providers/registry/eval.py @@ -6,7 +6,7 @@ from typing import List -from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec def available_providers() -> List[ProviderSpec]: diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 0ff557b9f..397e8b7ee 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -6,8 +6,13 @@ from typing import List -from llama_stack.distribution.datatypes import * # noqa: F403 - +from llama_stack.providers.datatypes import ( + AdapterSpec, + Api, + InlineProviderSpec, + ProviderSpec, + remote_provider_spec, +) META_REFERENCE_DEPS = [ "accelerate", diff --git a/llama_stack/providers/registry/memory.py b/llama_stack/providers/registry/memory.py index c18bd3873..6867a9186 100644 --- a/llama_stack/providers/registry/memory.py +++ b/llama_stack/providers/registry/memory.py @@ -6,8 +6,13 @@ from typing import List -from llama_stack.distribution.datatypes import * # noqa: F403 - +from llama_stack.providers.datatypes import ( + AdapterSpec, + Api, + InlineProviderSpec, + ProviderSpec, + remote_provider_spec, +) EMBEDDING_DEPS = [ "blobfile", diff --git a/llama_stack/providers/registry/post_training.py b/llama_stack/providers/registry/post_training.py index af8b660fa..3c5d06c05 100644 --- a/llama_stack/providers/registry/post_training.py +++ b/llama_stack/providers/registry/post_training.py @@ -6,7 +6,7 @@ from typing import List -from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec def available_providers() -> List[ProviderSpec]: diff --git a/llama_stack/providers/registry/safety.py b/llama_stack/providers/registry/safety.py index 99b0d2bd8..b9f7b6d78 100644 --- a/llama_stack/providers/registry/safety.py +++ b/llama_stack/providers/registry/safety.py @@ -6,7 +6,7 @@ from typing import List -from llama_stack.distribution.datatypes import ( +from llama_stack.providers.datatypes import ( AdapterSpec, Api, InlineProviderSpec, diff --git a/llama_stack/providers/registry/scoring.py b/llama_stack/providers/registry/scoring.py index f31ff44d7..ca09be984 100644 --- a/llama_stack/providers/registry/scoring.py +++ b/llama_stack/providers/registry/scoring.py @@ -6,7 +6,7 @@ from typing import List -from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec def available_providers() -> List[ProviderSpec]: diff --git a/llama_stack/providers/registry/telemetry.py b/llama_stack/providers/registry/telemetry.py index d367bf894..ba7e2f806 100644 --- a/llama_stack/providers/registry/telemetry.py +++ b/llama_stack/providers/registry/telemetry.py @@ -6,7 +6,13 @@ from typing import List -from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.providers.datatypes import ( + AdapterSpec, + Api, + InlineProviderSpec, + ProviderSpec, + remote_provider_spec, +) def available_providers() -> List[ProviderSpec]: diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index f3e6aead8..042aef9d9 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -6,7 +6,7 @@ from typing import List -from llama_stack.distribution.datatypes import ( +from llama_stack.providers.datatypes import ( AdapterSpec, Api, InlineProviderSpec, diff --git a/llama_stack/providers/tests/memory/fixtures.py b/llama_stack/providers/tests/memory/fixtures.py index 10314ea03..9a98526ab 100644 --- a/llama_stack/providers/tests/memory/fixtures.py +++ b/llama_stack/providers/tests/memory/fixtures.py @@ -18,7 +18,7 @@ from llama_stack.providers.remote.memory.chroma import ChromaRemoteImplConfig from llama_stack.providers.remote.memory.pgvector import PGVectorConfig from llama_stack.providers.remote.memory.weaviate import WeaviateConfig from llama_stack.providers.tests.resolver import construct_stack_for_test -from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig +from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig from ..conftest import ProviderFixture, remote_stack_fixture from ..env import get_env_or_fail diff --git a/llama_stack/providers/tests/resolver.py b/llama_stack/providers/tests/resolver.py index 8bbb902cd..5a38aaecc 100644 --- a/llama_stack/providers/tests/resolver.py +++ b/llama_stack/providers/tests/resolver.py @@ -8,14 +8,24 @@ import json import tempfile from typing import Any, Dict, List, Optional -from llama_stack.distribution.datatypes import * # noqa: F403 +from pydantic import BaseModel + +from llama_stack.apis.datasets import DatasetInput +from llama_stack.apis.eval_tasks import EvalTaskInput +from llama_stack.apis.memory_banks import MemoryBankInput +from llama_stack.apis.models import ModelInput +from llama_stack.apis.scoring_functions import ScoringFnInput +from llama_stack.apis.shields import ShieldInput + from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.configure import parse_and_maybe_upgrade_config +from llama_stack.distribution.datatypes import Provider, StackRunConfig from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.resolver import resolve_remote_stack_impls from llama_stack.distribution.stack import construct_stack -from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig +from llama_stack.providers.datatypes import Api, RemoteProviderConfig +from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig class TestStack(BaseModel): diff --git a/llama_stack/providers/tests/safety/test_safety.py b/llama_stack/providers/tests/safety/test_safety.py index b015e8b06..857fe57f9 100644 --- a/llama_stack/providers/tests/safety/test_safety.py +++ b/llama_stack/providers/tests/safety/test_safety.py @@ -6,11 +6,9 @@ import pytest -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.apis.safety import * # noqa: F403 - -from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.apis.inference import UserMessage +from llama_stack.apis.safety import ViolationLevel +from llama_stack.apis.shields import Shield # How to run this test: #