From e0731ba353e021534034456e973dd72170f8741c Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 16 Dec 2024 11:21:51 -0800 Subject: [PATCH] Fix agent executor --- llama_stack/apis/common/training_types.py | 2 -- .../agents/meta_reference/agent_instance.py | 20 +++++++---- .../providers/tests/agents/fixtures.py | 34 +++++++++++++++---- .../providers/utils/memory/vector_store.py | 22 ++++++++++++ 4 files changed, 62 insertions(+), 16 deletions(-) diff --git a/llama_stack/apis/common/training_types.py b/llama_stack/apis/common/training_types.py index ed278553e..b4bd1b0c6 100644 --- a/llama_stack/apis/common/training_types.py +++ b/llama_stack/apis/common/training_types.py @@ -10,8 +10,6 @@ from typing import Optional from llama_models.schema_utils import json_schema_type from pydantic import BaseModel -from llama_stack.apis.common.deployment_types import URL - @json_schema_type class PostTrainingMetric(BaseModel): diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index b403b9203..9e0c677dc 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -26,6 +26,7 @@ from llama_stack.apis.memory_banks import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403 from llama_stack.providers.utils.kvstore import KVStore +from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content from llama_stack.providers.utils.telemetry import tracing from .persistence import AgentPersistence @@ -387,7 +388,7 @@ class ChatAgent(ShieldRunnerMixin): if rag_context: last_message = input_messages[-1] - last_message.context = "\n".join(rag_context) + last_message.context = rag_context elif attachments and AgentTool.code_interpreter.value in enabled_tools: urls = [a.content for a in attachments if isinstance(a.content, URL)] @@ -687,7 +688,7 @@ class ChatAgent(ShieldRunnerMixin): async def _retrieve_context( self, session_id: str, messages: List[Message], attachments: List[Attachment] - ) -> Tuple[Optional[List[str]], Optional[List[int]]]: # (rag_context, bank_ids) + ) -> Tuple[Optional[InterleavedContent], List[int]]: # (rag_context, bank_ids) bank_ids = [] memory = self._memory_tool_definition() @@ -755,11 +756,16 @@ class ChatAgent(ShieldRunnerMixin): break picked.append(f"id:{c.document_id}; content:{c.content}") - return [ - "Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n", - *picked, - "\n=== END-RETRIEVED-CONTEXT ===\n", - ], bank_ids + return ( + concat_interleaved_content( + [ + "Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n", + *picked, + "\n=== END-RETRIEVED-CONTEXT ===\n", + ] + ), + bank_ids, + ) def _get_tools(self) -> List[ToolDefinition]: ret = [] diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py index 93a011c95..13c250439 100644 --- a/llama_stack/providers/tests/agents/fixtures.py +++ b/llama_stack/providers/tests/agents/fixtures.py @@ -9,7 +9,7 @@ import tempfile import pytest import pytest_asyncio -from llama_stack.apis.models import ModelInput +from llama_stack.apis.models import ModelInput, ModelType from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.inline.agents.meta_reference import ( @@ -67,22 +67,42 @@ async def agents_stack(request, inference_model, safety_shield): for key in ["inference", "safety", "memory", "agents"]: fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") providers[key] = fixture.providers + if key == "inference": + providers[key].append( + Provider( + provider_id="agents_memory_provider", + provider_type="inline::sentence-transformers", + config={}, + ) + ) if fixture.provider_data: provider_data.update(fixture.provider_data) inference_models = ( inference_model if isinstance(inference_model, list) else [inference_model] ) + models = [ + ModelInput( + model_id=model, + model_type=ModelType.llm, + provider_id=providers["inference"][0].provider_id, + ) + for model in inference_models + ] + models.append( + ModelInput( + model_id="all-MiniLM-L6-v2", + model_type=ModelType.embedding, + provider_id="agents_memory_provider", + metadata={"embedding_dimension": 384}, + ) + ) + test_stack = await construct_stack_for_test( [Api.agents, Api.inference, Api.safety, Api.memory], providers, provider_data, - models=[ - ModelInput( - model_id=model, - ) - for model in inference_models - ], + models=models, shields=[safety_shield] if safety_shield else [], ) return test_stack diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index 5c58c5fac..cfe5c2816 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -21,6 +21,7 @@ from pypdf import PdfReader from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.tokenizer import Tokenizer +from llama_stack.apis.inference import InterleavedContent, TextContentItem from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory_banks import VectorMemoryBank from llama_stack.providers.datatypes import Api @@ -88,6 +89,26 @@ def content_from_data(data_url: str) -> str: return "" +def concat_interleaved_content(content: List[InterleavedContent]) -> InterleavedContent: + """concatenate interleaved content into a single list. ensure that 'str's are converted to TextContentItem when in a list""" + + ret = [] + + def _process(c): + if isinstance(c, str): + ret.append(TextContentItem(text=c)) + elif isinstance(c, list): + for item in c: + _process(item) + else: + ret.append(c) + + for c in content: + _process(c) + + return ret + + async def content_from_doc(doc: MemoryBankDocument) -> str: if isinstance(doc.content, URL): if doc.content.uri.startswith("data:"): @@ -125,6 +146,7 @@ def make_overlapped_chunks( for i in range(0, len(tokens), window_len - overlap_len): toks = tokens[i : i + window_len] chunk = tokenizer.decode(toks) + # chunk is a string chunks.append( Chunk(content=chunk, token_count=len(toks), document_id=document_id) )