From 28a0fe57ccb7078fe35708ff67815164de8fa14e Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Thu, 6 Feb 2025 16:12:29 -0800 Subject: [PATCH] fix: Update rag examples to use fresh faiss index every time (#998) # What does this PR do? In several examples we use the same faiss index , which means running it multiple times fills up the index with duplicates which eventually degrades the model performance on RAG as multiple copies of the same irrelevant chunks might be picked up several times. Fix is to ensure we create a new index each time. Resolves issue in this discussion - https://github.com/meta-llama/llama-stack/discussions/995 ## Test Plan Re-ran the getting started guide multiple times to see the same output Co-authored-by: Hardik Shah --- docs/getting_started.ipynb | 9 +++++---- docs/source/getting_started/index.md | 3 ++- tests/client-sdk/agents/test_agents.py | 4 ++-- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/docs/getting_started.ipynb b/docs/getting_started.ipynb index 3087a2b3b..4e4893158 100644 --- a/docs/getting_started.ipynb +++ b/docs/getting_started.ipynb @@ -89,7 +89,7 @@ "# install a branch of llama stack\n", "import os\n", "os.environ[\"UV_SYSTEM_PYTHON\"] = \"1\"\n", - "!pip install uv \n", + "!pip install uv\n", "!uv pip install llama-stack" ] }, @@ -691,7 +691,7 @@ " from google.colab import userdata\n", " os.environ['TOGETHER_API_KEY'] = userdata.get('TOGETHER_API_KEY')\n", " os.environ['TAVILY_SEARCH_API_KEY'] = userdata.get('TAVILY_SEARCH_API_KEY')\n", - "except ImportError: \n", + "except ImportError:\n", " print(\"Not in Google Colab environment\")\n", "\n", "for key in ['TOGETHER_API_KEY', 'TAVILY_SEARCH_API_KEY']:\n", @@ -1656,6 +1656,7 @@ } ], "source": [ + "import uuid\n", "from llama_stack_client.lib.agents.agent import Agent\n", "from llama_stack_client.lib.agents.event_logger import EventLogger\n", "from llama_stack_client.types.agent_create_params import AgentConfig\n", @@ -1673,7 +1674,7 @@ " for i, url in enumerate(urls)\n", "]\n", "\n", - "vector_db_id = \"test-vector-db\"\n", + "vector_db_id = f\"test-vector-db-{uuid.uuid4().hex}\"\n", "client.vector_dbs.register(\n", " vector_db_id=vector_db_id,\n", " embedding_model=\"all-MiniLM-L6-v2\",\n", @@ -3098,7 +3099,7 @@ } ], "source": [ - "# NBVAL_SKIP \n", + "# NBVAL_SKIP\n", "print(f\"Getting traces for session_id={session_id}\")\n", "import json\n", "\n", diff --git a/docs/source/getting_started/index.md b/docs/source/getting_started/index.md index d8bf42533..b28b9afa3 100644 --- a/docs/source/getting_started/index.md +++ b/docs/source/getting_started/index.md @@ -173,6 +173,7 @@ Here is an example of a simple RAG (Retrieval Augmented Generation) chatbot agen ```python import os +import uuid from termcolor import cprint from llama_stack_client.lib.agents.agent import Agent @@ -214,7 +215,7 @@ documents = [ ] # Register a vector database -vector_db_id = "test-vector-db" +vector_db_id = f"test-vector-db-{uuid.uuid4().hex}" client.vector_dbs.register( vector_db_id=vector_db_id, embedding_model="all-MiniLM-L6-v2", diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 7b7bce813..2b1db7df0 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -297,7 +297,7 @@ def test_override_system_message_behavior(llama_stack_client, agent_config): You are an expert in composing functions. You are given a question and a set of possible functions. Based on the question, you may or may not need to make one or more function/tool calls to achieve the purpose. If none of the function can be used, don't return [], instead answer the question directly without using functions. If the given question lacks the parameters required by the function, - also point it out. + also point it out. {{ function_description }} """ @@ -414,7 +414,7 @@ def test_rag_and_code_agent(llama_stack_client, agent_config): ) for i, url in enumerate(urls) ] - vector_db_id = "test-vector-db" + vector_db_id = f"test-vector-db-{uuid4()}" llama_stack_client.vector_dbs.register( vector_db_id=vector_db_id, embedding_model="all-MiniLM-L6-v2",