chore: Consistent naming for VectorIO providers (#1023)

# What does this PR do?

This changes all VectorIO providers classes to follow the pattern
`<ProviderName>VectorIOConfig` and `<ProviderName>VectorIOAdapter`. All
API endpoints for VectorIOs are currently consistent with `/vector-io`.

Note that API endpoint for VectorDB stay unchanged as `/vector-dbs`. 

## Test Plan

I don't have a way to test all providers. This is a simple renaming so
things should work as expected.

---------

Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
Yuan Tang 2025-02-13 13:15:49 -05:00 committed by GitHub
parent e4a1579e63
commit 8ff27b58fa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
34 changed files with 85 additions and 86 deletions

View file

@ -8,10 +8,10 @@ from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import ChromaRemoteImplConfig
from .config import ChromaVectorIOConfig
async def get_adapter_impl(config: ChromaRemoteImplConfig, deps: Dict[Api, ProviderSpec]):
async def get_adapter_impl(config: ChromaVectorIOConfig, deps: Dict[Api, ProviderSpec]):
from .chroma import ChromaVectorIOAdapter
impl = ChromaVectorIOAdapter(config, deps[Api.inference])

View file

@ -16,13 +16,12 @@ from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.chroma import ChromaInlineImplConfig
from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex,
VectorDBWithIndex,
)
from .config import ChromaRemoteImplConfig
from .config import ChromaVectorIOConfig
log = logging.getLogger(__name__)
@ -89,7 +88,7 @@ class ChromaIndex(EmbeddingIndex):
class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(
self,
config: Union[ChromaRemoteImplConfig, ChromaInlineImplConfig],
config: Union[ChromaVectorIOConfig, ChromaVectorIOConfig],
inference_api: Api.inference,
) -> None:
log.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
@ -100,7 +99,7 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
self.cache = {}
async def initialize(self) -> None:
if isinstance(self.config, ChromaRemoteImplConfig):
if isinstance(self.config, ChromaVectorIOConfig):
log.info(f"Connecting to Chroma server at: {self.config.url}")
url = self.config.url.rstrip("/")
parsed = urlparse(url)

View file

@ -9,7 +9,7 @@ from typing import Any, Dict
from pydantic import BaseModel
class ChromaRemoteImplConfig(BaseModel):
class ChromaVectorIOConfig(BaseModel):
url: str
@classmethod

View file

@ -8,12 +8,12 @@ from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import PGVectorConfig
from .config import PGVectorVectorIOConfig
async def get_adapter_impl(config: PGVectorConfig, deps: Dict[Api, ProviderSpec]):
from .pgvector import PGVectorVectorDBAdapter
async def get_adapter_impl(config: PGVectorVectorIOConfig, deps: Dict[Api, ProviderSpec]):
from .pgvector import PGVectorVectorIOAdapter
impl = PGVectorVectorDBAdapter(config, deps[Api.inference])
impl = PGVectorVectorIOAdapter(config, deps[Api.inference])
await impl.initialize()
return impl

View file

@ -9,7 +9,7 @@ from pydantic import BaseModel, Field
@json_schema_type
class PGVectorConfig(BaseModel):
class PGVectorVectorIOConfig(BaseModel):
host: str = Field(default="localhost")
port: int = Field(default=5432)
db: str = Field(default="postgres")

View file

@ -22,7 +22,7 @@ from llama_stack.providers.utils.memory.vector_store import (
VectorDBWithIndex,
)
from .config import PGVectorConfig
from .config import PGVectorVectorIOConfig
log = logging.getLogger(__name__)
@ -121,8 +121,8 @@ class PGVectorIndex(EmbeddingIndex):
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
class PGVectorVectorDBAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: PGVectorConfig, inference_api: Api.inference) -> None:
class PGVectorVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: PGVectorVectorIOConfig, inference_api: Api.inference) -> None:
self.config = config
self.inference_api = inference_api
self.conn = None

View file

@ -8,12 +8,12 @@ from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import QdrantConfig
from .config import QdrantVectorIOConfig
async def get_adapter_impl(config: QdrantConfig, deps: Dict[Api, ProviderSpec]):
from .qdrant import QdrantVectorDBAdapter
async def get_adapter_impl(config: QdrantVectorIOConfig, deps: Dict[Api, ProviderSpec]):
from .qdrant import QdrantVectorIOAdapter
impl = QdrantVectorDBAdapter(config, deps[Api.inference])
impl = QdrantVectorIOAdapter(config, deps[Api.inference])
await impl.initialize()
return impl

View file

@ -11,7 +11,7 @@ from pydantic import BaseModel
@json_schema_type
class QdrantConfig(BaseModel):
class QdrantVectorIOConfig(BaseModel):
location: Optional[str] = None
url: Optional[str] = None
port: Optional[int] = 6333

View file

@ -21,7 +21,7 @@ from llama_stack.providers.utils.memory.vector_store import (
VectorDBWithIndex,
)
from .config import QdrantConfig
from .config import QdrantVectorIOConfig
log = logging.getLogger(__name__)
CHUNK_ID_KEY = "_chunk_id"
@ -98,8 +98,8 @@ class QdrantIndex(EmbeddingIndex):
await self.client.delete_collection(collection_name=self.collection_name)
class QdrantVectorDBAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: QdrantConfig, inference_api: Api.inference) -> None:
class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: QdrantVectorIOConfig, inference_api: Api.inference) -> None:
self.config = config
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True))
self.cache = {}

View file

@ -6,12 +6,12 @@
from typing import Any
from .config import SampleConfig
from .config import SampleVectorIOConfig
async def get_adapter_impl(config: SampleConfig, _deps) -> Any:
from .sample import SampleMemoryImpl
async def get_adapter_impl(config: SampleVectorIOConfig, _deps) -> Any:
from .sample import SampleVectorIOImpl
impl = SampleMemoryImpl(config)
impl = SampleVectorIOImpl(config)
await impl.initialize()
return impl

View file

@ -7,6 +7,6 @@
from pydantic import BaseModel
class SampleConfig(BaseModel):
class SampleVectorIOConfig(BaseModel):
host: str = "localhost"
port: int = 9999

View file

@ -7,11 +7,11 @@
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import VectorIO
from .config import SampleConfig
from .config import SampleVectorIOConfig
class SampleMemoryImpl(VectorIO):
def __init__(self, config: SampleConfig):
class SampleVectorIOImpl(VectorIO):
def __init__(self, config: SampleVectorIOConfig):
self.config = config
async def register_vector_db(self, vector_db: VectorDB) -> None:

View file

@ -8,12 +8,12 @@ from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import WeaviateConfig, WeaviateRequestProviderData # noqa: F401
from .config import WeaviateRequestProviderData, WeaviateVectorIOConfig # noqa: F401
async def get_adapter_impl(config: WeaviateConfig, deps: Dict[Api, ProviderSpec]):
from .weaviate import WeaviateMemoryAdapter
async def get_adapter_impl(config: WeaviateVectorIOConfig, deps: Dict[Api, ProviderSpec]):
from .weaviate import WeaviateVectorIOAdapter
impl = WeaviateMemoryAdapter(config, deps[Api.inference])
impl = WeaviateVectorIOAdapter(config, deps[Api.inference])
await impl.initialize()
return impl

View file

@ -12,5 +12,5 @@ class WeaviateRequestProviderData(BaseModel):
weaviate_cluster_url: str
class WeaviateConfig(BaseModel):
class WeaviateVectorIOConfig(BaseModel):
pass

View file

@ -23,7 +23,7 @@ from llama_stack.providers.utils.memory.vector_store import (
VectorDBWithIndex,
)
from .config import WeaviateConfig, WeaviateRequestProviderData
from .config import WeaviateRequestProviderData, WeaviateVectorIOConfig
log = logging.getLogger(__name__)
@ -85,12 +85,12 @@ class WeaviateIndex(EmbeddingIndex):
collection.data.delete_many(where=Filter.by_property("id").contains_any(chunk_ids))
class WeaviateMemoryAdapter(
class WeaviateVectorIOAdapter(
VectorIO,
NeedsRequestProviderData,
VectorDBsProtocolPrivate,
):
def __init__(self, config: WeaviateConfig, inference_api: Api.inference) -> None:
def __init__(self, config: WeaviateVectorIOConfig, inference_api: Api.inference) -> None:
self.config = config
self.inference_api = inference_api
self.client_cache = {}