mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-24 19:00:01 +00:00
Merge branch 'main' into allow-dynamic-models-nvidia
This commit is contained in:
commit
c2ab8988e6
127 changed files with 3997 additions and 3394 deletions
|
|
@ -15,21 +15,26 @@ LLM_MODEL_IDS = [
|
|||
"anthropic/claude-3-5-haiku-latest",
|
||||
]
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="anthropic/voyage-3",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 1024, "context_length": 32000},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="anthropic/voyage-3-lite",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 512, "context_length": 32000},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="anthropic/voyage-code-3",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 1024, "context_length": 32000},
|
||||
),
|
||||
]
|
||||
MODEL_ENTRIES = (
|
||||
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
|
||||
+ [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="anthropic/voyage-3",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 1024, "context_length": 32000},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="anthropic/voyage-3-lite",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 512, "context_length": 32000},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="anthropic/voyage-code-3",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 1024, "context_length": 32000},
|
||||
),
|
||||
]
|
||||
+ SAFETY_MODELS_ENTRIES
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,6 +9,10 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
|
||||
# https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"meta.llama3-1-8b-instruct-v1:0",
|
||||
|
|
@ -22,4 +26,4 @@ MODEL_ENTRIES = [
|
|||
"meta.llama3-1-405b-instruct-v1:0",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
]
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
|
|
|||
|
|
@ -9,6 +9,9 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
# https://inference-docs.cerebras.ai/models
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.1-8b",
|
||||
|
|
@ -18,4 +21,8 @@ MODEL_ENTRIES = [
|
|||
"llama-3.3-70b",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
]
|
||||
build_hf_repo_model_entry(
|
||||
"llama-4-scout-17b-16e-instruct",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
),
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
|
|
|||
|
|
@ -47,7 +47,10 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .config import DatabricksImplConfig
|
||||
|
||||
model_entries = [
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
# https://docs.databricks.com/aws/en/machine-learning/model-serving/foundation-model-overview
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"databricks-meta-llama-3-1-70b-instruct",
|
||||
CoreModelId.llama3_1_70b_instruct.value,
|
||||
|
|
@ -56,7 +59,7 @@ model_entries = [
|
|||
"databricks-meta-llama-3-1-405b-instruct",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
]
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
||||
|
||||
class DatabricksInferenceAdapter(
|
||||
|
|
@ -66,7 +69,7 @@ class DatabricksInferenceAdapter(
|
|||
OpenAICompletionToLlamaStackMixin,
|
||||
):
|
||||
def __init__(self, config: DatabricksImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, model_entries=model_entries)
|
||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
|
|||
|
|
@ -11,6 +11,17 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-guard-3-8b",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-guard-3-11b-vision",
|
||||
CoreModelId.llama_guard_3_11b_vision.value,
|
||||
),
|
||||
]
|
||||
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-v3p1-8b-instruct",
|
||||
|
|
@ -40,14 +51,6 @@ MODEL_ENTRIES = [
|
|||
"accounts/fireworks/models/llama-v3p3-70b-instruct",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-guard-3-8b",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama-guard-3-11b-vision",
|
||||
CoreModelId.llama_guard_3_11b_vision.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama4-scout-instruct-basic",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
|
|
@ -64,4 +67,4 @@ MODEL_ENTRIES = [
|
|||
"context_length": 8192,
|
||||
},
|
||||
),
|
||||
]
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
|
|
|||
|
|
@ -17,11 +17,16 @@ LLM_MODEL_IDS = [
|
|||
"gemini/gemini-2.5-pro",
|
||||
]
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="gemini/text-embedding-004",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 768, "context_length": 2048},
|
||||
),
|
||||
]
|
||||
MODEL_ENTRIES = (
|
||||
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
|
||||
+ [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="gemini/text-embedding-004",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 768, "context_length": 2048},
|
||||
),
|
||||
]
|
||||
+ SAFETY_MODELS_ENTRIES
|
||||
)
|
||||
|
|
|
|||
|
|
@ -38,24 +38,18 @@ class GroqInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
provider_data_api_key_field="groq_api_key",
|
||||
)
|
||||
self.config = config
|
||||
self._openai_client = None
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
||||
if self._openai_client:
|
||||
await self._openai_client.close()
|
||||
self._openai_client = None
|
||||
|
||||
def _get_openai_client(self) -> AsyncOpenAI:
|
||||
if not self._openai_client:
|
||||
self._openai_client = AsyncOpenAI(
|
||||
base_url=f"{self.config.url}/openai/v1",
|
||||
api_key=self.config.api_key,
|
||||
)
|
||||
return self._openai_client
|
||||
return AsyncOpenAI(
|
||||
base_url=f"{self.config.url}/openai/v1",
|
||||
api_key=self.get_api_key(),
|
||||
)
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"groq/llama3-8b-8192",
|
||||
|
|
@ -51,4 +53,4 @@ MODEL_ENTRIES = [
|
|||
"groq/meta-llama/llama-4-maverick-17b-128e-instruct",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
),
|
||||
]
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
|
|
|||
|
|
@ -11,6 +11,9 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
# https://docs.nvidia.com/nim/large-language-models/latest/supported-llm-agnostic-architectures.html
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama3-8b-instruct",
|
||||
|
|
@ -99,4 +102,4 @@ MODEL_ENTRIES = [
|
|||
),
|
||||
# TODO(mf): how do we handle Nemotron models?
|
||||
# "Llama3.1-Nemotron-51B-Instruct" -> "meta/llama-3.1-nemotron-51b-instruct",
|
||||
]
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
|
|
|||
|
|
@ -12,6 +12,19 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = [
|
||||
# The Llama Guard models don't have their full fp16 versions
|
||||
# so we are going to alias their default version to the canonical SKU
|
||||
build_hf_repo_model_entry(
|
||||
"llama-guard3:8b",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama-guard3:1b",
|
||||
CoreModelId.llama_guard_3_1b.value,
|
||||
),
|
||||
]
|
||||
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"llama3.1:8b-instruct-fp16",
|
||||
|
|
@ -73,16 +86,6 @@ MODEL_ENTRIES = [
|
|||
"llama3.3:70b",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
# The Llama Guard models don't have their full fp16 versions
|
||||
# so we are going to alias their default version to the canonical SKU
|
||||
build_hf_repo_model_entry(
|
||||
"llama-guard3:8b",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"llama-guard3:1b",
|
||||
CoreModelId.llama_guard_3_1b.value,
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="all-minilm:l6-v2",
|
||||
aliases=["all-minilm"],
|
||||
|
|
@ -100,4 +103,4 @@ MODEL_ENTRIES = [
|
|||
"context_length": 8192,
|
||||
},
|
||||
),
|
||||
]
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
|
|
|||
|
|
@ -48,16 +48,20 @@ EMBEDDING_MODEL_IDS: dict[str, EmbeddingModelInfo] = {
|
|||
"text-embedding-3-small": EmbeddingModelInfo(1536, 8192),
|
||||
"text-embedding-3-large": EmbeddingModelInfo(3072, 8192),
|
||||
}
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
|
||||
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + [
|
||||
ProviderModelEntry(
|
||||
provider_model_id=model_id,
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": model_info.embedding_dimension,
|
||||
"context_length": model_info.context_length,
|
||||
},
|
||||
)
|
||||
for model_id, model_info in EMBEDDING_MODEL_IDS.items()
|
||||
]
|
||||
MODEL_ENTRIES = (
|
||||
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
|
||||
+ [
|
||||
ProviderModelEntry(
|
||||
provider_model_id=model_id,
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": model_info.embedding_dimension,
|
||||
"context_length": model_info.context_length,
|
||||
},
|
||||
)
|
||||
for model_id, model_info in EMBEDDING_MODEL_IDS.items()
|
||||
]
|
||||
+ SAFETY_MODELS_ENTRIES
|
||||
)
|
||||
|
|
|
|||
|
|
@ -59,9 +59,6 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
# if we do not set this, users will be exposed to the
|
||||
# litellm specific model names, an abstraction leak.
|
||||
self.is_openai_compat = True
|
||||
self._openai_client = AsyncOpenAI(
|
||||
api_key=self.config.api_key,
|
||||
)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
await super().initialize()
|
||||
|
|
@ -69,6 +66,11 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
async def shutdown(self) -> None:
|
||||
await super().shutdown()
|
||||
|
||||
def _get_openai_client(self) -> AsyncOpenAI:
|
||||
return AsyncOpenAI(
|
||||
api_key=self.get_api_key(),
|
||||
)
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
|
@ -120,7 +122,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
user=user,
|
||||
suffix=suffix,
|
||||
)
|
||||
return await self._openai_client.completions.create(**params)
|
||||
return await self._get_openai_client().completions.create(**params)
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
|
|
@ -176,7 +178,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
return await self._openai_client.chat.completions.create(**params)
|
||||
return await self._get_openai_client().chat.completions.create(**params)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
|
|
@ -204,7 +206,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
params["user"] = user
|
||||
|
||||
# Call OpenAI embeddings API
|
||||
response = await self._openai_client.embeddings.create(**params)
|
||||
response = await self._get_openai_client().embeddings.create(**params)
|
||||
|
||||
data = []
|
||||
for i, embedding_data in enumerate(response.data):
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from llama_stack.apis.inference import * # noqa: F403
|
|||
from llama_stack.apis.inference import OpenAIEmbeddingsResponse
|
||||
|
||||
# from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, build_hf_repo_model_entry
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionToLlamaStackMixin,
|
||||
OpenAICompletionToLlamaStackMixin,
|
||||
|
|
@ -25,6 +25,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
from .config import RunpodImplConfig
|
||||
|
||||
# https://docs.runpod.io/serverless/vllm/overview#compatible-models
|
||||
# https://github.com/runpod-workers/worker-vllm/blob/main/README.md#compatible-model-architectures
|
||||
RUNPOD_SUPPORTED_MODELS = {
|
||||
"Llama3.1-8B": "meta-llama/Llama-3.1-8B",
|
||||
"Llama3.1-70B": "meta-llama/Llama-3.1-70B",
|
||||
|
|
@ -40,6 +42,14 @@ RUNPOD_SUPPORTED_MODELS = {
|
|||
"Llama3.2-3B": "meta-llama/Llama-3.2-3B",
|
||||
}
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
# Create MODEL_ENTRIES from RUNPOD_SUPPORTED_MODELS for compatibility with starter template
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(provider_model_id, model_descriptor)
|
||||
for provider_model_id, model_descriptor in RUNPOD_SUPPORTED_MODELS.items()
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
||||
|
||||
class RunpodInferenceAdapter(
|
||||
ModelRegistryHelper,
|
||||
|
|
|
|||
|
|
@ -9,6 +9,14 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Meta-Llama-Guard-3-8B",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Meta-Llama-3.1-8B-Instruct",
|
||||
|
|
@ -46,8 +54,4 @@ MODEL_ENTRIES = [
|
|||
"sambanova/Llama-4-Maverick-17B-128E-Instruct",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Meta-Llama-Guard-3-8B",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
]
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
import json
|
||||
from collections.abc import Iterable
|
||||
|
||||
import requests
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
|
||||
)
|
||||
|
|
@ -56,6 +57,7 @@ from llama_stack.apis.inference import (
|
|||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
|
@ -176,10 +178,11 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
|
||||
def __init__(self, config: SambaNovaImplConfig):
|
||||
self.config = config
|
||||
self.environment_available_models = []
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
api_key_from_config=self.config.api_key,
|
||||
api_key_from_config=self.config.api_key.get_secret_value() if self.config.api_key else None,
|
||||
provider_data_api_key_field="sambanova_api_key",
|
||||
)
|
||||
|
||||
|
|
@ -246,6 +249,22 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
**get_sampling_options(request.sampling_params),
|
||||
}
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
model_id = self.get_provider_model_id(model.provider_resource_id)
|
||||
|
||||
list_models_url = self.config.url + "/models"
|
||||
if len(self.environment_available_models) == 0:
|
||||
try:
|
||||
response = requests.get(list_models_url)
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise RuntimeError(f"Request to {list_models_url} failed") from e
|
||||
self.environment_available_models = [model.get("id") for model in response.json().get("data", {})]
|
||||
|
||||
if model_id.split("sambanova/")[-1] not in self.environment_available_models:
|
||||
logger.warning(f"Model {model_id} not available in {list_models_url}")
|
||||
return model
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,16 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-Guard-3-8B",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-Guard-3-11B-Vision-Turbo",
|
||||
CoreModelId.llama_guard_3_11b_vision.value,
|
||||
),
|
||||
]
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
||||
|
|
@ -40,14 +50,6 @@ MODEL_ENTRIES = [
|
|||
"meta-llama/Llama-3.3-70B-Instruct-Turbo",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Meta-Llama-Guard-3-8B",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-Guard-3-11B-Vision-Turbo",
|
||||
CoreModelId.llama_guard_3_11b_vision.value,
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="togethercomputer/m2-bert-80M-8k-retrieval",
|
||||
model_type=ModelType.embedding,
|
||||
|
|
@ -78,4 +80,4 @@ MODEL_ENTRIES = [
|
|||
"together/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
||||
],
|
||||
),
|
||||
]
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
|
|
|||
|
|
@ -68,19 +68,12 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
def __init__(self, config: TogetherImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||
self.config = config
|
||||
self._client = None
|
||||
self._openai_client = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self._client:
|
||||
# Together client has no close method, so just set to None
|
||||
self._client = None
|
||||
if self._openai_client:
|
||||
await self._openai_client.close()
|
||||
self._openai_client = None
|
||||
pass
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
|
|
@ -108,29 +101,25 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
return await self._nonstream_completion(request)
|
||||
|
||||
def _get_client(self) -> AsyncTogether:
|
||||
if not self._client:
|
||||
together_api_key = None
|
||||
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
|
||||
if config_api_key:
|
||||
together_api_key = config_api_key
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.together_api_key:
|
||||
raise ValueError(
|
||||
'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": <your api key>}'
|
||||
)
|
||||
together_api_key = provider_data.together_api_key
|
||||
self._client = AsyncTogether(api_key=together_api_key)
|
||||
return self._client
|
||||
together_api_key = None
|
||||
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
|
||||
if config_api_key:
|
||||
together_api_key = config_api_key
|
||||
else:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.together_api_key:
|
||||
raise ValueError(
|
||||
'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": <your api key>}'
|
||||
)
|
||||
together_api_key = provider_data.together_api_key
|
||||
return AsyncTogether(api_key=together_api_key)
|
||||
|
||||
def _get_openai_client(self) -> AsyncOpenAI:
|
||||
if not self._openai_client:
|
||||
together_client = self._get_client().client
|
||||
self._openai_client = AsyncOpenAI(
|
||||
base_url=together_client.base_url,
|
||||
api_key=together_client.api_key,
|
||||
)
|
||||
return self._openai_client
|
||||
together_client = self._get_client().client
|
||||
return AsyncOpenAI(
|
||||
base_url=together_client.base_url,
|
||||
api_key=together_client.api_key,
|
||||
)
|
||||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
|
||||
params = await self._get_params(request)
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
|||
class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProviderData):
|
||||
def __init__(self, config: SambaNovaSafetyConfig) -> None:
|
||||
self.config = config
|
||||
self.environment_available_models = []
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
|
@ -54,18 +55,18 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide
|
|||
|
||||
async def register_shield(self, shield: Shield) -> None:
|
||||
list_models_url = self.config.url + "/models"
|
||||
try:
|
||||
response = requests.get(list_models_url)
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise RuntimeError(f"Request to {list_models_url} failed") from e
|
||||
available_models = [model.get("id") for model in response.json().get("data", {})]
|
||||
if len(self.environment_available_models) == 0:
|
||||
try:
|
||||
response = requests.get(list_models_url)
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise RuntimeError(f"Request to {list_models_url} failed") from e
|
||||
self.environment_available_models = [model.get("id") for model in response.json().get("data", {})]
|
||||
if (
|
||||
len(available_models) == 0
|
||||
or "guard" not in shield.provider_resource_id.lower()
|
||||
or shield.provider_resource_id.split("sambanova/")[-1] not in available_models
|
||||
"guard" not in shield.provider_resource_id.lower()
|
||||
or shield.provider_resource_id.split("sambanova/")[-1] not in self.environment_available_models
|
||||
):
|
||||
raise ValueError(f"Shield {shield.provider_resource_id} not found in SambaNova")
|
||||
logger.warning(f"Shield {shield.provider_resource_id} not available in {list_models_url}")
|
||||
|
||||
async def run_shield(
|
||||
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
|
||||
|
|
|
|||
|
|
@ -61,6 +61,11 @@ class MilvusIndex(EmbeddingIndex):
|
|||
self.consistency_level = consistency_level
|
||||
self.kvstore = kvstore
|
||||
|
||||
async def initialize(self):
|
||||
# MilvusIndex does not require explicit initialization
|
||||
# TODO: could move collection creation into initialization but it is not really necessary
|
||||
pass
|
||||
|
||||
async def delete(self):
|
||||
if await asyncio.to_thread(self.client.has_collection, self.collection_name):
|
||||
await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name)
|
||||
|
|
@ -174,7 +179,8 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
uri = os.path.expanduser(self.config.db_path)
|
||||
self.client = MilvusClient(uri=uri)
|
||||
|
||||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||
# Load existing OpenAI vector stores into the in-memory cache
|
||||
await self.initialize_openai_vector_stores()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.client.close()
|
||||
|
|
@ -199,6 +205,9 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
|
||||
if self.vector_db_store is None:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
|
||||
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
|
||||
if not vector_db:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
|
|
@ -240,36 +249,6 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||
"""Save vector store metadata to persistent storage."""
|
||||
assert self.kvstore is not None
|
||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
||||
await self.kvstore.set(key=key, value=json.dumps(store_info))
|
||||
self.openai_vector_stores[store_id] = store_info
|
||||
|
||||
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||
"""Update vector store metadata in persistent storage."""
|
||||
assert self.kvstore is not None
|
||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
||||
await self.kvstore.set(key=key, value=json.dumps(store_info))
|
||||
self.openai_vector_stores[store_id] = store_info
|
||||
|
||||
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
|
||||
"""Delete vector store metadata from persistent storage."""
|
||||
assert self.kvstore is not None
|
||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
||||
await self.kvstore.delete(key)
|
||||
if store_id in self.openai_vector_stores:
|
||||
del self.openai_vector_stores[store_id]
|
||||
|
||||
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
|
||||
"""Load all vector store metadata from persistent storage."""
|
||||
assert self.kvstore is not None
|
||||
start_key = OPENAI_VECTOR_STORES_PREFIX
|
||||
end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
|
||||
stored = await self.kvstore.values_in_range(start_key, end_key)
|
||||
return {json.loads(s)["id"]: json.loads(s) for s in stored}
|
||||
|
||||
async def _save_openai_vector_store_file(
|
||||
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
|
||||
) -> None:
|
||||
|
|
|
|||
|
|
@ -8,6 +8,10 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
|
|
@ -18,10 +22,12 @@ class PGVectorVectorIOConfig(BaseModel):
|
|||
db: str | None = Field(default="postgres")
|
||||
user: str | None = Field(default="postgres")
|
||||
password: str | None = Field(default="mysecretpassword")
|
||||
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
__distro_dir__: str,
|
||||
host: str = "${env.PGVECTOR_HOST:=localhost}",
|
||||
port: int = "${env.PGVECTOR_PORT:=5432}",
|
||||
db: str = "${env.PGVECTOR_DB}",
|
||||
|
|
@ -29,4 +35,14 @@ class PGVectorVectorIOConfig(BaseModel):
|
|||
password: str = "${env.PGVECTOR_PASSWORD}",
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
return {"host": host, "port": port, "db": db, "user": user, "password": password}
|
||||
return {
|
||||
"host": host,
|
||||
"port": port,
|
||||
"db": db,
|
||||
"user": user,
|
||||
"password": password,
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="pgvector_registry.db",
|
||||
),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,24 +13,18 @@ from psycopg2 import sql
|
|||
from psycopg2.extras import Json, execute_values
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from llama_stack.apis.files.files import Files
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
QueryChunksResponse,
|
||||
SearchRankingOptions,
|
||||
VectorIO,
|
||||
VectorStoreChunkingStrategy,
|
||||
VectorStoreDeleteResponse,
|
||||
VectorStoreFileContentsResponse,
|
||||
VectorStoreFileObject,
|
||||
VectorStoreFileStatus,
|
||||
VectorStoreListFilesResponse,
|
||||
VectorStoreListResponse,
|
||||
VectorStoreObject,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
|
|
@ -40,6 +34,13 @@ from .config import PGVectorVectorIOConfig
|
|||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::"
|
||||
VECTOR_INDEX_PREFIX = f"vector_index:pgvector:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:pgvector:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:pgvector:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:pgvector:{VERSION}::"
|
||||
|
||||
|
||||
def check_extension_version(cur):
|
||||
cur.execute("SELECT extversion FROM pg_extension WHERE extname = 'vector'")
|
||||
|
|
@ -69,7 +70,7 @@ def load_models(cur, cls):
|
|||
|
||||
|
||||
class PGVectorIndex(EmbeddingIndex):
|
||||
def __init__(self, vector_db: VectorDB, dimension: int, conn):
|
||||
def __init__(self, vector_db: VectorDB, dimension: int, conn, kvstore: KVStore | None = None):
|
||||
self.conn = conn
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
# Sanitize the table name by replacing hyphens with underscores
|
||||
|
|
@ -77,6 +78,7 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
# when created with patterns like "test-vector-db-{uuid4()}"
|
||||
sanitized_identifier = vector_db.identifier.replace("-", "_")
|
||||
self.table_name = f"vector_store_{sanitized_identifier}"
|
||||
self.kvstore = kvstore
|
||||
|
||||
cur.execute(
|
||||
f"""
|
||||
|
|
@ -158,15 +160,28 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||
|
||||
|
||||
class PGVectorVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(self, config: PGVectorVectorIOConfig, inference_api: Api.inference) -> None:
|
||||
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(
|
||||
self,
|
||||
config: PGVectorVectorIOConfig,
|
||||
inference_api: Api.inference,
|
||||
files_api: Files | None = None,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.conn = None
|
||||
self.cache = {}
|
||||
self.files_api = files_api
|
||||
self.kvstore: KVStore | None = None
|
||||
self.vector_db_store = None
|
||||
self.openai_vector_store: dict[str, dict[str, Any]] = {}
|
||||
self.metadatadata_collection_name = "openai_vector_stores_metadata"
|
||||
|
||||
async def initialize(self) -> None:
|
||||
log.info(f"Initializing PGVector memory adapter with config: {self.config}")
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
await self.initialize_openai_vector_stores()
|
||||
|
||||
try:
|
||||
self.conn = psycopg2.connect(
|
||||
host=self.config.host,
|
||||
|
|
@ -201,14 +216,31 @@ class PGVectorVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
log.info("Connection to PGVector database server closed")
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
# Persist vector DB metadata in the KV store
|
||||
assert self.kvstore is not None
|
||||
key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}"
|
||||
await self.kvstore.set(key=key, value=vector_db.model_dump_json())
|
||||
|
||||
# Upsert model metadata in Postgres
|
||||
upsert_models(self.conn, [(vector_db.identifier, vector_db)])
|
||||
|
||||
index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
# Create and cache the PGVector index table for the vector DB
|
||||
index = VectorDBWithIndex(
|
||||
vector_db,
|
||||
index=PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn, kvstore=self.kvstore),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[vector_db.identifier] = index
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
# Remove provider index and cache
|
||||
if vector_db_id in self.cache:
|
||||
await self.cache[vector_db_id].index.delete()
|
||||
del self.cache[vector_db_id]
|
||||
|
||||
# Delete vector DB metadata from KV store
|
||||
assert self.kvstore is not None
|
||||
await self.kvstore.delete(key=f"{VECTOR_DBS_PREFIX}{vector_db_id}")
|
||||
|
||||
async def insert_chunks(
|
||||
self,
|
||||
|
|
@ -237,107 +269,20 @@ class PGVectorVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
return self.cache[vector_db_id]
|
||||
|
||||
async def openai_create_vector_store(
|
||||
self,
|
||||
name: str,
|
||||
file_ids: list[str] | None = None,
|
||||
expires_after: dict[str, Any] | None = None,
|
||||
chunking_strategy: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
embedding_model: str | None = None,
|
||||
embedding_dimension: int | None = 384,
|
||||
provider_id: str | None = None,
|
||||
provider_vector_db_id: str | None = None,
|
||||
) -> VectorStoreObject:
|
||||
# OpenAI Vector Stores File operations are not supported in PGVector
|
||||
async def _save_openai_vector_store_file(
|
||||
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
|
||||
) -> None:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
||||
|
||||
async def openai_list_vector_stores(
|
||||
self,
|
||||
limit: int | None = 20,
|
||||
order: str | None = "desc",
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
) -> VectorStoreListResponse:
|
||||
async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
||||
|
||||
async def openai_retrieve_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreObject:
|
||||
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
||||
|
||||
async def openai_update_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
name: str | None = None,
|
||||
expires_after: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> VectorStoreObject:
|
||||
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
||||
|
||||
async def openai_delete_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreDeleteResponse:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
||||
|
||||
async def openai_search_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
query: str | list[str],
|
||||
filters: dict[str, Any] | None = None,
|
||||
max_num_results: int | None = 10,
|
||||
ranking_options: SearchRankingOptions | None = None,
|
||||
rewrite_query: bool | None = False,
|
||||
search_mode: str | None = "vector",
|
||||
) -> VectorStoreSearchResponsePage:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
||||
|
||||
async def openai_attach_file_to_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
attributes: dict[str, Any] | None = None,
|
||||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
) -> VectorStoreFileObject:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
||||
|
||||
async def openai_list_files_in_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
limit: int | None = 20,
|
||||
order: str | None = "desc",
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
filter: VectorStoreFileStatus | None = None,
|
||||
) -> VectorStoreListFilesResponse:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
||||
|
||||
async def openai_retrieve_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileObject:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
||||
|
||||
async def openai_retrieve_vector_store_file_contents(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileContentsResponse:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
||||
|
||||
async def openai_update_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
attributes: dict[str, Any] | None = None,
|
||||
) -> VectorStoreFileObject:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
||||
|
||||
async def openai_delete_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileObject:
|
||||
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
||||
|
|
|
|||
|
|
@ -6,15 +6,26 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
|
||||
|
||||
class WeaviateRequestProviderData(BaseModel):
|
||||
weaviate_api_key: str
|
||||
weaviate_cluster_url: str
|
||||
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
|
||||
|
||||
|
||||
class WeaviateVectorIOConfig(BaseModel):
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
|
||||
return {}
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="weaviate_registry.db",
|
||||
),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,10 +14,13 @@ from weaviate.classes.init import Auth
|
|||
from weaviate.classes.query import Filter
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.files.files import Files
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
|
|
@ -27,11 +30,19 @@ from .config import WeaviateRequestProviderData, WeaviateVectorIOConfig
|
|||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::"
|
||||
VECTOR_INDEX_PREFIX = f"vector_index:weaviate:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:weaviate:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:weaviate:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:weaviate:{VERSION}::"
|
||||
|
||||
|
||||
class WeaviateIndex(EmbeddingIndex):
|
||||
def __init__(self, client: weaviate.Client, collection_name: str):
|
||||
def __init__(self, client: weaviate.Client, collection_name: str, kvstore: KVStore | None = None):
|
||||
self.client = client
|
||||
self.collection_name = collection_name
|
||||
self.kvstore = kvstore
|
||||
|
||||
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
||||
assert len(chunks) == len(embeddings), (
|
||||
|
|
@ -109,11 +120,21 @@ class WeaviateVectorIOAdapter(
|
|||
NeedsRequestProviderData,
|
||||
VectorDBsProtocolPrivate,
|
||||
):
|
||||
def __init__(self, config: WeaviateVectorIOConfig, inference_api: Api.inference) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
config: WeaviateVectorIOConfig,
|
||||
inference_api: Api.inference,
|
||||
files_api: Files | None,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.client_cache = {}
|
||||
self.cache = {}
|
||||
self.files_api = files_api
|
||||
self.kvstore: KVStore | None = None
|
||||
self.vector_db_store = None
|
||||
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||
self.metadata_collection_name = "openai_vector_stores_metadata"
|
||||
|
||||
def _get_client(self) -> weaviate.Client:
|
||||
provider_data = self.get_request_provider_data()
|
||||
|
|
@ -132,7 +153,26 @@ class WeaviateVectorIOAdapter(
|
|||
return client
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
"""Set up KV store and load existing vector DBs and OpenAI vector stores."""
|
||||
# Initialize KV store for metadata
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
|
||||
# Load existing vector DB definitions
|
||||
start_key = VECTOR_DBS_PREFIX
|
||||
end_key = f"{VECTOR_DBS_PREFIX}\xff"
|
||||
stored = await self.kvstore.values_in_range(start_key, end_key)
|
||||
for raw in stored:
|
||||
vector_db = VectorDB.model_validate_json(raw)
|
||||
client = self._get_client()
|
||||
idx = WeaviateIndex(client=client, collection_name=vector_db.identifier, kvstore=self.kvstore)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=idx,
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
|
||||
# Load OpenAI vector stores metadata into cache
|
||||
await self.initialize_openai_vector_stores()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
for client in self.client_cache.values():
|
||||
|
|
@ -206,3 +246,21 @@ class WeaviateVectorIOAdapter(
|
|||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
# OpenAI Vector Stores File operations are not supported in Weaviate
|
||||
async def _save_openai_vector_store_file(
|
||||
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
|
||||
) -> None:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")
|
||||
|
||||
async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")
|
||||
|
||||
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")
|
||||
|
||||
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")
|
||||
|
||||
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue