mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-07 04:45:44 +00:00
Merge branch 'main' into content-extension
This commit is contained in:
commit
2fbddb4beb
30 changed files with 669 additions and 92 deletions
|
@ -15,6 +15,7 @@ from llama_stack.apis.safety import (
|
|||
RunShieldResponse,
|
||||
Safety,
|
||||
SafetyViolation,
|
||||
ShieldStore,
|
||||
ViolationLevel,
|
||||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
|
@ -32,6 +33,8 @@ PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
|
|||
|
||||
|
||||
class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||
shield_store: ShieldStore
|
||||
|
||||
def __init__(self, config: PromptGuardConfig, _deps) -> None:
|
||||
self.config = config
|
||||
|
||||
|
@ -53,7 +56,7 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
self,
|
||||
shield_id: str,
|
||||
messages: list[Message],
|
||||
params: dict[str, Any] = None,
|
||||
params: dict[str, Any],
|
||||
) -> RunShieldResponse:
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
if not shield:
|
||||
|
@ -117,8 +120,10 @@ class PromptGuardShield:
|
|||
elif self.config.guard_type == PromptGuardType.jailbreak.value and score_malicious > self.threshold:
|
||||
violation = SafetyViolation(
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
violation_type=f"prompt_injection:malicious={score_malicious}",
|
||||
violation_return_message="Sorry, I cannot do this.",
|
||||
user_message="Sorry, I cannot do this.",
|
||||
metadata={
|
||||
"violation_type": f"prompt_injection:malicious={score_malicious}",
|
||||
},
|
||||
)
|
||||
|
||||
return RunShieldResponse(violation=violation)
|
||||
|
|
|
@ -174,7 +174,9 @@ class FaissIndex(EmbeddingIndex):
|
|||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in FAISS")
|
||||
raise NotImplementedError(
|
||||
"Keyword search is not supported - underlying DB FAISS does not support this search mode"
|
||||
)
|
||||
|
||||
async def query_hybrid(
|
||||
self,
|
||||
|
@ -185,7 +187,9 @@ class FaissIndex(EmbeddingIndex):
|
|||
reranker_type: str,
|
||||
reranker_params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Hybrid search is not supported in FAISS")
|
||||
raise NotImplementedError(
|
||||
"Hybrid search is not supported - underlying DB FAISS does not support this search mode"
|
||||
)
|
||||
|
||||
|
||||
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
|
|
|
@ -213,6 +213,36 @@ def available_providers() -> list[ProviderSpec]:
|
|||
description="Google Gemini inference provider for accessing Gemini models and Google's AI services.",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="vertexai",
|
||||
pip_packages=["litellm", "google-cloud-aiplatform"],
|
||||
module="llama_stack.providers.remote.inference.vertexai",
|
||||
config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator",
|
||||
description="""Google Vertex AI inference provider enables you to use Google's Gemini models through Google Cloud's Vertex AI platform, providing several advantages:
|
||||
|
||||
• Enterprise-grade security: Uses Google Cloud's security controls and IAM
|
||||
• Better integration: Seamless integration with other Google Cloud services
|
||||
• Advanced features: Access to additional Vertex AI features like model tuning and monitoring
|
||||
• Authentication: Uses Google Cloud Application Default Credentials (ADC) instead of API keys
|
||||
|
||||
Configuration:
|
||||
- Set VERTEX_AI_PROJECT environment variable (required)
|
||||
- Set VERTEX_AI_LOCATION environment variable (optional, defaults to us-central1)
|
||||
- Use Google Cloud Application Default Credentials or service account key
|
||||
|
||||
Authentication Setup:
|
||||
Option 1 (Recommended): gcloud auth application-default login
|
||||
Option 2: Set GOOGLE_APPLICATION_CREDENTIALS to service account key path
|
||||
|
||||
Available Models:
|
||||
- vertex_ai/gemini-2.0-flash
|
||||
- vertex_ai/gemini-2.5-flash
|
||||
- vertex_ai/gemini-2.5-pro""",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
|
|
|
@ -45,6 +45,18 @@ That means you'll get fast and efficient vector retrieval.
|
|||
- Lightweight and easy to use
|
||||
- Fully integrated with Llama Stack
|
||||
- GPU support
|
||||
- **Vector search** - FAISS supports pure vector similarity search using embeddings
|
||||
|
||||
## Search Modes
|
||||
|
||||
**Supported:**
|
||||
- **Vector Search** (`mode="vector"`): Performs vector similarity search using embeddings
|
||||
|
||||
**Not Supported:**
|
||||
- **Keyword Search** (`mode="keyword"`): Not supported by FAISS
|
||||
- **Hybrid Search** (`mode="hybrid"`): Not supported by FAISS
|
||||
|
||||
> **Note**: FAISS is designed as a pure vector similarity search library. See the [FAISS GitHub repository](https://github.com/facebookresearch/faiss) for more details about FAISS's core functionality.
|
||||
|
||||
## Usage
|
||||
|
||||
|
@ -535,6 +547,7 @@ That means you're not limited to storing vectors in memory or in a separate serv
|
|||
|
||||
- Easy to use
|
||||
- Fully integrated with Llama Stack
|
||||
- Supports all search modes: vector, keyword, and hybrid search (both inline and remote configurations)
|
||||
|
||||
## Usage
|
||||
|
||||
|
@ -625,6 +638,92 @@ vector_io:
|
|||
- **`client_pem_path`**: Path to the **client certificate** file (required for mTLS).
|
||||
- **`client_key_path`**: Path to the **client private key** file (required for mTLS).
|
||||
|
||||
## Search Modes
|
||||
|
||||
Milvus supports three different search modes for both inline and remote configurations:
|
||||
|
||||
### Vector Search
|
||||
Vector search uses semantic similarity to find the most relevant chunks based on embedding vectors. This is the default search mode and works well for finding conceptually similar content.
|
||||
|
||||
```python
|
||||
# Vector search example
|
||||
search_response = client.vector_stores.search(
|
||||
vector_store_id=vector_store.id,
|
||||
query="What is machine learning?",
|
||||
search_mode="vector",
|
||||
max_num_results=5,
|
||||
)
|
||||
```
|
||||
|
||||
### Keyword Search
|
||||
Keyword search uses traditional text-based matching to find chunks containing specific terms or phrases. This is useful when you need exact term matches.
|
||||
|
||||
```python
|
||||
# Keyword search example
|
||||
search_response = client.vector_stores.search(
|
||||
vector_store_id=vector_store.id,
|
||||
query="Python programming language",
|
||||
search_mode="keyword",
|
||||
max_num_results=5,
|
||||
)
|
||||
```
|
||||
|
||||
### Hybrid Search
|
||||
Hybrid search combines both vector and keyword search methods to provide more comprehensive results. It leverages the strengths of both semantic similarity and exact term matching.
|
||||
|
||||
#### Basic Hybrid Search
|
||||
```python
|
||||
# Basic hybrid search example (uses RRF ranker with default impact_factor=60.0)
|
||||
search_response = client.vector_stores.search(
|
||||
vector_store_id=vector_store.id,
|
||||
query="neural networks in Python",
|
||||
search_mode="hybrid",
|
||||
max_num_results=5,
|
||||
)
|
||||
```
|
||||
|
||||
**Note**: The default `impact_factor` value of 60.0 was empirically determined to be optimal in the original RRF research paper: ["Reciprocal Rank Fusion outperforms Condorcet and individual Rank Learning Methods"](https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf) (Cormack et al., 2009).
|
||||
|
||||
#### Hybrid Search with RRF (Reciprocal Rank Fusion) Ranker
|
||||
RRF combines rankings from vector and keyword search by using reciprocal ranks. The impact factor controls how much weight is given to higher-ranked results.
|
||||
|
||||
```python
|
||||
# Hybrid search with custom RRF parameters
|
||||
search_response = client.vector_stores.search(
|
||||
vector_store_id=vector_store.id,
|
||||
query="neural networks in Python",
|
||||
search_mode="hybrid",
|
||||
max_num_results=5,
|
||||
ranking_options={
|
||||
"ranker": {
|
||||
"type": "rrf",
|
||||
"impact_factor": 100.0, # Higher values give more weight to top-ranked results
|
||||
}
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
#### Hybrid Search with Weighted Ranker
|
||||
Weighted ranker linearly combines normalized scores from vector and keyword search. The alpha parameter controls the balance between the two search methods.
|
||||
|
||||
```python
|
||||
# Hybrid search with weighted ranker
|
||||
search_response = client.vector_stores.search(
|
||||
vector_store_id=vector_store.id,
|
||||
query="neural networks in Python",
|
||||
search_mode="hybrid",
|
||||
max_num_results=5,
|
||||
ranking_options={
|
||||
"ranker": {
|
||||
"type": "weighted",
|
||||
"alpha": 0.7, # 70% vector search, 30% keyword search
|
||||
}
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
For detailed documentation on RRF and Weighted rankers, please refer to the [Milvus Reranking Guide](https://milvus.io/docs/reranking.md).
|
||||
|
||||
## Documentation
|
||||
See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for more details about Milvus in general.
|
||||
|
||||
|
|
15
llama_stack/providers/remote/inference/vertexai/__init__.py
Normal file
15
llama_stack/providers/remote/inference/vertexai/__init__.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import VertexAIConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: VertexAIConfig, _deps):
|
||||
from .vertexai import VertexAIInferenceAdapter
|
||||
|
||||
impl = VertexAIInferenceAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
45
llama_stack/providers/remote/inference/vertexai/config.py
Normal file
45
llama_stack/providers/remote/inference/vertexai/config.py
Normal file
|
@ -0,0 +1,45 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class VertexAIProviderDataValidator(BaseModel):
|
||||
vertex_project: str | None = Field(
|
||||
default=None,
|
||||
description="Google Cloud project ID for Vertex AI",
|
||||
)
|
||||
vertex_location: str | None = Field(
|
||||
default=None,
|
||||
description="Google Cloud location for Vertex AI (e.g., us-central1)",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class VertexAIConfig(BaseModel):
|
||||
project: str = Field(
|
||||
description="Google Cloud project ID for Vertex AI",
|
||||
)
|
||||
location: str = Field(
|
||||
default="us-central1",
|
||||
description="Google Cloud location for Vertex AI",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
project: str = "${env.VERTEX_AI_PROJECT:=}",
|
||||
location: str = "${env.VERTEX_AI_LOCATION:=us-central1}",
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"project": project,
|
||||
"location": location,
|
||||
}
|
20
llama_stack/providers/remote/inference/vertexai/models.py
Normal file
20
llama_stack/providers/remote/inference/vertexai/models.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ProviderModelEntry,
|
||||
)
|
||||
|
||||
# Vertex AI model IDs with vertex_ai/ prefix as required by litellm
|
||||
LLM_MODEL_IDS = [
|
||||
"vertex_ai/gemini-2.0-flash",
|
||||
"vertex_ai/gemini-2.5-flash",
|
||||
"vertex_ai/gemini-2.5-pro",
|
||||
]
|
||||
|
||||
SAFETY_MODELS_ENTRIES = list[ProviderModelEntry]()
|
||||
|
||||
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + SAFETY_MODELS_ENTRIES
|
52
llama_stack/providers/remote/inference/vertexai/vertexai.py
Normal file
52
llama_stack/providers/remote/inference/vertexai/vertexai.py
Normal file
|
@ -0,0 +1,52 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.inference import ChatCompletionRequest
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
|
||||
LiteLLMOpenAIMixin,
|
||||
)
|
||||
|
||||
from .config import VertexAIConfig
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class VertexAIInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
def __init__(self, config: VertexAIConfig) -> None:
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
MODEL_ENTRIES,
|
||||
litellm_provider_name="vertex_ai",
|
||||
api_key_from_config=None, # Vertex AI uses ADC, not API keys
|
||||
provider_data_api_key_field="vertex_project", # Use project for validation
|
||||
)
|
||||
self.config = config
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
# Vertex AI doesn't use API keys, it uses Application Default Credentials
|
||||
# Return empty string to let litellm handle authentication via ADC
|
||||
return ""
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
|
||||
# Get base parameters from parent
|
||||
params = await super()._get_params(request)
|
||||
|
||||
# Add Vertex AI specific parameters
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data:
|
||||
if getattr(provider_data, "vertex_project", None):
|
||||
params["vertex_project"] = provider_data.vertex_project
|
||||
if getattr(provider_data, "vertex_location", None):
|
||||
params["vertex_location"] = provider_data.vertex_location
|
||||
else:
|
||||
params["vertex_project"] = self.config.project
|
||||
params["vertex_location"] = self.config.location
|
||||
|
||||
# Remove api_key since Vertex AI uses ADC
|
||||
params.pop("api_key", None)
|
||||
|
||||
return params
|
Loading…
Add table
Add a link
Reference in a new issue