mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-25 17:31:58 +00:00
feat: add auto-generated CI documentation pre-commit hook (#2890)
Our CI is entirely undocumented, this commit adds a README.md file with a table of the current CI and what is does --------- Signed-off-by: Nathan Weinberg <nweinber@redhat.com>
This commit is contained in:
parent
7f834339ba
commit
b381ed6d64
93 changed files with 495 additions and 477 deletions
|
|
@ -65,7 +65,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
from .config import FireworksImplConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
log = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
|
|
@ -256,7 +256,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
"stream": bool(request.stream),
|
||||
**self._build_options(request.sampling_params, request.response_format, request.logprobs),
|
||||
}
|
||||
logger.debug(f"params to fireworks: {params}")
|
||||
log.debug(f"params to fireworks: {params}")
|
||||
|
||||
return params
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import logging
|
||||
|
||||
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
|
@ -11,8 +10,6 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
|||
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
|
|
@ -33,6 +32,7 @@ from llama_stack.apis.inference import (
|
|||
ToolChoice,
|
||||
ToolConfig,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
|
|
@ -54,7 +54,7 @@ from .openai_utils import (
|
|||
)
|
||||
from .utils import _is_nvidia_hosted
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
|
||||
|
|
@ -75,7 +75,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
|
|||
# TODO(mf): filter by available models
|
||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||
|
||||
logger.info(f"Initializing NVIDIAInferenceAdapter({config.url})...")
|
||||
log.info(f"Initializing NVIDIAInferenceAdapter({config.url})...")
|
||||
|
||||
if _is_nvidia_hosted(config):
|
||||
if not config.api_key:
|
||||
|
|
|
|||
|
|
@ -4,13 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
import httpx
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from . import NVIDIAConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
|
||||
|
|
@ -44,7 +45,7 @@ async def check_health(config: NVIDIAConfig) -> None:
|
|||
RuntimeError: If the server is not running or ready
|
||||
"""
|
||||
if not _is_nvidia_hosted(config):
|
||||
logger.info("Checking NVIDIA NIM health...")
|
||||
log.info("Checking NVIDIA NIM health...")
|
||||
try:
|
||||
is_live, is_ready = await _get_health(config.url)
|
||||
if not is_live:
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
log = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class OllamaInferenceAdapter(
|
||||
|
|
@ -117,10 +117,10 @@ class OllamaInferenceAdapter(
|
|||
return self._openai_client
|
||||
|
||||
async def initialize(self) -> None:
|
||||
logger.info(f"checking connectivity to Ollama at `{self.config.url}`...")
|
||||
log.info(f"checking connectivity to Ollama at `{self.config.url}`...")
|
||||
health_response = await self.health()
|
||||
if health_response["status"] == HealthStatus.ERROR:
|
||||
logger.warning(
|
||||
log.warning(
|
||||
"Ollama Server is not running, make sure to start it using `ollama serve` in a separate terminal"
|
||||
)
|
||||
|
||||
|
|
@ -339,7 +339,7 @@ class OllamaInferenceAdapter(
|
|||
"options": sampling_options,
|
||||
"stream": request.stream,
|
||||
}
|
||||
logger.debug(f"params to ollama: {params}")
|
||||
log.debug(f"params to ollama: {params}")
|
||||
|
||||
return params
|
||||
|
||||
|
|
@ -437,7 +437,7 @@ class OllamaInferenceAdapter(
|
|||
if provider_resource_id not in available_models:
|
||||
available_models_latest = [m.model.split(":latest")[0] for m in response.models]
|
||||
if provider_resource_id in available_models_latest:
|
||||
logger.warning(
|
||||
log.warning(
|
||||
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'"
|
||||
)
|
||||
return model
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
|
@ -12,8 +11,6 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
|||
from .config import OpenAIConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
#
|
||||
# This OpenAI adapter implements Inference methods using two mixins -
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from huggingface_hub import AsyncInferenceClient, HfApi
|
||||
|
|
@ -34,6 +33,7 @@ from llama_stack.apis.inference import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.sku_list import all_registered_models
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
|
|
@ -58,7 +58,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = get_logger(__name__, category="core")
|
||||
|
||||
|
||||
def build_hf_repo_model_entries():
|
||||
|
|
@ -307,7 +307,7 @@ class TGIAdapter(_HfAdapter):
|
|||
async def initialize(self, config: TGIImplConfig) -> None:
|
||||
if not config.url:
|
||||
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
|
||||
log.info(f"Initializing TGI client with url={config.url}")
|
||||
logger.info(f"Initializing TGI client with url={config.url}")
|
||||
self.client = AsyncInferenceClient(
|
||||
model=config.url,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
from .config import TogetherImplConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
log = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
||||
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
|
|
@ -232,7 +232,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
"stream": request.stream,
|
||||
**self._build_options(request.sampling_params, request.logprobs, request.response_format),
|
||||
}
|
||||
logger.debug(f"params to together: {params}")
|
||||
log.debug(f"params to together: {params}")
|
||||
return params
|
||||
|
||||
async def embeddings(
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -15,8 +14,6 @@ from llama_stack.providers.remote.post_training.nvidia.config import SFTLoRADefa
|
|||
|
||||
from .config import NvidiaPostTrainingConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def warn_unsupported_params(config_dict: Any, supported_keys: set[str], config_name: str) -> None:
|
||||
keys = set(config_dict.__annotations__.keys()) if isinstance(config_dict, BaseModel) else config_dict.keys()
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
|
|
@ -16,12 +15,13 @@ from llama_stack.apis.safety import (
|
|||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||
|
||||
from .config import BedrockSafetyConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="safety")
|
||||
|
||||
|
||||
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||
|
|
@ -76,13 +76,13 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
|||
"""
|
||||
|
||||
shield_params = shield.params
|
||||
logger.debug(f"run_shield::{shield_params}::messages={messages}")
|
||||
log.debug(f"run_shield::{shield_params}::messages={messages}")
|
||||
|
||||
# - convert the messages into format Bedrock expects
|
||||
content_messages = []
|
||||
for message in messages:
|
||||
content_messages.append({"text": {"text": message.content}})
|
||||
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
|
||||
log.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
|
||||
|
||||
response = self.bedrock_runtime_client.apply_guardrail(
|
||||
guardrailIdentifier=shield.provider_resource_id,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
|
@ -17,8 +16,6 @@ from llama_stack.providers.utils.inference.openai_compat import convert_message_
|
|||
|
||||
from .config import NVIDIASafetyConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||
def __init__(self, config: NVIDIASafetyConfig) -> None:
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
|
|
@ -20,12 +19,13 @@ from llama_stack.apis.safety import (
|
|||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
|
||||
|
||||
from .config import SambaNovaSafetyConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="safety")
|
||||
|
||||
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
||||
|
||||
|
|
@ -66,7 +66,7 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide
|
|||
"guard" not in shield.provider_resource_id.lower()
|
||||
or shield.provider_resource_id.split("sambanova/")[-1] not in self.environment_available_models
|
||||
):
|
||||
logger.warning(f"Shield {shield.provider_resource_id} not available in {list_models_url}")
|
||||
log.warning(f"Shield {shield.provider_resource_id} not available in {list_models_url}")
|
||||
|
||||
async def unregister_shield(self, identifier: str) -> None:
|
||||
pass
|
||||
|
|
@ -79,9 +79,9 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide
|
|||
raise ValueError(f"Shield {shield_id} not found")
|
||||
|
||||
shield_params = shield.params
|
||||
logger.debug(f"run_shield::{shield_params}::messages={messages}")
|
||||
log.debug(f"run_shield::{shield_params}::messages={messages}")
|
||||
content_messages = [await convert_message_to_openai_dict_new(m) for m in messages]
|
||||
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
|
||||
log.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
|
||||
|
||||
response = litellm.completion(
|
||||
model=shield.provider_resource_id, messages=content_messages, api_key=self._get_api_key()
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
|
@ -20,6 +19,7 @@ from llama_stack.apis.vector_io import (
|
|||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
|
@ -32,8 +32,6 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
|
||||
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI
|
||||
|
||||
VERSION = "v3"
|
||||
|
|
@ -43,6 +41,8 @@ OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:chroma:{VERSION}::"
|
|||
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:chroma:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:chroma:{VERSION}::"
|
||||
|
||||
logger = get_logger(__name__, category="core")
|
||||
|
||||
|
||||
# this is a helper to allow us to use async and non-async chroma clients interchangeably
|
||||
async def maybe_await(result):
|
||||
|
|
@ -92,7 +92,7 @@ class ChromaIndex(EmbeddingIndex):
|
|||
doc = json.loads(doc)
|
||||
chunk = Chunk(**doc)
|
||||
except Exception:
|
||||
log.exception(f"Failed to parse document: {doc}")
|
||||
logger.exception(f"Failed to parse document: {doc}")
|
||||
continue
|
||||
|
||||
score = 1.0 / float(dist) if dist != 0 else float("inf")
|
||||
|
|
@ -137,7 +137,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
inference_api: Api.inference,
|
||||
files_api: Files | None,
|
||||
) -> None:
|
||||
log.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
|
||||
logger.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.client = None
|
||||
|
|
@ -150,7 +150,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
self.vector_db_store = self.kvstore
|
||||
|
||||
if isinstance(self.config, RemoteChromaVectorIOConfig):
|
||||
log.info(f"Connecting to Chroma server at: {self.config.url}")
|
||||
logger.info(f"Connecting to Chroma server at: {self.config.url}")
|
||||
url = self.config.url.rstrip("/")
|
||||
parsed = urlparse(url)
|
||||
|
||||
|
|
@ -159,7 +159,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
|
||||
self.client = await chromadb.AsyncHttpClient(host=parsed.hostname, port=parsed.port)
|
||||
else:
|
||||
log.info(f"Connecting to Chroma local db at: {self.config.db_path}")
|
||||
logger.info(f"Connecting to Chroma local db at: {self.config.db_path}")
|
||||
self.client = chromadb.PersistentClient(path=self.config.db_path)
|
||||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||
|
||||
|
|
@ -182,7 +182,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
if vector_db_id not in self.cache:
|
||||
log.warning(f"Vector DB {vector_db_id} not found")
|
||||
logger.warning(f"Vector DB {vector_db_id} not found")
|
||||
return
|
||||
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -21,6 +20,7 @@ from llama_stack.apis.vector_io import (
|
|||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
|
@ -34,7 +34,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti
|
|||
|
||||
from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
log = get_logger(name=__name__, category="core")
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:milvus:{VERSION}::"
|
||||
|
|
@ -68,7 +68,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
)
|
||||
|
||||
if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
|
||||
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
|
||||
log.info(f"Creating new collection {self.collection_name} with nullable sparse field")
|
||||
# Create schema for vector search
|
||||
schema = self.client.create_schema()
|
||||
schema.add_field(
|
||||
|
|
@ -147,7 +147,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
data=data,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
|
||||
log.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
|
||||
raise e
|
||||
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
|
|
@ -203,7 +203,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error performing BM25 search: {e}")
|
||||
log.error(f"Error performing BM25 search: {e}")
|
||||
# Fallback to simple text search
|
||||
return await self._fallback_keyword_search(query_string, k, score_threshold)
|
||||
|
||||
|
|
@ -247,7 +247,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
self.client.delete, collection_name=self.collection_name, filter=f'chunk_id == "{chunk_id}"'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting chunk {chunk_id} from Milvus collection {self.collection_name}: {e}")
|
||||
log.error(f"Error deleting chunk {chunk_id} from Milvus collection {self.collection_name}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
|
|
@ -288,10 +288,10 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
)
|
||||
self.cache[vector_db.identifier] = index
|
||||
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
||||
logger.info(f"Connecting to Milvus server at {self.config.uri}")
|
||||
log.info(f"Connecting to Milvus server at {self.config.uri}")
|
||||
self.client = MilvusClient(**self.config.model_dump(exclude_none=True))
|
||||
else:
|
||||
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
|
||||
log.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
|
||||
uri = os.path.expanduser(self.config.db_path)
|
||||
self.client = MilvusClient(uri=uri)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import psycopg2
|
||||
|
|
@ -22,6 +21,7 @@ from llama_stack.apis.vector_io import (
|
|||
QueryChunksResponse,
|
||||
VectorIO,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
|
|
@ -33,8 +33,6 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
|
||||
from .config import PGVectorVectorIOConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::"
|
||||
VECTOR_INDEX_PREFIX = f"vector_index:pgvector:{VERSION}::"
|
||||
|
|
@ -42,6 +40,8 @@ OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:pgvector:{VERSION}::"
|
|||
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:pgvector:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:pgvector:{VERSION}::"
|
||||
|
||||
logger = get_logger(__name__, category="core")
|
||||
|
||||
|
||||
def check_extension_version(cur):
|
||||
cur.execute("SELECT extversion FROM pg_extension WHERE extname = 'vector'")
|
||||
|
|
@ -187,7 +187,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
|||
self.metadatadata_collection_name = "openai_vector_stores_metadata"
|
||||
|
||||
async def initialize(self) -> None:
|
||||
log.info(f"Initializing PGVector memory adapter with config: {self.config}")
|
||||
logger.info(f"Initializing PGVector memory adapter with config: {self.config}")
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
await self.initialize_openai_vector_stores()
|
||||
|
||||
|
|
@ -203,7 +203,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
|||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
version = check_extension_version(cur)
|
||||
if version:
|
||||
log.info(f"Vector extension version: {version}")
|
||||
logger.info(f"Vector extension version: {version}")
|
||||
else:
|
||||
raise RuntimeError("Vector extension is not installed.")
|
||||
|
||||
|
|
@ -216,13 +216,13 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
|||
"""
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception("Could not connect to PGVector database server")
|
||||
logger.exception("Could not connect to PGVector database server")
|
||||
raise RuntimeError("Could not connect to PGVector database server") from e
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self.conn is not None:
|
||||
self.conn.close()
|
||||
log.info("Connection to PGVector database server closed")
|
||||
logger.info("Connection to PGVector database server closed")
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
# Persist vector DB metadata in the KV store
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import (
|
|||
VectorStoreChunkingStrategy,
|
||||
VectorStoreFileObject,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
|
||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||
|
|
@ -35,13 +35,14 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
|
||||
from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
CHUNK_ID_KEY = "_chunk_id"
|
||||
|
||||
# KV store prefixes for vector databases
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:qdrant:{VERSION}::"
|
||||
|
||||
logger = get_logger(__name__, category="core")
|
||||
|
||||
|
||||
def convert_id(_id: str) -> str:
|
||||
"""
|
||||
|
|
@ -96,7 +97,7 @@ class QdrantIndex(EmbeddingIndex):
|
|||
points_selector=models.PointIdsList(points=[convert_id(chunk_id)]),
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error deleting chunk {chunk_id} from Qdrant collection {self.collection_name}: {e}")
|
||||
logger.error(f"Error deleting chunk {chunk_id} from Qdrant collection {self.collection_name}: {e}")
|
||||
raise
|
||||
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
|
|
@ -118,7 +119,7 @@ class QdrantIndex(EmbeddingIndex):
|
|||
try:
|
||||
chunk = Chunk(**point.payload["chunk_content"])
|
||||
except Exception:
|
||||
log.exception("Failed to parse chunk")
|
||||
logger.exception("Failed to parse chunk")
|
||||
continue
|
||||
|
||||
chunks.append(chunk)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import weaviate
|
||||
|
|
@ -19,6 +18,7 @@ from llama_stack.apis.files.files import Files
|
|||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
|
|
@ -33,8 +33,6 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti
|
|||
|
||||
from .config import WeaviateVectorIOConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::"
|
||||
VECTOR_INDEX_PREFIX = f"vector_index:weaviate:{VERSION}::"
|
||||
|
|
@ -42,6 +40,8 @@ OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:weaviate:{VERSION}::"
|
|||
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:weaviate:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:weaviate:{VERSION}::"
|
||||
|
||||
logger = get_logger(__name__, category="core")
|
||||
|
||||
|
||||
class WeaviateIndex(EmbeddingIndex):
|
||||
def __init__(
|
||||
|
|
@ -102,7 +102,7 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
chunk_dict = json.loads(chunk_json)
|
||||
chunk = Chunk(**chunk_dict)
|
||||
except Exception:
|
||||
log.exception(f"Failed to parse document: {chunk_json}")
|
||||
logger.exception(f"Failed to parse document: {chunk_json}")
|
||||
continue
|
||||
|
||||
score = 1.0 / doc.metadata.distance if doc.metadata.distance != 0 else float("inf")
|
||||
|
|
@ -171,7 +171,7 @@ class WeaviateVectorIOAdapter(
|
|||
|
||||
def _get_client(self) -> weaviate.Client:
|
||||
if "localhost" in self.config.weaviate_cluster_url:
|
||||
log.info("using Weaviate locally in container")
|
||||
logger.info("using Weaviate locally in container")
|
||||
host, port = self.config.weaviate_cluster_url.split(":")
|
||||
key = "local_test"
|
||||
client = weaviate.connect_to_local(
|
||||
|
|
@ -179,7 +179,7 @@ class WeaviateVectorIOAdapter(
|
|||
port=port,
|
||||
)
|
||||
else:
|
||||
log.info("Using Weaviate remote cluster with URL")
|
||||
logger.info("Using Weaviate remote cluster with URL")
|
||||
key = f"{self.config.weaviate_cluster_url}::{self.config.weaviate_api_key}"
|
||||
if key in self.client_cache:
|
||||
return self.client_cache[key]
|
||||
|
|
@ -197,7 +197,7 @@ class WeaviateVectorIOAdapter(
|
|||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
else:
|
||||
self.kvstore = None
|
||||
log.info("No kvstore configured, registry will not persist across restarts")
|
||||
logger.info("No kvstore configured, registry will not persist across restarts")
|
||||
|
||||
# Load existing vector DB definitions
|
||||
if self.kvstore is not None:
|
||||
|
|
@ -254,7 +254,7 @@ class WeaviateVectorIOAdapter(
|
|||
client = self._get_client()
|
||||
sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True)
|
||||
if sanitized_collection_name not in self.cache or client.collections.exists(sanitized_collection_name) is False:
|
||||
log.warning(f"Vector DB {sanitized_collection_name} not found")
|
||||
logger.warning(f"Vector DB {sanitized_collection_name} not found")
|
||||
return
|
||||
client.collections.delete(sanitized_collection_name)
|
||||
await self.cache[sanitized_collection_name].index.delete()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue