From b641902bfacc72904aaea452f800d29796c84124 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 26 Dec 2024 18:01:45 -0800 Subject: [PATCH] impls imports remove --- llama_stack/distribution/build.py | 11 +++++----- .../providers/remote/inference/vllm/vllm.py | 22 +++++++++++++++++-- .../providers/remote/memory/chroma/chroma.py | 10 +++++++-- .../remote/memory/pgvector/pgvector.py | 12 +++++++--- .../providers/remote/memory/qdrant/qdrant.py | 13 +++++++---- .../providers/remote/memory/sample/sample.py | 5 ++--- .../remote/memory/weaviate/weaviate.py | 10 +++++++-- .../remote/safety/bedrock/bedrock.py | 11 ++++++++-- .../providers/remote/safety/sample/sample.py | 5 ++--- .../tests/agents/test_persistence.py | 6 ++--- .../tests/datasetio/test_datasetio.py | 13 ++++++----- .../providers/tests/post_training/fixtures.py | 3 ++- .../tests/post_training/test_post_training.py | 15 ++++++++++--- 13 files changed, 97 insertions(+), 39 deletions(-) diff --git a/llama_stack/distribution/build.py b/llama_stack/distribution/build.py index bdda0349f..f376301f9 100644 --- a/llama_stack/distribution/build.py +++ b/llama_stack/distribution/build.py @@ -6,21 +6,22 @@ import logging from enum import Enum -from typing import List + +from pathlib import Path +from typing import Dict, List import pkg_resources from pydantic import BaseModel from termcolor import cprint -from llama_stack.distribution.utils.exec import run_with_pty - -from llama_stack.distribution.datatypes import * # noqa: F403 -from pathlib import Path +from llama_stack.distribution.datatypes import BuildConfig, Provider from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR +from llama_stack.distribution.utils.exec import run_with_pty +from llama_stack.providers.datatypes import Api log = logging.getLogger(__name__) diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 7250d901f..f62ccaa58 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import logging -from typing import AsyncGenerator +from typing import AsyncGenerator, List, Optional, Union from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer @@ -13,7 +13,25 @@ from llama_models.sku_list import all_registered_models from openai import OpenAI -from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.inference import ( + ChatCompletionRequest, + ChatCompletionResponse, + CompletionRequest, + CompletionResponse, + CompletionResponseStreamChunk, + EmbeddingsResponse, + Inference, + LogProbConfig, + Message, + ResponseFormat, + ResponseFormatType, + SamplingParams, + ToolChoice, + ToolDefinition, + ToolPromptFormat, +) +from llama_stack.apis.models import Model, ModelType from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ( diff --git a/llama_stack/providers/remote/memory/chroma/chroma.py b/llama_stack/providers/remote/memory/chroma/chroma.py index aa8b481a3..c04d775ca 100644 --- a/llama_stack/providers/remote/memory/chroma/chroma.py +++ b/llama_stack/providers/remote/memory/chroma/chroma.py @@ -12,8 +12,14 @@ from urllib.parse import urlparse import chromadb from numpy.typing import NDArray -from llama_stack.apis.memory import * # noqa: F403 -from llama_stack.apis.memory_banks import MemoryBankType +from llama_stack.apis.inference import InterleavedContent +from llama_stack.apis.memory import ( + Chunk, + Memory, + MemoryBankDocument, + QueryDocumentsResponse, +) +from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig from llama_stack.providers.utils.memory.vector_store import ( diff --git a/llama_stack/providers/remote/memory/pgvector/pgvector.py b/llama_stack/providers/remote/memory/pgvector/pgvector.py index ffe164ecb..b2c720b2c 100644 --- a/llama_stack/providers/remote/memory/pgvector/pgvector.py +++ b/llama_stack/providers/remote/memory/pgvector/pgvector.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import logging -from typing import List, Tuple +from typing import Any, Dict, List, Optional, Tuple import psycopg2 from numpy.typing import NDArray @@ -14,8 +14,14 @@ from psycopg2.extras import execute_values, Json from pydantic import BaseModel, parse_obj_as -from llama_stack.apis.memory import * # noqa: F403 -from llama_stack.apis.memory_banks import MemoryBankType, VectorMemoryBank +from llama_stack.apis.inference import InterleavedContent +from llama_stack.apis.memory import ( + Chunk, + Memory, + MemoryBankDocument, + QueryDocumentsResponse, +) +from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType, VectorMemoryBank from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate from llama_stack.providers.utils.memory.vector_store import ( diff --git a/llama_stack/providers/remote/memory/qdrant/qdrant.py b/llama_stack/providers/remote/memory/qdrant/qdrant.py index bf9e943c4..b1d5bd7fa 100644 --- a/llama_stack/providers/remote/memory/qdrant/qdrant.py +++ b/llama_stack/providers/remote/memory/qdrant/qdrant.py @@ -6,16 +6,21 @@ import logging import uuid -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from numpy.typing import NDArray from qdrant_client import AsyncQdrantClient, models from qdrant_client.models import PointStruct -from llama_stack.apis.memory_banks import * # noqa: F403 +from llama_stack.apis.inference import InterleavedContent +from llama_stack.apis.memory import ( + Chunk, + Memory, + MemoryBankDocument, + QueryDocumentsResponse, +) +from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate -from llama_stack.apis.memory import * # noqa: F403 - from llama_stack.providers.remote.memory.qdrant.config import QdrantConfig from llama_stack.providers.utils.memory.vector_store import ( BankWithIndex, diff --git a/llama_stack/providers/remote/memory/sample/sample.py b/llama_stack/providers/remote/memory/sample/sample.py index 09ea2f32c..b051eb544 100644 --- a/llama_stack/providers/remote/memory/sample/sample.py +++ b/llama_stack/providers/remote/memory/sample/sample.py @@ -4,12 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from llama_stack.apis.memory import Memory +from llama_stack.apis.memory_banks import MemoryBank from .config import SampleConfig -from llama_stack.apis.memory import * # noqa: F403 - - class SampleMemoryImpl(Memory): def __init__(self, config: SampleConfig): self.config = config diff --git a/llama_stack/providers/remote/memory/weaviate/weaviate.py b/llama_stack/providers/remote/memory/weaviate/weaviate.py index 8ee001cfa..f1433090d 100644 --- a/llama_stack/providers/remote/memory/weaviate/weaviate.py +++ b/llama_stack/providers/remote/memory/weaviate/weaviate.py @@ -14,8 +14,14 @@ from numpy.typing import NDArray from weaviate.classes.init import Auth from weaviate.classes.query import Filter -from llama_stack.apis.memory import * # noqa: F403 -from llama_stack.apis.memory_banks import MemoryBankType +from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.memory import ( + Chunk, + Memory, + MemoryBankDocument, + QueryDocumentsResponse, +) +from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate from llama_stack.providers.utils.memory.vector_store import ( diff --git a/llama_stack/providers/remote/safety/bedrock/bedrock.py b/llama_stack/providers/remote/safety/bedrock/bedrock.py index 78e8105e0..fba7bf342 100644 --- a/llama_stack/providers/remote/safety/bedrock/bedrock.py +++ b/llama_stack/providers/remote/safety/bedrock/bedrock.py @@ -9,8 +9,15 @@ import logging from typing import Any, Dict, List -from llama_stack.apis.safety import * # noqa -from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.apis.inference import Message + +from llama_stack.apis.safety import ( + RunShieldResponse, + Safety, + SafetyViolation, + ViolationLevel, +) +from llama_stack.apis.shields import Shield from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.utils.bedrock.client import create_bedrock_client diff --git a/llama_stack/providers/remote/safety/sample/sample.py b/llama_stack/providers/remote/safety/sample/sample.py index 4069b8789..180e6c3b5 100644 --- a/llama_stack/providers/remote/safety/sample/sample.py +++ b/llama_stack/providers/remote/safety/sample/sample.py @@ -4,12 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from llama_stack.apis.safety import Safety +from llama_stack.apis.shields import Shield from .config import SampleConfig -from llama_stack.apis.safety import * # noqa: F403 - - class SampleSafetyImpl(Safety): def __init__(self, config: SampleConfig): self.config = config diff --git a/llama_stack/providers/tests/agents/test_persistence.py b/llama_stack/providers/tests/agents/test_persistence.py index 97094cd7a..38eb7de55 100644 --- a/llama_stack/providers/tests/agents/test_persistence.py +++ b/llama_stack/providers/tests/agents/test_persistence.py @@ -6,9 +6,9 @@ import pytest -from llama_stack.apis.agents import * # noqa: F403 -from llama_stack.providers.datatypes import * # noqa: F403 - +from llama_stack.apis.agents import AgentConfig, Turn +from llama_stack.apis.inference import SamplingParams, UserMessage +from llama_stack.providers.datatypes import Api from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig from .fixtures import pick_inference_model diff --git a/llama_stack/providers/tests/datasetio/test_datasetio.py b/llama_stack/providers/tests/datasetio/test_datasetio.py index 7d88b6115..46c99f5b3 100644 --- a/llama_stack/providers/tests/datasetio/test_datasetio.py +++ b/llama_stack/providers/tests/datasetio/test_datasetio.py @@ -4,16 +4,17 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import os - -import pytest -from llama_stack.apis.common.type_system import * # noqa: F403 -from llama_stack.apis.datasetio import * # noqa: F403 -from llama_stack.distribution.datatypes import * # noqa: F403 import base64 import mimetypes +import os from pathlib import Path +import pytest + +from llama_stack.apis.common.content_types import URL +from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType +from llama_stack.apis.datasets import Datasets + # How to run this test: # # pytest llama_stack/providers/tests/datasetio/test_datasetio.py diff --git a/llama_stack/providers/tests/post_training/fixtures.py b/llama_stack/providers/tests/post_training/fixtures.py index 17d9668b2..fd8a9e4f6 100644 --- a/llama_stack/providers/tests/post_training/fixtures.py +++ b/llama_stack/providers/tests/post_training/fixtures.py @@ -7,8 +7,9 @@ import pytest import pytest_asyncio -from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.common.content_types import URL + +from llama_stack.apis.common.type_system import StringType from llama_stack.apis.datasets import DatasetInput from llama_stack.apis.models import ModelInput diff --git a/llama_stack/providers/tests/post_training/test_post_training.py b/llama_stack/providers/tests/post_training/test_post_training.py index 4ecc05187..0645cd555 100644 --- a/llama_stack/providers/tests/post_training/test_post_training.py +++ b/llama_stack/providers/tests/post_training/test_post_training.py @@ -4,9 +4,18 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import pytest -from llama_stack.apis.common.type_system import * # noqa: F403 -from llama_stack.apis.post_training import * # noqa: F403 -from llama_stack.distribution.datatypes import * # noqa: F403 + +from llama_stack.apis.common.type_system import JobStatus +from llama_stack.apis.post_training import ( + Checkpoint, + DataConfig, + LoraFinetuningConfig, + OptimizerConfig, + PostTrainingJob, + PostTrainingJobArtifactsResponse, + PostTrainingJobStatusResponse, + TrainingConfig, +) # How to run this test: #