From 0b5a794c27949eb573b6d43108aa938805d232dd Mon Sep 17 00:00:00 2001 From: ehhuang Date: Fri, 8 Aug 2025 13:47:36 -0700 Subject: [PATCH 01/17] fix: telemetry logger spams when queue is full (#3070) # What does this PR do? ## Test Plan Ran a stress test on chat completion endpoint locally: For 10 concurrent users over 3 minutes: Before: image After: image (Will send scripts in a future PR.) --- .../providers/utils/telemetry/tracing.py | 29 +++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/llama_stack/providers/utils/telemetry/tracing.py b/llama_stack/providers/utils/telemetry/tracing.py index 75b29cdce..7080e774a 100644 --- a/llama_stack/providers/utils/telemetry/tracing.py +++ b/llama_stack/providers/utils/telemetry/tracing.py @@ -9,7 +9,9 @@ import contextvars import logging import queue import random +import sys import threading +import time from collections.abc import Callable from datetime import UTC, datetime from functools import wraps @@ -30,6 +32,16 @@ from llama_stack.providers.utils.telemetry.trace_protocol import serialize_value logger = get_logger(__name__, category="core") +# Fallback logger that does NOT propagate to TelemetryHandler to avoid recursion +_fallback_logger = logging.getLogger("llama_stack.telemetry.background") +if not _fallback_logger.handlers: + _fallback_logger.propagate = False + _fallback_logger.setLevel(logging.ERROR) + _fallback_handler = logging.StreamHandler(sys.stderr) + _fallback_handler.setLevel(logging.ERROR) + _fallback_handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s")) + _fallback_logger.addHandler(_fallback_handler) + INVALID_SPAN_ID = 0x0000000000000000 INVALID_TRACE_ID = 0x00000000000000000000000000000000 @@ -79,19 +91,32 @@ def generate_trace_id() -> str: CURRENT_TRACE_CONTEXT = contextvars.ContextVar("trace_context", default=None) BACKGROUND_LOGGER = None +LOG_QUEUE_FULL_LOG_INTERVAL_SECONDS = 60.0 + class BackgroundLogger: def __init__(self, api: Telemetry, capacity: int = 100000): self.api = api - self.log_queue = queue.Queue(maxsize=capacity) + self.log_queue: queue.Queue[Any] = queue.Queue(maxsize=capacity) self.worker_thread = threading.Thread(target=self._process_logs, daemon=True) self.worker_thread.start() + self._last_queue_full_log_time: float = 0.0 + self._dropped_since_last_notice: int = 0 def log_event(self, event): try: self.log_queue.put_nowait(event) except queue.Full: - logger.error("Log queue is full, dropping event") + # Aggregate drops and emit at most once per interval via fallback logger + self._dropped_since_last_notice += 1 + current_time = time.time() + if current_time - self._last_queue_full_log_time >= LOG_QUEUE_FULL_LOG_INTERVAL_SECONDS: + _fallback_logger.error( + "Log queue is full; dropped %d events since last notice", + self._dropped_since_last_notice, + ) + self._last_queue_full_log_time = current_time + self._dropped_since_last_notice = 0 def _process_logs(self): while True: From 1677d6bffdf0a002abe3b827c460df4930ee83c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Eli=C3=A1=C5=A1?= Date: Fri, 8 Aug 2025 22:48:15 +0200 Subject: [PATCH 02/17] feat: Flash-Lite 2.0 and 2.5 models added to Gemini inference provider (#3058) PR adds Flash-Lite 2.0 and 2.5 models to the Gemini inference provider Closes #3046 ## Test Plan I was not able to locate any existing test for this provider, so I performed manual testing. But the change is really trivial and straightforward. --- llama_stack/providers/remote/inference/gemini/models.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/llama_stack/providers/remote/inference/gemini/models.py b/llama_stack/providers/remote/inference/gemini/models.py index 6fda35e0f..bd696b0ac 100644 --- a/llama_stack/providers/remote/inference/gemini/models.py +++ b/llama_stack/providers/remote/inference/gemini/models.py @@ -13,7 +13,9 @@ LLM_MODEL_IDS = [ "gemini-1.5-flash", "gemini-1.5-pro", "gemini-2.0-flash", + "gemini-2.0-flash-lite", "gemini-2.5-flash", + "gemini-2.5-flash-lite", "gemini-2.5-pro", ] From ce72a2852516bbd702d202b2f4426478643faea0 Mon Sep 17 00:00:00 2001 From: Varsha Date: Sun, 10 Aug 2025 15:48:36 -0700 Subject: [PATCH 03/17] docs: Update doc on search modes for Milvus (#3078) # What does this PR do? Update Milvus doc on using search modes. ## Test Plan Signed-off-by: Varsha Prasad Narsing --- .../providers/vector_io/remote_milvus.md | 87 +++++++++++++++++++ llama_stack/providers/registry/vector_io.py | 87 +++++++++++++++++++ 2 files changed, 174 insertions(+) diff --git a/docs/source/providers/vector_io/remote_milvus.md b/docs/source/providers/vector_io/remote_milvus.md index 3646f4acc..2af64b8bb 100644 --- a/docs/source/providers/vector_io/remote_milvus.md +++ b/docs/source/providers/vector_io/remote_milvus.md @@ -11,6 +11,7 @@ That means you're not limited to storing vectors in memory or in a separate serv - Easy to use - Fully integrated with Llama Stack +- Supports all search modes: vector, keyword, and hybrid search (both inline and remote configurations) ## Usage @@ -101,6 +102,92 @@ vector_io: - **`client_pem_path`**: Path to the **client certificate** file (required for mTLS). - **`client_key_path`**: Path to the **client private key** file (required for mTLS). +## Search Modes + +Milvus supports three different search modes for both inline and remote configurations: + +### Vector Search +Vector search uses semantic similarity to find the most relevant chunks based on embedding vectors. This is the default search mode and works well for finding conceptually similar content. + +```python +# Vector search example +search_response = client.vector_stores.search( + vector_store_id=vector_store.id, + query="What is machine learning?", + search_mode="vector", + max_num_results=5, +) +``` + +### Keyword Search +Keyword search uses traditional text-based matching to find chunks containing specific terms or phrases. This is useful when you need exact term matches. + +```python +# Keyword search example +search_response = client.vector_stores.search( + vector_store_id=vector_store.id, + query="Python programming language", + search_mode="keyword", + max_num_results=5, +) +``` + +### Hybrid Search +Hybrid search combines both vector and keyword search methods to provide more comprehensive results. It leverages the strengths of both semantic similarity and exact term matching. + +#### Basic Hybrid Search +```python +# Basic hybrid search example (uses RRF ranker with default impact_factor=60.0) +search_response = client.vector_stores.search( + vector_store_id=vector_store.id, + query="neural networks in Python", + search_mode="hybrid", + max_num_results=5, +) +``` + +**Note**: The default `impact_factor` value of 60.0 was empirically determined to be optimal in the original RRF research paper: ["Reciprocal Rank Fusion outperforms Condorcet and individual Rank Learning Methods"](https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf) (Cormack et al., 2009). + +#### Hybrid Search with RRF (Reciprocal Rank Fusion) Ranker +RRF combines rankings from vector and keyword search by using reciprocal ranks. The impact factor controls how much weight is given to higher-ranked results. + +```python +# Hybrid search with custom RRF parameters +search_response = client.vector_stores.search( + vector_store_id=vector_store.id, + query="neural networks in Python", + search_mode="hybrid", + max_num_results=5, + ranking_options={ + "ranker": { + "type": "rrf", + "impact_factor": 100.0, # Higher values give more weight to top-ranked results + } + }, +) +``` + +#### Hybrid Search with Weighted Ranker +Weighted ranker linearly combines normalized scores from vector and keyword search. The alpha parameter controls the balance between the two search methods. + +```python +# Hybrid search with weighted ranker +search_response = client.vector_stores.search( + vector_store_id=vector_store.id, + query="neural networks in Python", + search_mode="hybrid", + max_num_results=5, + ranking_options={ + "ranker": { + "type": "weighted", + "alpha": 0.7, # 70% vector search, 30% keyword search + } + }, +) +``` + +For detailed documentation on RRF and Weighted rankers, please refer to the [Milvus Reranking Guide](https://milvus.io/docs/reranking.md). + ## Documentation See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for more details about Milvus in general. diff --git a/llama_stack/providers/registry/vector_io.py b/llama_stack/providers/registry/vector_io.py index 846f7b88e..b4f3ab6ac 100644 --- a/llama_stack/providers/registry/vector_io.py +++ b/llama_stack/providers/registry/vector_io.py @@ -535,6 +535,7 @@ That means you're not limited to storing vectors in memory or in a separate serv - Easy to use - Fully integrated with Llama Stack +- Supports all search modes: vector, keyword, and hybrid search (both inline and remote configurations) ## Usage @@ -625,6 +626,92 @@ vector_io: - **`client_pem_path`**: Path to the **client certificate** file (required for mTLS). - **`client_key_path`**: Path to the **client private key** file (required for mTLS). +## Search Modes + +Milvus supports three different search modes for both inline and remote configurations: + +### Vector Search +Vector search uses semantic similarity to find the most relevant chunks based on embedding vectors. This is the default search mode and works well for finding conceptually similar content. + +```python +# Vector search example +search_response = client.vector_stores.search( + vector_store_id=vector_store.id, + query="What is machine learning?", + search_mode="vector", + max_num_results=5, +) +``` + +### Keyword Search +Keyword search uses traditional text-based matching to find chunks containing specific terms or phrases. This is useful when you need exact term matches. + +```python +# Keyword search example +search_response = client.vector_stores.search( + vector_store_id=vector_store.id, + query="Python programming language", + search_mode="keyword", + max_num_results=5, +) +``` + +### Hybrid Search +Hybrid search combines both vector and keyword search methods to provide more comprehensive results. It leverages the strengths of both semantic similarity and exact term matching. + +#### Basic Hybrid Search +```python +# Basic hybrid search example (uses RRF ranker with default impact_factor=60.0) +search_response = client.vector_stores.search( + vector_store_id=vector_store.id, + query="neural networks in Python", + search_mode="hybrid", + max_num_results=5, +) +``` + +**Note**: The default `impact_factor` value of 60.0 was empirically determined to be optimal in the original RRF research paper: ["Reciprocal Rank Fusion outperforms Condorcet and individual Rank Learning Methods"](https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf) (Cormack et al., 2009). + +#### Hybrid Search with RRF (Reciprocal Rank Fusion) Ranker +RRF combines rankings from vector and keyword search by using reciprocal ranks. The impact factor controls how much weight is given to higher-ranked results. + +```python +# Hybrid search with custom RRF parameters +search_response = client.vector_stores.search( + vector_store_id=vector_store.id, + query="neural networks in Python", + search_mode="hybrid", + max_num_results=5, + ranking_options={ + "ranker": { + "type": "rrf", + "impact_factor": 100.0, # Higher values give more weight to top-ranked results + } + }, +) +``` + +#### Hybrid Search with Weighted Ranker +Weighted ranker linearly combines normalized scores from vector and keyword search. The alpha parameter controls the balance between the two search methods. + +```python +# Hybrid search with weighted ranker +search_response = client.vector_stores.search( + vector_store_id=vector_store.id, + query="neural networks in Python", + search_mode="hybrid", + max_num_results=5, + ranking_options={ + "ranker": { + "type": "weighted", + "alpha": 0.7, # 70% vector search, 30% keyword search + } + }, +) +``` + +For detailed documentation on RRF and Weighted rankers, please refer to the [Milvus Reranking Guide](https://milvus.io/docs/reranking.md). + ## Documentation See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for more details about Milvus in general. From 69dc789e15421e717d447042593cde08c61ad9e2 Mon Sep 17 00:00:00 2001 From: Varsha Date: Sun, 10 Aug 2025 16:34:34 -0700 Subject: [PATCH 04/17] docs: Add unsupported search mode info about FAISS (#3089) --- docs/source/providers/vector_io/inline_faiss.md | 12 ++++++++++++ .../providers/inline/vector_io/faiss/faiss.py | 8 ++++++-- llama_stack/providers/registry/vector_io.py | 12 ++++++++++++ 3 files changed, 30 insertions(+), 2 deletions(-) diff --git a/docs/source/providers/vector_io/inline_faiss.md b/docs/source/providers/vector_io/inline_faiss.md index bcff66f3f..cfa18a839 100644 --- a/docs/source/providers/vector_io/inline_faiss.md +++ b/docs/source/providers/vector_io/inline_faiss.md @@ -12,6 +12,18 @@ That means you'll get fast and efficient vector retrieval. - Lightweight and easy to use - Fully integrated with Llama Stack - GPU support +- **Vector search** - FAISS supports pure vector similarity search using embeddings + +## Search Modes + +**Supported:** +- **Vector Search** (`mode="vector"`): Performs vector similarity search using embeddings + +**Not Supported:** +- **Keyword Search** (`mode="keyword"`): Not supported by FAISS +- **Hybrid Search** (`mode="hybrid"`): Not supported by FAISS + +> **Note**: FAISS is designed as a pure vector similarity search library. See the [FAISS GitHub repository](https://github.com/facebookresearch/faiss) for more details about FAISS's core functionality. ## Usage diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index 7a5373726..5a063592c 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -174,7 +174,9 @@ class FaissIndex(EmbeddingIndex): k: int, score_threshold: float, ) -> QueryChunksResponse: - raise NotImplementedError("Keyword search is not supported in FAISS") + raise NotImplementedError( + "Keyword search is not supported - underlying DB FAISS does not support this search mode" + ) async def query_hybrid( self, @@ -185,7 +187,9 @@ class FaissIndex(EmbeddingIndex): reranker_type: str, reranker_params: dict[str, Any] | None = None, ) -> QueryChunksResponse: - raise NotImplementedError("Hybrid search is not supported in FAISS") + raise NotImplementedError( + "Hybrid search is not supported - underlying DB FAISS does not support this search mode" + ) class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): diff --git a/llama_stack/providers/registry/vector_io.py b/llama_stack/providers/registry/vector_io.py index b4f3ab6ac..ed170b508 100644 --- a/llama_stack/providers/registry/vector_io.py +++ b/llama_stack/providers/registry/vector_io.py @@ -45,6 +45,18 @@ That means you'll get fast and efficient vector retrieval. - Lightweight and easy to use - Fully integrated with Llama Stack - GPU support +- **Vector search** - FAISS supports pure vector similarity search using embeddings + +## Search Modes + +**Supported:** +- **Vector Search** (`mode="vector"`): Performs vector similarity search using embeddings + +**Not Supported:** +- **Keyword Search** (`mode="keyword"`): Not supported by FAISS +- **Hybrid Search** (`mode="hybrid"`): Not supported by FAISS + +> **Note**: FAISS is designed as a pure vector similarity search library. See the [FAISS GitHub repository](https://github.com/facebookresearch/faiss) for more details about FAISS's core functionality. ## Usage From 78a59a4dbeaa50cc85d4f08f6c72b1bb51e7f721 Mon Sep 17 00:00:00 2001 From: Francisco Arceo Date: Sun, 10 Aug 2025 19:11:14 -0600 Subject: [PATCH 05/17] chore: Adding GitHub Stars, trends, and contributor shout out to README (#3079) # What does this PR do? Updates READMe to add 1. GitHub badge highlighting Llama Stack as #1 Repo of the Day 2. GitHub Star History (cumulative stars chart) 3. Contributor shout out ## Test Plan Signed-off-by: Francisco Javier Arceo --- README.md | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/README.md b/README.md index 03aa3dd50..8db4580a2 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,8 @@ # Llama Stack +meta-llama%2Fllama-stack | Trendshift + +----- [![PyPI version](https://img.shields.io/pypi/v/llama_stack.svg)](https://pypi.org/project/llama_stack/) [![PyPI - Downloads](https://img.shields.io/pypi/dm/llama-stack)](https://pypi.org/project/llama-stack/) [![License](https://img.shields.io/pypi/l/llama_stack.svg)](https://github.com/meta-llama/llama-stack/blob/main/LICENSE) @@ -9,6 +12,7 @@ [**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb) | [**Discord**](https://discord.gg/llama-stack) + ### ✨🎉 Llama 4 Support 🎉✨ We released [Version 0.2.0](https://github.com/meta-llama/llama-stack/releases/tag/v0.2.0) with support for the Llama 4 herd of models released by Meta. @@ -179,3 +183,17 @@ Please checkout our [Documentation](https://llama-stack.readthedocs.io/en/latest Check out our client SDKs for connecting to a Llama Stack server in your preferred language, you can choose from [python](https://github.com/meta-llama/llama-stack-client-python), [typescript](https://github.com/meta-llama/llama-stack-client-typescript), [swift](https://github.com/meta-llama/llama-stack-client-swift), and [kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) programming languages to quickly build your applications. You can find more example scripts with client SDKs to talk with the Llama Stack server in our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/examples) repo. + + +## 🌟 GitHub Star History +## Star History + +[![Star History Chart](https://api.star-history.com/svg?repos=meta-llama/llama-stack&type=Date)](https://www.star-history.com/#meta-llama/llama-stack&Date) + +## ✨ Contributors + +Thanks to all of our amazing contributors! + + + + \ No newline at end of file From a4bad6c0b44dabad5fe183266960eeb52b2b27d6 Mon Sep 17 00:00:00 2001 From: Eran Cohen Date: Mon, 11 Aug 2025 15:22:04 +0300 Subject: [PATCH 06/17] feat: Add Google Vertex AI inference provider support (#2841) # What does this PR do? - Add new Vertex AI remote inference provider with litellm integration - Support for Gemini models through Google Cloud Vertex AI platform - Uses Google Cloud Application Default Credentials (ADC) for authentication - Added VertexAI models: gemini-2.5-flash, gemini-2.5-pro, gemini-2.0-flash. - Updated provider registry to include vertexai provider - Updated starter template to support Vertex AI configuration - Added comprehensive documentation and sample configuration relates to https://github.com/meta-llama/llama-stack/issues/2747 ## Test Plan Signed-off-by: Eran Cohen Co-authored-by: Francisco Arceo --- docs/source/providers/inference/index.md | 1 + .../providers/inference/remote_vertexai.md | 40 ++++++++++++++ llama_stack/distributions/ci-tests/build.yaml | 1 + llama_stack/distributions/ci-tests/run.yaml | 5 ++ llama_stack/distributions/starter/build.yaml | 1 + llama_stack/distributions/starter/run.yaml | 5 ++ llama_stack/distributions/starter/starter.py | 10 ++++ llama_stack/providers/registry/inference.py | 30 +++++++++++ .../remote/inference/vertexai/__init__.py | 15 ++++++ .../remote/inference/vertexai/config.py | 45 ++++++++++++++++ .../remote/inference/vertexai/models.py | 20 +++++++ .../remote/inference/vertexai/vertexai.py | 52 +++++++++++++++++++ .../inference/test_openai_completion.py | 1 + .../inference/test_text_inference.py | 1 + 14 files changed, 227 insertions(+) create mode 100644 docs/source/providers/inference/remote_vertexai.md create mode 100644 llama_stack/providers/remote/inference/vertexai/__init__.py create mode 100644 llama_stack/providers/remote/inference/vertexai/config.py create mode 100644 llama_stack/providers/remote/inference/vertexai/models.py create mode 100644 llama_stack/providers/remote/inference/vertexai/vertexai.py diff --git a/docs/source/providers/inference/index.md b/docs/source/providers/inference/index.md index 1c7bc86b9..38781e5eb 100644 --- a/docs/source/providers/inference/index.md +++ b/docs/source/providers/inference/index.md @@ -29,6 +29,7 @@ remote_runpod remote_sambanova remote_tgi remote_together +remote_vertexai remote_vllm remote_watsonx ``` diff --git a/docs/source/providers/inference/remote_vertexai.md b/docs/source/providers/inference/remote_vertexai.md new file mode 100644 index 000000000..962bbd76f --- /dev/null +++ b/docs/source/providers/inference/remote_vertexai.md @@ -0,0 +1,40 @@ +# remote::vertexai + +## Description + +Google Vertex AI inference provider enables you to use Google's Gemini models through Google Cloud's Vertex AI platform, providing several advantages: + +• Enterprise-grade security: Uses Google Cloud's security controls and IAM +• Better integration: Seamless integration with other Google Cloud services +• Advanced features: Access to additional Vertex AI features like model tuning and monitoring +• Authentication: Uses Google Cloud Application Default Credentials (ADC) instead of API keys + +Configuration: +- Set VERTEX_AI_PROJECT environment variable (required) +- Set VERTEX_AI_LOCATION environment variable (optional, defaults to us-central1) +- Use Google Cloud Application Default Credentials or service account key + +Authentication Setup: +Option 1 (Recommended): gcloud auth application-default login +Option 2: Set GOOGLE_APPLICATION_CREDENTIALS to service account key path + +Available Models: +- vertex_ai/gemini-2.0-flash +- vertex_ai/gemini-2.5-flash +- vertex_ai/gemini-2.5-pro + +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `project` | `` | No | | Google Cloud project ID for Vertex AI | +| `location` | `` | No | us-central1 | Google Cloud location for Vertex AI | + +## Sample Configuration + +```yaml +project: ${env.VERTEX_AI_PROJECT:=} +location: ${env.VERTEX_AI_LOCATION:=us-central1} + +``` + diff --git a/llama_stack/distributions/ci-tests/build.yaml b/llama_stack/distributions/ci-tests/build.yaml index 2f9ae8682..e6e699b62 100644 --- a/llama_stack/distributions/ci-tests/build.yaml +++ b/llama_stack/distributions/ci-tests/build.yaml @@ -14,6 +14,7 @@ distribution_spec: - provider_type: remote::openai - provider_type: remote::anthropic - provider_type: remote::gemini + - provider_type: remote::vertexai - provider_type: remote::groq - provider_type: remote::sambanova - provider_type: inline::sentence-transformers diff --git a/llama_stack/distributions/ci-tests/run.yaml b/llama_stack/distributions/ci-tests/run.yaml index 188c66275..05e1b4576 100644 --- a/llama_stack/distributions/ci-tests/run.yaml +++ b/llama_stack/distributions/ci-tests/run.yaml @@ -65,6 +65,11 @@ providers: provider_type: remote::gemini config: api_key: ${env.GEMINI_API_KEY:=} + - provider_id: ${env.VERTEX_AI_PROJECT:+vertexai} + provider_type: remote::vertexai + config: + project: ${env.VERTEX_AI_PROJECT:=} + location: ${env.VERTEX_AI_LOCATION:=us-central1} - provider_id: groq provider_type: remote::groq config: diff --git a/llama_stack/distributions/starter/build.yaml b/llama_stack/distributions/starter/build.yaml index f95a03a9e..1a4f81d49 100644 --- a/llama_stack/distributions/starter/build.yaml +++ b/llama_stack/distributions/starter/build.yaml @@ -14,6 +14,7 @@ distribution_spec: - provider_type: remote::openai - provider_type: remote::anthropic - provider_type: remote::gemini + - provider_type: remote::vertexai - provider_type: remote::groq - provider_type: remote::sambanova - provider_type: inline::sentence-transformers diff --git a/llama_stack/distributions/starter/run.yaml b/llama_stack/distributions/starter/run.yaml index 8bd737686..46bd12956 100644 --- a/llama_stack/distributions/starter/run.yaml +++ b/llama_stack/distributions/starter/run.yaml @@ -65,6 +65,11 @@ providers: provider_type: remote::gemini config: api_key: ${env.GEMINI_API_KEY:=} + - provider_id: ${env.VERTEX_AI_PROJECT:+vertexai} + provider_type: remote::vertexai + config: + project: ${env.VERTEX_AI_PROJECT:=} + location: ${env.VERTEX_AI_LOCATION:=us-central1} - provider_id: groq provider_type: remote::groq config: diff --git a/llama_stack/distributions/starter/starter.py b/llama_stack/distributions/starter/starter.py index a970f2d1c..0270b68ad 100644 --- a/llama_stack/distributions/starter/starter.py +++ b/llama_stack/distributions/starter/starter.py @@ -56,6 +56,7 @@ ENABLED_INFERENCE_PROVIDERS = [ "fireworks", "together", "gemini", + "vertexai", "groq", "sambanova", "anthropic", @@ -71,6 +72,7 @@ INFERENCE_PROVIDER_IDS = { "tgi": "${env.TGI_URL:+tgi}", "cerebras": "${env.CEREBRAS_API_KEY:+cerebras}", "nvidia": "${env.NVIDIA_API_KEY:+nvidia}", + "vertexai": "${env.VERTEX_AI_PROJECT:+vertexai}", } @@ -246,6 +248,14 @@ def get_distribution_template() -> DistributionTemplate: "", "Gemini API Key", ), + "VERTEX_AI_PROJECT": ( + "", + "Google Cloud Project ID for Vertex AI", + ), + "VERTEX_AI_LOCATION": ( + "us-central1", + "Google Cloud Location for Vertex AI", + ), "SAMBANOVA_API_KEY": ( "", "SambaNova API Key", diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index a8bc96a77..1801cdcad 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -213,6 +213,36 @@ def available_providers() -> list[ProviderSpec]: description="Google Gemini inference provider for accessing Gemini models and Google's AI services.", ), ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_type="vertexai", + pip_packages=["litellm", "google-cloud-aiplatform"], + module="llama_stack.providers.remote.inference.vertexai", + config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig", + provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator", + description="""Google Vertex AI inference provider enables you to use Google's Gemini models through Google Cloud's Vertex AI platform, providing several advantages: + +• Enterprise-grade security: Uses Google Cloud's security controls and IAM +• Better integration: Seamless integration with other Google Cloud services +• Advanced features: Access to additional Vertex AI features like model tuning and monitoring +• Authentication: Uses Google Cloud Application Default Credentials (ADC) instead of API keys + +Configuration: +- Set VERTEX_AI_PROJECT environment variable (required) +- Set VERTEX_AI_LOCATION environment variable (optional, defaults to us-central1) +- Use Google Cloud Application Default Credentials or service account key + +Authentication Setup: +Option 1 (Recommended): gcloud auth application-default login +Option 2: Set GOOGLE_APPLICATION_CREDENTIALS to service account key path + +Available Models: +- vertex_ai/gemini-2.0-flash +- vertex_ai/gemini-2.5-flash +- vertex_ai/gemini-2.5-pro""", + ), + ), remote_provider_spec( api=Api.inference, adapter=AdapterSpec( diff --git a/llama_stack/providers/remote/inference/vertexai/__init__.py b/llama_stack/providers/remote/inference/vertexai/__init__.py new file mode 100644 index 000000000..d9e9419be --- /dev/null +++ b/llama_stack/providers/remote/inference/vertexai/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import VertexAIConfig + + +async def get_adapter_impl(config: VertexAIConfig, _deps): + from .vertexai import VertexAIInferenceAdapter + + impl = VertexAIInferenceAdapter(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/remote/inference/vertexai/config.py b/llama_stack/providers/remote/inference/vertexai/config.py new file mode 100644 index 000000000..659de653e --- /dev/null +++ b/llama_stack/providers/remote/inference/vertexai/config.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from pydantic import BaseModel, Field + +from llama_stack.schema_utils import json_schema_type + + +class VertexAIProviderDataValidator(BaseModel): + vertex_project: str | None = Field( + default=None, + description="Google Cloud project ID for Vertex AI", + ) + vertex_location: str | None = Field( + default=None, + description="Google Cloud location for Vertex AI (e.g., us-central1)", + ) + + +@json_schema_type +class VertexAIConfig(BaseModel): + project: str = Field( + description="Google Cloud project ID for Vertex AI", + ) + location: str = Field( + default="us-central1", + description="Google Cloud location for Vertex AI", + ) + + @classmethod + def sample_run_config( + cls, + project: str = "${env.VERTEX_AI_PROJECT:=}", + location: str = "${env.VERTEX_AI_LOCATION:=us-central1}", + **kwargs, + ) -> dict[str, Any]: + return { + "project": project, + "location": location, + } diff --git a/llama_stack/providers/remote/inference/vertexai/models.py b/llama_stack/providers/remote/inference/vertexai/models.py new file mode 100644 index 000000000..e72db533d --- /dev/null +++ b/llama_stack/providers/remote/inference/vertexai/models.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.providers.utils.inference.model_registry import ( + ProviderModelEntry, +) + +# Vertex AI model IDs with vertex_ai/ prefix as required by litellm +LLM_MODEL_IDS = [ + "vertex_ai/gemini-2.0-flash", + "vertex_ai/gemini-2.5-flash", + "vertex_ai/gemini-2.5-pro", +] + +SAFETY_MODELS_ENTRIES = list[ProviderModelEntry]() + +MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + SAFETY_MODELS_ENTRIES diff --git a/llama_stack/providers/remote/inference/vertexai/vertexai.py b/llama_stack/providers/remote/inference/vertexai/vertexai.py new file mode 100644 index 000000000..8807fd0e6 --- /dev/null +++ b/llama_stack/providers/remote/inference/vertexai/vertexai.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.apis.inference import ChatCompletionRequest +from llama_stack.providers.utils.inference.litellm_openai_mixin import ( + LiteLLMOpenAIMixin, +) + +from .config import VertexAIConfig +from .models import MODEL_ENTRIES + + +class VertexAIInferenceAdapter(LiteLLMOpenAIMixin): + def __init__(self, config: VertexAIConfig) -> None: + LiteLLMOpenAIMixin.__init__( + self, + MODEL_ENTRIES, + litellm_provider_name="vertex_ai", + api_key_from_config=None, # Vertex AI uses ADC, not API keys + provider_data_api_key_field="vertex_project", # Use project for validation + ) + self.config = config + + def get_api_key(self) -> str: + # Vertex AI doesn't use API keys, it uses Application Default Credentials + # Return empty string to let litellm handle authentication via ADC + return "" + + async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]: + # Get base parameters from parent + params = await super()._get_params(request) + + # Add Vertex AI specific parameters + provider_data = self.get_request_provider_data() + if provider_data: + if getattr(provider_data, "vertex_project", None): + params["vertex_project"] = provider_data.vertex_project + if getattr(provider_data, "vertex_location", None): + params["vertex_location"] = provider_data.vertex_location + else: + params["vertex_project"] = self.config.project + params["vertex_location"] = self.config.location + + # Remove api_key since Vertex AI uses ADC + params.pop("api_key", None) + + return params diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index 0222bfb79..72137662d 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -34,6 +34,7 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id) "remote::runpod", "remote::sambanova", "remote::tgi", + "remote::vertexai", ): pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.") diff --git a/tests/integration/inference/test_text_inference.py b/tests/integration/inference/test_text_inference.py index 08e19726e..d7ffe5929 100644 --- a/tests/integration/inference/test_text_inference.py +++ b/tests/integration/inference/test_text_inference.py @@ -29,6 +29,7 @@ def skip_if_model_doesnt_support_completion(client_with_models, model_id): "remote::openai", "remote::anthropic", "remote::gemini", + "remote::vertexai", "remote::groq", "remote::sambanova", ) From 8faff925919d034f2f3e971a2d24f64b221ed4d9 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Mon, 11 Aug 2025 09:38:54 -0500 Subject: [PATCH 07/17] chore: remove redundant code in unregister_toolgroup (#3092) # What does this PR do? removes redundant code ## Test Plan ci --- llama_stack/core/routing_tables/toolgroups.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/llama_stack/core/routing_tables/toolgroups.py b/llama_stack/core/routing_tables/toolgroups.py index e172af991..6910b3906 100644 --- a/llama_stack/core/routing_tables/toolgroups.py +++ b/llama_stack/core/routing_tables/toolgroups.py @@ -124,10 +124,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): return toolgroup async def unregister_toolgroup(self, toolgroup_id: str) -> None: - tool_group = await self.get_tool_group(toolgroup_id) - if tool_group is None: - raise ToolGroupNotFoundError(toolgroup_id) - await self.unregister_object(tool_group) + await self.unregister_object(await self.get_tool_group(toolgroup_id)) async def shutdown(self) -> None: pass From 7448a4a88c71a996b7cfa980d9d55915c1ab5094 Mon Sep 17 00:00:00 2001 From: Francisco Arceo Date: Mon, 11 Aug 2025 08:39:52 -0600 Subject: [PATCH 08/17] chore: Updating UI Sidebar (#3081) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? This updates the sidebar to look a little more like other popular ones. Screenshot 2025-08-08 at 11 25
31 PM ## Test Plan Signed-off-by: Francisco Javier Arceo --- llama_stack/ui/app/chat-playground/page.tsx | 2 +- .../ui/components/layout/app-sidebar.tsx | 168 ++++++++++-------- 2 files changed, 96 insertions(+), 74 deletions(-) diff --git a/llama_stack/ui/app/chat-playground/page.tsx b/llama_stack/ui/app/chat-playground/page.tsx index c31248b78..d8094af85 100644 --- a/llama_stack/ui/app/chat-playground/page.tsx +++ b/llama_stack/ui/app/chat-playground/page.tsx @@ -175,7 +175,7 @@ const handleSubmitWithContent = async (content: string) => { return (
-

Chat Playground

+

Chat Playground (Completions)