mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 04:00:42 +00:00
Merge branch 'main' into change-default-embedding-model
This commit is contained in:
commit
da35f2452e
15 changed files with 473 additions and 231 deletions
|
|
@ -66,11 +66,11 @@ runs:
|
|||
shell: bash
|
||||
run: |
|
||||
echo "Checking for recording changes"
|
||||
git status --porcelain tests/integration/recordings/
|
||||
git status --porcelain tests/integration/
|
||||
|
||||
if [[ -n $(git status --porcelain tests/integration/recordings/) ]]; then
|
||||
if [[ -n $(git status --porcelain tests/integration/) ]]; then
|
||||
echo "New recordings detected, committing and pushing"
|
||||
git add tests/integration/recordings/
|
||||
git add tests/integration/
|
||||
|
||||
git commit -m "Recordings update from CI (suite: ${{ inputs.suite }})"
|
||||
git fetch origin ${{ github.ref_name }}
|
||||
|
|
|
|||
|
|
@ -55,30 +55,18 @@ class VectorIORouter(VectorIO):
|
|||
logger.debug("VectorIORouter.shutdown")
|
||||
pass
|
||||
|
||||
async def _get_first_embedding_model(self) -> tuple[str, int] | None:
|
||||
"""Get the first available embedding model identifier."""
|
||||
try:
|
||||
# Get all models from the routing table
|
||||
async def _get_embedding_model_dimension(self, embedding_model_id: str) -> int:
|
||||
"""Get the embedding dimension for a specific embedding model."""
|
||||
all_models = await self.routing_table.get_all_with_type("model")
|
||||
|
||||
# Filter for embedding models
|
||||
embedding_models = [
|
||||
model
|
||||
for model in all_models
|
||||
if hasattr(model, "model_type") and model.model_type == ModelType.embedding
|
||||
]
|
||||
|
||||
if embedding_models:
|
||||
dimension = embedding_models[0].metadata.get("embedding_dimension", None)
|
||||
for model in all_models:
|
||||
if model.identifier == embedding_model_id and model.model_type == ModelType.embedding:
|
||||
dimension = model.metadata.get("embedding_dimension")
|
||||
if dimension is None:
|
||||
raise ValueError(f"Embedding model {embedding_models[0].identifier} has no embedding dimension")
|
||||
return embedding_models[0].identifier, dimension
|
||||
else:
|
||||
logger.warning("No embedding models found in the routing table")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting embedding models: {e}")
|
||||
return None
|
||||
raise ValueError(f"Embedding model '{embedding_model_id}' has no embedding_dimension in metadata")
|
||||
return int(dimension)
|
||||
|
||||
raise ValueError(f"Embedding model '{embedding_model_id}' not found or not an embedding model")
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
|
|
@ -129,20 +117,30 @@ class VectorIORouter(VectorIO):
|
|||
# Extract llama-stack-specific parameters from extra_body
|
||||
extra = params.model_extra or {}
|
||||
embedding_model = extra.get("embedding_model")
|
||||
embedding_dimension = extra.get("embedding_dimension", 384)
|
||||
embedding_dimension = extra.get("embedding_dimension")
|
||||
provider_id = extra.get("provider_id")
|
||||
|
||||
logger.debug(f"VectorIORouter.openai_create_vector_store: name={params.name}, provider_id={provider_id}")
|
||||
|
||||
# If no embedding model is provided, use the first available one
|
||||
# TODO: this branch will soon be deleted so you _must_ provide the embedding_model when
|
||||
# creating a vector store
|
||||
# Require explicit embedding model specification
|
||||
if embedding_model is None:
|
||||
embedding_model_info = await self._get_first_embedding_model()
|
||||
if embedding_model_info is None:
|
||||
raise ValueError("No embedding model provided and no embedding models available in the system")
|
||||
embedding_model, embedding_dimension = embedding_model_info
|
||||
logger.info(f"No embedding model specified, using first available: {embedding_model}")
|
||||
raise ValueError("embedding_model is required in extra_body when creating a vector store")
|
||||
|
||||
if embedding_dimension is None:
|
||||
embedding_dimension = await self._get_embedding_model_dimension(embedding_model)
|
||||
|
||||
# Auto-select provider if not specified
|
||||
if provider_id is None:
|
||||
num_providers = len(self.routing_table.impls_by_provider_id)
|
||||
if num_providers == 0:
|
||||
raise ValueError("No vector_io providers available")
|
||||
if num_providers > 1:
|
||||
available_providers = list(self.routing_table.impls_by_provider_id.keys())
|
||||
raise ValueError(
|
||||
f"Multiple vector_io providers available. Please specify provider_id in extra_body. "
|
||||
f"Available providers: {available_providers}"
|
||||
)
|
||||
provider_id = list(self.routing_table.impls_by_provider_id.keys())[0]
|
||||
|
||||
vector_db_id = f"vs_{uuid.uuid4()}"
|
||||
registered_vector_db = await self.routing_table.register_vector_db(
|
||||
|
|
|
|||
|
|
@ -5,13 +5,11 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import ssl
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from asyncio import Lock
|
||||
from urllib.parse import parse_qs, urljoin, urlparse
|
||||
|
||||
import httpx
|
||||
from jose import jwt
|
||||
import jwt
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.common.errors import TokenValidationError
|
||||
|
|
@ -98,9 +96,7 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
|
||||
def __init__(self, config: OAuth2TokenAuthConfig):
|
||||
self.config = config
|
||||
self._jwks_at: float = 0.0
|
||||
self._jwks: dict[str, str] = {}
|
||||
self._jwks_lock = Lock()
|
||||
self._jwks_client: jwt.PyJWKClient | None = None
|
||||
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> User:
|
||||
if self.config.jwks:
|
||||
|
|
@ -109,23 +105,60 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
return await self.introspect_token(token, scope)
|
||||
raise ValueError("One of jwks or introspection must be configured")
|
||||
|
||||
def _get_jwks_client(self) -> jwt.PyJWKClient:
|
||||
if self._jwks_client is None:
|
||||
ssl_context = None
|
||||
if not self.config.verify_tls:
|
||||
# Disable SSL verification if verify_tls is False
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.verify_mode = ssl.CERT_NONE
|
||||
elif self.config.tls_cafile:
|
||||
# Use custom CA file if provided
|
||||
ssl_context = ssl.create_default_context(
|
||||
cafile=self.config.tls_cafile.as_posix(),
|
||||
)
|
||||
# If verify_tls is True and no tls_cafile, ssl_context remains None (use system defaults)
|
||||
|
||||
# Prepare headers for JWKS request - this is needed for Kubernetes to authenticate
|
||||
# to the JWK endpoint, we must use the token in the config to authenticate
|
||||
headers = {}
|
||||
if self.config.jwks and self.config.jwks.token:
|
||||
headers["Authorization"] = f"Bearer {self.config.jwks.token}"
|
||||
|
||||
self._jwks_client = jwt.PyJWKClient(
|
||||
self.config.jwks.uri if self.config.jwks else None,
|
||||
cache_keys=True,
|
||||
max_cached_keys=10,
|
||||
lifespan=self.config.jwks.key_recheck_period if self.config.jwks else None,
|
||||
headers=headers,
|
||||
ssl_context=ssl_context,
|
||||
)
|
||||
return self._jwks_client
|
||||
|
||||
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> User:
|
||||
"""Validate a token using the JWT token."""
|
||||
await self._refresh_jwks()
|
||||
|
||||
try:
|
||||
header = jwt.get_unverified_header(token)
|
||||
kid = header["kid"]
|
||||
if kid not in self._jwks:
|
||||
raise ValueError(f"Unknown key ID: {kid}")
|
||||
key_data = self._jwks[kid]
|
||||
algorithm = header.get("alg", "RS256")
|
||||
jwks_client: jwt.PyJWKClient = self._get_jwks_client()
|
||||
signing_key = jwks_client.get_signing_key_from_jwt(token)
|
||||
algorithm = jwt.get_unverified_header(token)["alg"]
|
||||
claims = jwt.decode(
|
||||
token,
|
||||
key_data,
|
||||
signing_key.key,
|
||||
algorithms=[algorithm],
|
||||
audience=self.config.audience,
|
||||
issuer=self.config.issuer,
|
||||
options={"verify_exp": True, "verify_aud": True, "verify_iss": True},
|
||||
)
|
||||
|
||||
# Decode and verify the JWT
|
||||
claims = jwt.decode(
|
||||
token,
|
||||
signing_key.key,
|
||||
algorithms=[algorithm],
|
||||
audience=self.config.audience,
|
||||
issuer=self.config.issuer,
|
||||
options={"verify_exp": True, "verify_aud": True, "verify_iss": True},
|
||||
)
|
||||
except Exception as exc:
|
||||
raise ValueError("Invalid JWT token") from exc
|
||||
|
|
@ -201,37 +234,6 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
else:
|
||||
return "Authentication required. Please provide a valid OAuth2 Bearer token in the Authorization header"
|
||||
|
||||
async def _refresh_jwks(self) -> None:
|
||||
"""
|
||||
Refresh the JWKS cache.
|
||||
|
||||
This is a simple cache that expires after a certain amount of time (defined by `key_recheck_period`).
|
||||
If the cache is expired, we refresh the JWKS from the JWKS URI.
|
||||
|
||||
Notes: for Kubernetes which doesn't fully implement the OIDC protocol:
|
||||
* It doesn't have user authentication flows
|
||||
* It doesn't have refresh tokens
|
||||
"""
|
||||
async with self._jwks_lock:
|
||||
if self.config.jwks is None:
|
||||
raise ValueError("JWKS is not configured")
|
||||
if time.time() - self._jwks_at > self.config.jwks.key_recheck_period:
|
||||
headers = {}
|
||||
if self.config.jwks.token:
|
||||
headers["Authorization"] = f"Bearer {self.config.jwks.token}"
|
||||
verify = self.config.tls_cafile.as_posix() if self.config.tls_cafile else self.config.verify_tls
|
||||
async with httpx.AsyncClient(verify=verify) as client:
|
||||
res = await client.get(self.config.jwks.uri, timeout=5, headers=headers)
|
||||
res.raise_for_status()
|
||||
jwks_data = res.json()["keys"]
|
||||
updated = {}
|
||||
for k in jwks_data:
|
||||
kid = k["kid"]
|
||||
# Store the entire key object as it may be needed for different algorithms
|
||||
updated[kid] = k
|
||||
self._jwks = updated
|
||||
self._jwks_at = time.time()
|
||||
|
||||
|
||||
class CustomAuthProvider(AuthProvider):
|
||||
"""Custom authentication provider that uses an external endpoint."""
|
||||
|
|
|
|||
|
|
@ -277,7 +277,7 @@ Available Models:
|
|||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.watsonx",
|
||||
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.watsonx.config.WatsonXProviderDataValidator",
|
||||
description="IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform.",
|
||||
),
|
||||
RemoteProviderSpec(
|
||||
|
|
|
|||
|
|
@ -7,18 +7,18 @@
|
|||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class WatsonXProviderDataValidator(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
extra="forbid",
|
||||
watsonx_project_id: str | None = Field(
|
||||
default=None,
|
||||
description="IBM WatsonX project ID",
|
||||
)
|
||||
watsonx_api_key: str | None
|
||||
watsonx_api_key: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
|||
|
|
@ -4,42 +4,259 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
import requests
|
||||
|
||||
from llama_stack.apis.inference import ChatCompletionRequest
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAIChatCompletionUsage,
|
||||
OpenAICompletion,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIEmbeddingsResponse,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.models.models import ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.remote.inference.watsonx.config import WatsonXConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||
|
||||
logger = get_logger(name=__name__, category="providers::remote::watsonx")
|
||||
|
||||
|
||||
class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
_model_cache: dict[str, Model] = {}
|
||||
|
||||
provider_data_api_key_field: str = "watsonx_api_key"
|
||||
|
||||
def __init__(self, config: WatsonXConfig):
|
||||
self.available_models = None
|
||||
self.config = config
|
||||
api_key = config.auth_credential.get_secret_value() if config.auth_credential else None
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
litellm_provider_name="watsonx",
|
||||
api_key_from_config=config.auth_credential.get_secret_value() if config.auth_credential else None,
|
||||
api_key_from_config=api_key,
|
||||
provider_data_api_key_field="watsonx_api_key",
|
||||
openai_compat_api_base=self.get_base_url(),
|
||||
)
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
params: OpenAIChatCompletionRequestWithExtraBody,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
"""
|
||||
Override parent method to add timeout and inject usage object when missing.
|
||||
This works around a LiteLLM defect where usage block is sometimes dropped.
|
||||
"""
|
||||
|
||||
# Add usage tracking for streaming when telemetry is active
|
||||
stream_options = params.stream_options
|
||||
if params.stream and get_current_span() is not None:
|
||||
if stream_options is None:
|
||||
stream_options = {"include_usage": True}
|
||||
elif "include_usage" not in stream_options:
|
||||
stream_options = {**stream_options, "include_usage": True}
|
||||
|
||||
model_obj = await self.model_store.get_model(params.model)
|
||||
|
||||
request_params = await prepare_openai_completion_params(
|
||||
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
||||
messages=params.messages,
|
||||
frequency_penalty=params.frequency_penalty,
|
||||
function_call=params.function_call,
|
||||
functions=params.functions,
|
||||
logit_bias=params.logit_bias,
|
||||
logprobs=params.logprobs,
|
||||
max_completion_tokens=params.max_completion_tokens,
|
||||
max_tokens=params.max_tokens,
|
||||
n=params.n,
|
||||
parallel_tool_calls=params.parallel_tool_calls,
|
||||
presence_penalty=params.presence_penalty,
|
||||
response_format=params.response_format,
|
||||
seed=params.seed,
|
||||
stop=params.stop,
|
||||
stream=params.stream,
|
||||
stream_options=stream_options,
|
||||
temperature=params.temperature,
|
||||
tool_choice=params.tool_choice,
|
||||
tools=params.tools,
|
||||
top_logprobs=params.top_logprobs,
|
||||
top_p=params.top_p,
|
||||
user=params.user,
|
||||
api_key=self.get_api_key(),
|
||||
api_base=self.api_base,
|
||||
# These are watsonx-specific parameters
|
||||
timeout=self.config.timeout,
|
||||
project_id=self.config.project_id,
|
||||
)
|
||||
|
||||
result = await litellm.acompletion(**request_params)
|
||||
|
||||
# If not streaming, check and inject usage if missing
|
||||
if not params.stream:
|
||||
# Use getattr to safely handle cases where usage attribute might not exist
|
||||
if getattr(result, "usage", None) is None:
|
||||
# Create usage object with zeros
|
||||
usage_obj = OpenAIChatCompletionUsage(
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
)
|
||||
# Use model_copy to create a new response with the usage injected
|
||||
result = result.model_copy(update={"usage": usage_obj})
|
||||
return result
|
||||
|
||||
# For streaming, wrap the iterator to normalize chunks
|
||||
return self._normalize_stream(result)
|
||||
|
||||
def _normalize_chunk(self, chunk: OpenAIChatCompletionChunk) -> OpenAIChatCompletionChunk:
|
||||
"""
|
||||
Normalize a chunk to ensure it has all expected attributes.
|
||||
This works around LiteLLM not always including all expected attributes.
|
||||
"""
|
||||
# Ensure chunk has usage attribute with zeros if missing
|
||||
if not hasattr(chunk, "usage") or chunk.usage is None:
|
||||
usage_obj = OpenAIChatCompletionUsage(
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
)
|
||||
chunk = chunk.model_copy(update={"usage": usage_obj})
|
||||
|
||||
# Ensure all delta objects in choices have expected attributes
|
||||
if hasattr(chunk, "choices") and chunk.choices:
|
||||
normalized_choices = []
|
||||
for choice in chunk.choices:
|
||||
if hasattr(choice, "delta") and choice.delta:
|
||||
delta = choice.delta
|
||||
# Build update dict for missing attributes
|
||||
delta_updates = {}
|
||||
if not hasattr(delta, "refusal"):
|
||||
delta_updates["refusal"] = None
|
||||
if not hasattr(delta, "reasoning_content"):
|
||||
delta_updates["reasoning_content"] = None
|
||||
|
||||
# If we need to update delta, create a new choice with updated delta
|
||||
if delta_updates:
|
||||
new_delta = delta.model_copy(update=delta_updates)
|
||||
new_choice = choice.model_copy(update={"delta": new_delta})
|
||||
normalized_choices.append(new_choice)
|
||||
else:
|
||||
normalized_choices.append(choice)
|
||||
else:
|
||||
normalized_choices.append(choice)
|
||||
|
||||
# If we modified any choices, create a new chunk with updated choices
|
||||
if any(normalized_choices[i] is not chunk.choices[i] for i in range(len(chunk.choices))):
|
||||
chunk = chunk.model_copy(update={"choices": normalized_choices})
|
||||
|
||||
return chunk
|
||||
|
||||
async def _normalize_stream(
|
||||
self, stream: AsyncIterator[OpenAIChatCompletionChunk]
|
||||
) -> AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
"""
|
||||
Normalize all chunks in the stream to ensure they have expected attributes.
|
||||
This works around LiteLLM sometimes not including expected attributes.
|
||||
"""
|
||||
try:
|
||||
async for chunk in stream:
|
||||
# Normalize and yield each chunk immediately
|
||||
yield self._normalize_chunk(chunk)
|
||||
except Exception as e:
|
||||
logger.error(f"Error normalizing stream: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
params: OpenAICompletionRequestWithExtraBody,
|
||||
) -> OpenAICompletion:
|
||||
"""
|
||||
Override parent method to add watsonx-specific parameters.
|
||||
"""
|
||||
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
||||
|
||||
model_obj = await self.model_store.get_model(params.model)
|
||||
|
||||
request_params = await prepare_openai_completion_params(
|
||||
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
||||
prompt=params.prompt,
|
||||
best_of=params.best_of,
|
||||
echo=params.echo,
|
||||
frequency_penalty=params.frequency_penalty,
|
||||
logit_bias=params.logit_bias,
|
||||
logprobs=params.logprobs,
|
||||
max_tokens=params.max_tokens,
|
||||
n=params.n,
|
||||
presence_penalty=params.presence_penalty,
|
||||
seed=params.seed,
|
||||
stop=params.stop,
|
||||
stream=params.stream,
|
||||
stream_options=params.stream_options,
|
||||
temperature=params.temperature,
|
||||
top_p=params.top_p,
|
||||
user=params.user,
|
||||
suffix=params.suffix,
|
||||
api_key=self.get_api_key(),
|
||||
api_base=self.api_base,
|
||||
# These are watsonx-specific parameters
|
||||
timeout=self.config.timeout,
|
||||
project_id=self.config.project_id,
|
||||
)
|
||||
return await litellm.atext_completion(**request_params)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
params: OpenAIEmbeddingsRequestWithExtraBody,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
"""
|
||||
Override parent method to add watsonx-specific parameters.
|
||||
"""
|
||||
model_obj = await self.model_store.get_model(params.model)
|
||||
|
||||
# Convert input to list if it's a string
|
||||
input_list = [params.input] if isinstance(params.input, str) else params.input
|
||||
|
||||
# Call litellm embedding function with watsonx-specific parameters
|
||||
response = litellm.embedding(
|
||||
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
||||
input=input_list,
|
||||
api_key=self.get_api_key(),
|
||||
api_base=self.api_base,
|
||||
dimensions=params.dimensions,
|
||||
# These are watsonx-specific parameters
|
||||
timeout=self.config.timeout,
|
||||
project_id=self.config.project_id,
|
||||
)
|
||||
|
||||
# Convert response to OpenAI format
|
||||
from llama_stack.apis.inference import OpenAIEmbeddingUsage
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import b64_encode_openai_embeddings_response
|
||||
|
||||
data = b64_encode_openai_embeddings_response(response.data, params.encoding_format)
|
||||
|
||||
usage = OpenAIEmbeddingUsage(
|
||||
prompt_tokens=response["usage"]["prompt_tokens"],
|
||||
total_tokens=response["usage"]["total_tokens"],
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingsResponse(
|
||||
data=data,
|
||||
model=model_obj.provider_resource_id,
|
||||
usage=usage,
|
||||
)
|
||||
self.available_models = None
|
||||
self.config = config
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
return self.config.url
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
|
||||
# Get base parameters from parent
|
||||
params = await super()._get_params(request)
|
||||
|
||||
# Add watsonx.ai specific parameters
|
||||
params["project_id"] = self.config.project_id
|
||||
params["time_limit"] = self.config.timeout
|
||||
return params
|
||||
|
||||
# Copied from OpenAIMixin
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -353,14 +353,11 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
provider_vector_db_id = extra.get("provider_vector_db_id")
|
||||
embedding_model = extra.get("embedding_model")
|
||||
embedding_dimension = extra.get("embedding_dimension", 768)
|
||||
provider_id = extra.get("provider_id")
|
||||
|
||||
# use provider_id set by router; fallback to provider's own ID when used directly via --stack-config
|
||||
provider_id = extra.get("provider_id") or getattr(self, "__provider_id__", None)
|
||||
# Derive the canonical vector_db_id (allow override, else generate)
|
||||
vector_db_id = provider_vector_db_id or generate_object_id("vector_store", lambda: f"vs_{uuid.uuid4()}")
|
||||
|
||||
if provider_id is None:
|
||||
raise ValueError("Provider ID is required")
|
||||
|
||||
if embedding_model is None:
|
||||
raise ValueError("Embedding model is required")
|
||||
|
||||
|
|
@ -369,6 +366,9 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
raise ValueError("Embedding dimension is required")
|
||||
|
||||
# Register the VectorDB backing this vector store
|
||||
if provider_id is None:
|
||||
raise ValueError("Provider ID is required but was not provided")
|
||||
|
||||
vector_db = VectorDB(
|
||||
identifier=vector_db_id,
|
||||
embedding_dimension=embedding_dimension,
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ dependencies = [
|
|||
"openai>=1.107", # for expires_after support
|
||||
"prompt-toolkit",
|
||||
"python-dotenv",
|
||||
"python-jose[cryptography]",
|
||||
"pyjwt[crypto]>=2.10.0", # Pull crypto to support RS256 for jwt. Requires 2.10.0+ for ssl_context support.
|
||||
"pydantic>=2.11.9",
|
||||
"rich",
|
||||
"starlette",
|
||||
|
|
|
|||
|
|
@ -58,7 +58,6 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id)
|
|||
# does not work with the specified model, gpt-5-mini. Please choose different model and try
|
||||
# again. You can learn more about which models can be used with each operation here:
|
||||
# https://go.microsoft.com/fwlink/?linkid=2197993.'}}"}
|
||||
"remote::watsonx", # return 404 when hitting the /openai/v1 endpoint
|
||||
"remote::llama-openai-compat",
|
||||
):
|
||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.")
|
||||
|
|
@ -68,6 +67,7 @@ def skip_if_doesnt_support_completions_logprobs(client_with_models, model_id):
|
|||
provider_type = provider_from_model(client_with_models, model_id).provider_type
|
||||
if provider_type in (
|
||||
"remote::ollama", # logprobs is ignored
|
||||
"remote::watsonx",
|
||||
):
|
||||
pytest.skip(f"Model {model_id} hosted by {provider_type} doesn't support /v1/completions logprobs.")
|
||||
|
||||
|
|
@ -110,6 +110,7 @@ def skip_if_doesnt_support_n(client_with_models, model_id):
|
|||
# Error code 400 - {'message': '"n" > 1 is not currently supported', 'type': 'invalid_request_error', 'param': 'n', 'code': 'wrong_api_format'}
|
||||
"remote::cerebras",
|
||||
"remote::databricks", # Bad request: parameter "n" must be equal to 1 for streaming mode
|
||||
"remote::watsonx",
|
||||
):
|
||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support n param.")
|
||||
|
||||
|
|
@ -124,7 +125,6 @@ def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, mode
|
|||
"remote::databricks",
|
||||
"remote::cerebras",
|
||||
"remote::runpod",
|
||||
"remote::watsonx", # watsonx returns 404 when hitting the /openai/v1 endpoint
|
||||
):
|
||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI chat completions.")
|
||||
|
||||
|
|
@ -508,6 +508,12 @@ def test_openai_chat_completion_non_streaming_with_file(openai_client, client_wi
|
|||
assert "hello world" in normalized_content
|
||||
|
||||
|
||||
def skip_if_doesnt_support_completions_stop_sequence(client_with_models, model_id):
|
||||
provider_type = provider_from_model(client_with_models, model_id).provider_type
|
||||
if provider_type in ("remote::watsonx",): # openai.BadRequestError: Error code: 400
|
||||
pytest.skip(f"Model {model_id} hosted by {provider_type} doesn't support /v1/completions stop sequence.")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
|
|
@ -516,6 +522,7 @@ def test_openai_chat_completion_non_streaming_with_file(openai_client, client_wi
|
|||
)
|
||||
def test_openai_completion_stop_sequence(client_with_models, openai_client, text_model_id, test_case):
|
||||
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
|
||||
skip_if_doesnt_support_completions_stop_sequence(client_with_models, text_model_id)
|
||||
|
||||
tc = TestCase(test_case)
|
||||
|
||||
|
|
|
|||
|
|
@ -50,11 +50,15 @@ def skip_if_model_doesnt_support_encoding_format_base64(client, model_id):
|
|||
|
||||
def skip_if_model_doesnt_support_variable_dimensions(client_with_models, model_id):
|
||||
provider = provider_from_model(client_with_models, model_id)
|
||||
if provider.provider_type in (
|
||||
if (
|
||||
provider.provider_type
|
||||
in (
|
||||
"remote::together", # returns 400
|
||||
"inline::sentence-transformers",
|
||||
# Error code: 400 - {'error_code': 'BAD_REQUEST', 'message': 'Bad request: json: unknown field "dimensions"\n'}
|
||||
"remote::databricks",
|
||||
"remote::watsonx", # openai.BadRequestError: Error code: 400 - {'detail': "litellm.UnsupportedParamsError: watsonx does not support parameters: {'dimensions': 384}
|
||||
)
|
||||
):
|
||||
pytest.skip(
|
||||
f"Model {model_id} hosted by {provider.provider_type} does not support variable output embedding dimensions."
|
||||
|
|
|
|||
|
|
@ -146,8 +146,6 @@ def test_openai_create_vector_store(
|
|||
metadata={"purpose": "testing", "environment": "integration"},
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"provider_id": "my_provider",
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -175,8 +173,6 @@ def test_openai_list_vector_stores(
|
|||
metadata={"type": "test"},
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"provider_id": "my_provider",
|
||||
},
|
||||
)
|
||||
store2 = client.vector_stores.create(
|
||||
|
|
@ -184,8 +180,6 @@ def test_openai_list_vector_stores(
|
|||
metadata={"type": "test"},
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"provider_id": "my_provider",
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -220,8 +214,6 @@ def test_openai_retrieve_vector_store(
|
|||
metadata={"purpose": "retrieval_test"},
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"provider_id": "my_provider",
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -249,8 +241,6 @@ def test_openai_update_vector_store(
|
|||
metadata={"version": "1.0"},
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"provider_id": "my_provider",
|
||||
},
|
||||
)
|
||||
time.sleep(1)
|
||||
|
|
@ -282,8 +272,6 @@ def test_openai_delete_vector_store(
|
|||
metadata={"purpose": "deletion_test"},
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"provider_id": "my_provider",
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -314,8 +302,6 @@ def test_openai_vector_store_search_empty(
|
|||
metadata={"purpose": "search_testing"},
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"provider_id": "my_provider",
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -346,8 +332,6 @@ def test_openai_vector_store_with_chunks(
|
|||
metadata={"purpose": "chunks_testing"},
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"provider_id": "my_provider",
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -412,8 +396,6 @@ def test_openai_vector_store_search_relevance(
|
|||
metadata={"purpose": "relevance_testing"},
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"provider_id": "my_provider",
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -457,8 +439,6 @@ def test_openai_vector_store_search_with_ranking_options(
|
|||
metadata={"purpose": "ranking_testing"},
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"provider_id": "my_provider",
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -512,8 +492,6 @@ def test_openai_vector_store_search_with_high_score_filter(
|
|||
metadata={"purpose": "high_score_filtering"},
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"provider_id": "my_provider",
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -573,8 +551,6 @@ def test_openai_vector_store_search_with_max_num_results(
|
|||
metadata={"purpose": "max_num_results_testing"},
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"provider_id": "my_provider",
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -608,8 +584,6 @@ def test_openai_vector_store_attach_file(
|
|||
name="test_store",
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"provider_id": "my_provider",
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -688,8 +662,6 @@ def test_openai_vector_store_attach_files_on_creation(
|
|||
file_ids=file_ids,
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"provider_id": "my_provider",
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -735,8 +707,6 @@ def test_openai_vector_store_list_files(
|
|||
name="test_store",
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"provider_id": "my_provider",
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -826,8 +796,6 @@ def test_openai_vector_store_retrieve_file_contents(
|
|||
name="test_store",
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"provider_id": "my_provider",
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -851,8 +819,6 @@ def test_openai_vector_store_retrieve_file_contents(
|
|||
attributes=attributes,
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"provider_id": "my_provider",
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -889,8 +855,6 @@ def test_openai_vector_store_delete_file(
|
|||
name="test_store",
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"provider_id": "my_provider",
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -955,8 +919,6 @@ def test_openai_vector_store_delete_file_removes_from_vector_store(
|
|||
name="test_store",
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"provider_id": "my_provider",
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -1007,8 +969,6 @@ def test_openai_vector_store_update_file(
|
|||
name="test_store",
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"provider_id": "my_provider",
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -1078,8 +1038,6 @@ def test_create_vector_store_files_duplicate_vector_store_name(
|
|||
name="test_store_with_files",
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"provider_id": "my_provider",
|
||||
},
|
||||
)
|
||||
assert vector_store.file_counts.completed == 0
|
||||
|
|
@ -1092,8 +1050,6 @@ def test_create_vector_store_files_duplicate_vector_store_name(
|
|||
name="test_store_with_files",
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"provider_id": "my_provider",
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -1105,8 +1061,6 @@ def test_create_vector_store_files_duplicate_vector_store_name(
|
|||
file_id=file_ids[0],
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"provider_id": "my_provider",
|
||||
},
|
||||
)
|
||||
assert created_file.status == "completed"
|
||||
|
|
@ -1117,8 +1071,6 @@ def test_create_vector_store_files_duplicate_vector_store_name(
|
|||
file_id=file_ids[1],
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"provider_id": "my_provider",
|
||||
},
|
||||
)
|
||||
assert created_file_from_non_deleted_vector_store.status == "completed"
|
||||
|
|
@ -1139,8 +1091,6 @@ def test_openai_vector_store_search_modes(
|
|||
metadata={"purpose": "search_mode_testing"},
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"provider_id": "my_provider",
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -1172,8 +1122,6 @@ def test_openai_vector_store_file_batch_create_and_retrieve(
|
|||
name="batch_test_store",
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"provider_id": "my_provider",
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -1191,8 +1139,6 @@ def test_openai_vector_store_file_batch_create_and_retrieve(
|
|||
file_ids=file_ids,
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"provider_id": "my_provider",
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -1239,8 +1185,6 @@ def test_openai_vector_store_file_batch_list_files(
|
|||
name="batch_list_test_store",
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"provider_id": "my_provider",
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -1258,8 +1202,6 @@ def test_openai_vector_store_file_batch_list_files(
|
|||
file_ids=file_ids,
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"provider_id": "my_provider",
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -1336,8 +1278,6 @@ def test_openai_vector_store_file_batch_cancel(
|
|||
name="batch_cancel_test_store",
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"provider_id": "my_provider",
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -1355,8 +1295,6 @@ def test_openai_vector_store_file_batch_cancel(
|
|||
file_ids=file_ids,
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"provider_id": "my_provider",
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -1395,8 +1333,6 @@ def test_openai_vector_store_file_batch_retrieve_contents(
|
|||
name="batch_contents_test_store",
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"provider_id": "my_provider",
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -1419,8 +1355,6 @@ def test_openai_vector_store_file_batch_retrieve_contents(
|
|||
file_ids=file_ids,
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"provider_id": "my_provider",
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -1472,8 +1406,6 @@ def test_openai_vector_store_file_batch_error_handling(
|
|||
name="batch_error_test_store",
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"provider_id": "my_provider",
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -1485,8 +1417,6 @@ def test_openai_vector_store_file_batch_error_handling(
|
|||
file_ids=file_ids,
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"provider_id": "my_provider",
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -52,8 +52,6 @@ def test_vector_db_retrieve(client_with_empty_registry, embedding_model_id, embe
|
|||
name=vector_db_name,
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"provider_id": "my_provider",
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -73,8 +71,6 @@ def test_vector_db_register(client_with_empty_registry, embedding_model_id, embe
|
|||
name=vector_db_name,
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"provider_id": "my_provider",
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -110,8 +106,6 @@ def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding
|
|||
name=vector_db_name,
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"provider_id": "my_provider",
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -152,8 +146,6 @@ def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, e
|
|||
name=vector_db_name,
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"provider_id": "my_provider",
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -202,8 +194,6 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(
|
|||
name=vector_db_name,
|
||||
extra_body={
|
||||
"embedding_model": embedding_model_id,
|
||||
"embedding_dimension": embedding_dimension,
|
||||
"provider_id": "my_provider",
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -234,3 +224,35 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(
|
|||
assert len(response.chunks) > 0
|
||||
assert response.chunks[0].metadata["document_id"] == "doc1"
|
||||
assert response.chunks[0].metadata["source"] == "precomputed"
|
||||
|
||||
|
||||
def test_auto_extract_embedding_dimension(client_with_empty_registry, embedding_model_id):
|
||||
vs = client_with_empty_registry.vector_stores.create(
|
||||
name="test_auto_extract", extra_body={"embedding_model": embedding_model_id}
|
||||
)
|
||||
assert vs.id is not None
|
||||
|
||||
|
||||
def test_provider_auto_selection_single_provider(client_with_empty_registry, embedding_model_id):
|
||||
providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"]
|
||||
if len(providers) != 1:
|
||||
pytest.skip(f"Test requires exactly one vector_io provider, found {len(providers)}")
|
||||
|
||||
vs = client_with_empty_registry.vector_stores.create(
|
||||
name="test_auto_provider", extra_body={"embedding_model": embedding_model_id}
|
||||
)
|
||||
assert vs.id is not None
|
||||
|
||||
|
||||
def test_provider_id_override(client_with_empty_registry, embedding_model_id):
|
||||
providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"]
|
||||
if len(providers) != 1:
|
||||
pytest.skip(f"Test requires exactly one vector_io provider, found {len(providers)}")
|
||||
|
||||
provider_id = providers[0].provider_id
|
||||
|
||||
vs = client_with_empty_registry.vector_stores.create(
|
||||
name="test_provider_override", extra_body={"embedding_model": embedding_model_id, "provider_id": provider_id}
|
||||
)
|
||||
assert vs.id is not None
|
||||
assert vs.metadata.get("provider_id") == provider_id
|
||||
|
|
|
|||
57
tests/unit/core/routers/test_vector_io.py
Normal file
57
tests/unit/core/routers/test_vector_io.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.vector_io import OpenAICreateVectorStoreRequestWithExtraBody
|
||||
from llama_stack.core.routers.vector_io import VectorIORouter
|
||||
|
||||
|
||||
async def test_single_provider_auto_selection():
|
||||
# provider_id automatically selected during vector store create() when only one provider available
|
||||
mock_routing_table = Mock()
|
||||
mock_routing_table.impls_by_provider_id = {"inline::faiss": "mock_provider"}
|
||||
mock_routing_table.get_all_with_type = AsyncMock(
|
||||
return_value=[
|
||||
Mock(identifier="all-MiniLM-L6-v2", model_type="embedding", metadata={"embedding_dimension": 384})
|
||||
]
|
||||
)
|
||||
mock_routing_table.register_vector_db = AsyncMock(
|
||||
return_value=Mock(identifier="vs_123", provider_id="inline::faiss", provider_resource_id="vs_123")
|
||||
)
|
||||
mock_routing_table.get_provider_impl = AsyncMock(
|
||||
return_value=Mock(openai_create_vector_store=AsyncMock(return_value=Mock(id="vs_123")))
|
||||
)
|
||||
router = VectorIORouter(mock_routing_table)
|
||||
request = OpenAICreateVectorStoreRequestWithExtraBody.model_validate(
|
||||
{"name": "test_store", "embedding_model": "all-MiniLM-L6-v2"}
|
||||
)
|
||||
|
||||
result = await router.openai_create_vector_store(request)
|
||||
assert result.id == "vs_123"
|
||||
|
||||
|
||||
async def test_create_vector_stores_multiple_providers_missing_provider_id_error():
|
||||
# if multiple providers are available, vector store create will error without provider_id
|
||||
mock_routing_table = Mock()
|
||||
mock_routing_table.impls_by_provider_id = {
|
||||
"inline::faiss": "mock_provider_1",
|
||||
"inline::sqlite-vec": "mock_provider_2",
|
||||
}
|
||||
mock_routing_table.get_all_with_type = AsyncMock(
|
||||
return_value=[
|
||||
Mock(identifier="all-MiniLM-L6-v2", model_type="embedding", metadata={"embedding_dimension": 384})
|
||||
]
|
||||
)
|
||||
router = VectorIORouter(mock_routing_table)
|
||||
request = OpenAICreateVectorStoreRequestWithExtraBody.model_validate(
|
||||
{"name": "test_store", "embedding_model": "all-MiniLM-L6-v2"}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Multiple vector_io providers available"):
|
||||
await router.openai_create_vector_store(request)
|
||||
|
|
@ -5,7 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
from unittest.mock import AsyncMock, patch
|
||||
import json
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
|
|
@ -374,7 +375,7 @@ async def mock_jwks_response(*args, **kwargs):
|
|||
|
||||
@pytest.fixture
|
||||
def jwt_token_valid():
|
||||
from jose import jwt
|
||||
import jwt
|
||||
|
||||
return jwt.encode(
|
||||
{
|
||||
|
|
@ -389,8 +390,30 @@ def jwt_token_valid():
|
|||
)
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.get", new=mock_jwks_response)
|
||||
def test_valid_oauth2_authentication(oauth2_client, jwt_token_valid):
|
||||
@pytest.fixture
|
||||
def mock_jwks_urlopen():
|
||||
"""Mock urllib.request.urlopen for PyJWKClient JWKS requests."""
|
||||
with patch("urllib.request.urlopen") as mock_urlopen:
|
||||
# Mock the JWKS response for PyJWKClient
|
||||
mock_response = Mock()
|
||||
mock_response.read.return_value = json.dumps(
|
||||
{
|
||||
"keys": [
|
||||
{
|
||||
"kid": "1234567890",
|
||||
"kty": "oct",
|
||||
"alg": "HS256",
|
||||
"use": "sig",
|
||||
"k": base64.b64encode(b"foobarbaz").decode(),
|
||||
}
|
||||
]
|
||||
}
|
||||
).encode()
|
||||
mock_urlopen.return_value.__enter__.return_value = mock_response
|
||||
yield mock_urlopen
|
||||
|
||||
|
||||
def test_valid_oauth2_authentication(oauth2_client, jwt_token_valid, mock_jwks_urlopen):
|
||||
response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "Authentication successful"}
|
||||
|
|
@ -447,8 +470,7 @@ def test_oauth2_with_jwks_token_expected(oauth2_client, jwt_token_valid):
|
|||
assert response.status_code == 401
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.get", new=mock_auth_jwks_response)
|
||||
def test_oauth2_with_jwks_token_configured(oauth2_client_with_jwks_token, jwt_token_valid):
|
||||
def test_oauth2_with_jwks_token_configured(oauth2_client_with_jwks_token, jwt_token_valid, mock_jwks_urlopen):
|
||||
response = oauth2_client_with_jwks_token.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "Authentication successful"}
|
||||
|
|
|
|||
49
uv.lock
generated
49
uv.lock
generated
|
|
@ -874,18 +874,6 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/b0/0d/9feae160378a3553fa9a339b0e9c1a048e147a4127210e286ef18b730f03/durationpy-0.10-py3-none-any.whl", hash = "sha256:3b41e1b601234296b4fb368338fdcd3e13e0b4fb5b67345948f4f2bf9868b286", size = 3922, upload-time = "2025-05-17T13:52:36.463Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ecdsa"
|
||||
version = "0.19.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "six" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c0/1f/924e3caae75f471eae4b26bd13b698f6af2c44279f67af317439c2f4c46a/ecdsa-0.19.1.tar.gz", hash = "sha256:478cba7b62555866fcb3bb3fe985e06decbdb68ef55713c4e5ab98c57d508e61", size = 201793, upload-time = "2025-03-13T11:52:43.25Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/cb/a3/460c57f094a4a165c84a1341c373b0a4f5ec6ac244b998d5021aade89b77/ecdsa-0.19.1-py2.py3-none-any.whl", hash = "sha256:30638e27cf77b7e15c4c4cc1973720149e1033827cfd00661ca5c8cc0cdb24c3", size = 150607, upload-time = "2025-03-13T11:52:41.757Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "eval-type-backport"
|
||||
version = "0.2.2"
|
||||
|
|
@ -1787,8 +1775,8 @@ dependencies = [
|
|||
{ name = "pillow" },
|
||||
{ name = "prompt-toolkit" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "pyjwt", extra = ["crypto"] },
|
||||
{ name = "python-dotenv" },
|
||||
{ name = "python-jose", extra = ["cryptography"] },
|
||||
{ name = "python-multipart" },
|
||||
{ name = "rich" },
|
||||
{ name = "sqlalchemy", extra = ["asyncio"] },
|
||||
|
|
@ -1910,8 +1898,8 @@ requires-dist = [
|
|||
{ name = "pillow" },
|
||||
{ name = "prompt-toolkit" },
|
||||
{ name = "pydantic", specifier = ">=2.11.9" },
|
||||
{ name = "pyjwt", extras = ["crypto"], specifier = ">=2.10.0" },
|
||||
{ name = "python-dotenv" },
|
||||
{ name = "python-jose", extras = ["cryptography"] },
|
||||
{ name = "python-multipart", specifier = ">=0.0.20" },
|
||||
{ name = "rich" },
|
||||
{ name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.41" },
|
||||
|
|
@ -3558,6 +3546,20 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyjwt"
|
||||
version = "2.10.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/e7/46/bd74733ff231675599650d3e47f361794b22ef3e3770998dda30d3b63726/pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953", size = 87785, upload-time = "2024-11-28T03:43:29.933Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/61/ad/689f02752eeec26aed679477e80e632ef1b682313be70793d798c1d5fc8f/PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb", size = 22997, upload-time = "2024-11-28T03:43:27.893Z" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
crypto = [
|
||||
{ name = "cryptography" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pymilvus"
|
||||
version = "2.6.1"
|
||||
|
|
@ -3747,25 +3749,6 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/0c/fa/df59acedf7bbb937f69174d00f921a7b93aa5a5f5c17d05296c814fff6fc/python_engineio-4.12.2-py3-none-any.whl", hash = "sha256:8218ab66950e179dfec4b4bbb30aecf3f5d86f5e58e6fc1aa7fde2c698b2804f", size = 59536, upload-time = "2025-06-04T19:22:16.916Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "python-jose"
|
||||
version = "3.5.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "ecdsa" },
|
||||
{ name = "pyasn1" },
|
||||
{ name = "rsa" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c6/77/3a1c9039db7124eb039772b935f2244fbb73fc8ee65b9acf2375da1c07bf/python_jose-3.5.0.tar.gz", hash = "sha256:fb4eaa44dbeb1c26dcc69e4bd7ec54a1cb8dd64d3b4d81ef08d90ff453f2b01b", size = 92726, upload-time = "2025-05-28T17:31:54.288Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d9/c3/0bd11992072e6a1c513b16500a5d07f91a24017c5909b02c72c62d7ad024/python_jose-3.5.0-py2.py3-none-any.whl", hash = "sha256:abd1202f23d34dfad2c3d28cb8617b90acf34132c7afd60abd0b0b7d3cb55771", size = 34624, upload-time = "2025-05-28T17:31:52.802Z" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
cryptography = [
|
||||
{ name = "cryptography" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "python-multipart"
|
||||
version = "0.0.20"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue