Fix agent executor

This commit is contained in:
Ashwin Bharambe 2024-12-16 11:21:51 -08:00
parent 59ce047aea
commit e0731ba353
4 changed files with 62 additions and 16 deletions

View file

@ -10,8 +10,6 @@ from typing import Optional
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.common.deployment_types import URL
@json_schema_type @json_schema_type
class PostTrainingMetric(BaseModel): class PostTrainingMetric(BaseModel):

View file

@ -26,6 +26,7 @@ from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403
from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
from llama_stack.providers.utils.telemetry import tracing from llama_stack.providers.utils.telemetry import tracing
from .persistence import AgentPersistence from .persistence import AgentPersistence
@ -387,7 +388,7 @@ class ChatAgent(ShieldRunnerMixin):
if rag_context: if rag_context:
last_message = input_messages[-1] last_message = input_messages[-1]
last_message.context = "\n".join(rag_context) last_message.context = rag_context
elif attachments and AgentTool.code_interpreter.value in enabled_tools: elif attachments and AgentTool.code_interpreter.value in enabled_tools:
urls = [a.content for a in attachments if isinstance(a.content, URL)] urls = [a.content for a in attachments if isinstance(a.content, URL)]
@ -687,7 +688,7 @@ class ChatAgent(ShieldRunnerMixin):
async def _retrieve_context( async def _retrieve_context(
self, session_id: str, messages: List[Message], attachments: List[Attachment] self, session_id: str, messages: List[Message], attachments: List[Attachment]
) -> Tuple[Optional[List[str]], Optional[List[int]]]: # (rag_context, bank_ids) ) -> Tuple[Optional[InterleavedContent], List[int]]: # (rag_context, bank_ids)
bank_ids = [] bank_ids = []
memory = self._memory_tool_definition() memory = self._memory_tool_definition()
@ -755,11 +756,16 @@ class ChatAgent(ShieldRunnerMixin):
break break
picked.append(f"id:{c.document_id}; content:{c.content}") picked.append(f"id:{c.document_id}; content:{c.content}")
return [ return (
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n", concat_interleaved_content(
*picked, [
"\n=== END-RETRIEVED-CONTEXT ===\n", "Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
], bank_ids *picked,
"\n=== END-RETRIEVED-CONTEXT ===\n",
]
),
bank_ids,
)
def _get_tools(self) -> List[ToolDefinition]: def _get_tools(self) -> List[ToolDefinition]:
ret = [] ret = []

View file

@ -9,7 +9,7 @@ import tempfile
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from llama_stack.apis.models import ModelInput from llama_stack.apis.models import ModelInput, ModelType
from llama_stack.distribution.datatypes import Api, Provider from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.agents.meta_reference import ( from llama_stack.providers.inline.agents.meta_reference import (
@ -67,22 +67,42 @@ async def agents_stack(request, inference_model, safety_shield):
for key in ["inference", "safety", "memory", "agents"]: for key in ["inference", "safety", "memory", "agents"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers providers[key] = fixture.providers
if key == "inference":
providers[key].append(
Provider(
provider_id="agents_memory_provider",
provider_type="inline::sentence-transformers",
config={},
)
)
if fixture.provider_data: if fixture.provider_data:
provider_data.update(fixture.provider_data) provider_data.update(fixture.provider_data)
inference_models = ( inference_models = (
inference_model if isinstance(inference_model, list) else [inference_model] inference_model if isinstance(inference_model, list) else [inference_model]
) )
models = [
ModelInput(
model_id=model,
model_type=ModelType.llm,
provider_id=providers["inference"][0].provider_id,
)
for model in inference_models
]
models.append(
ModelInput(
model_id="all-MiniLM-L6-v2",
model_type=ModelType.embedding,
provider_id="agents_memory_provider",
metadata={"embedding_dimension": 384},
)
)
test_stack = await construct_stack_for_test( test_stack = await construct_stack_for_test(
[Api.agents, Api.inference, Api.safety, Api.memory], [Api.agents, Api.inference, Api.safety, Api.memory],
providers, providers,
provider_data, provider_data,
models=[ models=models,
ModelInput(
model_id=model,
)
for model in inference_models
],
shields=[safety_shield] if safety_shield else [], shields=[safety_shield] if safety_shield else [],
) )
return test_stack return test_stack

View file

@ -21,6 +21,7 @@ from pypdf import PdfReader
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.inference import InterleavedContent, TextContentItem
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.memory_banks import VectorMemoryBank from llama_stack.apis.memory_banks import VectorMemoryBank
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
@ -88,6 +89,26 @@ def content_from_data(data_url: str) -> str:
return "" return ""
def concat_interleaved_content(content: List[InterleavedContent]) -> InterleavedContent:
"""concatenate interleaved content into a single list. ensure that 'str's are converted to TextContentItem when in a list"""
ret = []
def _process(c):
if isinstance(c, str):
ret.append(TextContentItem(text=c))
elif isinstance(c, list):
for item in c:
_process(item)
else:
ret.append(c)
for c in content:
_process(c)
return ret
async def content_from_doc(doc: MemoryBankDocument) -> str: async def content_from_doc(doc: MemoryBankDocument) -> str:
if isinstance(doc.content, URL): if isinstance(doc.content, URL):
if doc.content.uri.startswith("data:"): if doc.content.uri.startswith("data:"):
@ -125,6 +146,7 @@ def make_overlapped_chunks(
for i in range(0, len(tokens), window_len - overlap_len): for i in range(0, len(tokens), window_len - overlap_len):
toks = tokens[i : i + window_len] toks = tokens[i : i + window_len]
chunk = tokenizer.decode(toks) chunk = tokenizer.decode(toks)
# chunk is a string
chunks.append( chunks.append(
Chunk(content=chunk, token_count=len(toks), document_id=document_id) Chunk(content=chunk, token_count=len(toks), document_id=document_id)
) )