Fix agent executor

This commit is contained in:
Ashwin Bharambe 2024-12-16 11:21:51 -08:00
parent 59ce047aea
commit e0731ba353
4 changed files with 62 additions and 16 deletions

View file

@ -10,8 +10,6 @@ from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
from llama_stack.apis.common.deployment_types import URL
@json_schema_type
class PostTrainingMetric(BaseModel):

View file

@ -26,6 +26,7 @@ from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
from llama_stack.providers.utils.telemetry import tracing
from .persistence import AgentPersistence
@ -387,7 +388,7 @@ class ChatAgent(ShieldRunnerMixin):
if rag_context:
last_message = input_messages[-1]
last_message.context = "\n".join(rag_context)
last_message.context = rag_context
elif attachments and AgentTool.code_interpreter.value in enabled_tools:
urls = [a.content for a in attachments if isinstance(a.content, URL)]
@ -687,7 +688,7 @@ class ChatAgent(ShieldRunnerMixin):
async def _retrieve_context(
self, session_id: str, messages: List[Message], attachments: List[Attachment]
) -> Tuple[Optional[List[str]], Optional[List[int]]]: # (rag_context, bank_ids)
) -> Tuple[Optional[InterleavedContent], List[int]]: # (rag_context, bank_ids)
bank_ids = []
memory = self._memory_tool_definition()
@ -755,11 +756,16 @@ class ChatAgent(ShieldRunnerMixin):
break
picked.append(f"id:{c.document_id}; content:{c.content}")
return [
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
*picked,
"\n=== END-RETRIEVED-CONTEXT ===\n",
], bank_ids
return (
concat_interleaved_content(
[
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
*picked,
"\n=== END-RETRIEVED-CONTEXT ===\n",
]
),
bank_ids,
)
def _get_tools(self) -> List[ToolDefinition]:
ret = []

View file

@ -9,7 +9,7 @@ import tempfile
import pytest
import pytest_asyncio
from llama_stack.apis.models import ModelInput
from llama_stack.apis.models import ModelInput, ModelType
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.agents.meta_reference import (
@ -67,22 +67,42 @@ async def agents_stack(request, inference_model, safety_shield):
for key in ["inference", "safety", "memory", "agents"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if key == "inference":
providers[key].append(
Provider(
provider_id="agents_memory_provider",
provider_type="inline::sentence-transformers",
config={},
)
)
if fixture.provider_data:
provider_data.update(fixture.provider_data)
inference_models = (
inference_model if isinstance(inference_model, list) else [inference_model]
)
models = [
ModelInput(
model_id=model,
model_type=ModelType.llm,
provider_id=providers["inference"][0].provider_id,
)
for model in inference_models
]
models.append(
ModelInput(
model_id="all-MiniLM-L6-v2",
model_type=ModelType.embedding,
provider_id="agents_memory_provider",
metadata={"embedding_dimension": 384},
)
)
test_stack = await construct_stack_for_test(
[Api.agents, Api.inference, Api.safety, Api.memory],
providers,
provider_data,
models=[
ModelInput(
model_id=model,
)
for model in inference_models
],
models=models,
shields=[safety_shield] if safety_shield else [],
)
return test_stack

View file

@ -21,6 +21,7 @@ from pypdf import PdfReader
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.inference import InterleavedContent, TextContentItem
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.memory_banks import VectorMemoryBank
from llama_stack.providers.datatypes import Api
@ -88,6 +89,26 @@ def content_from_data(data_url: str) -> str:
return ""
def concat_interleaved_content(content: List[InterleavedContent]) -> InterleavedContent:
"""concatenate interleaved content into a single list. ensure that 'str's are converted to TextContentItem when in a list"""
ret = []
def _process(c):
if isinstance(c, str):
ret.append(TextContentItem(text=c))
elif isinstance(c, list):
for item in c:
_process(item)
else:
ret.append(c)
for c in content:
_process(c)
return ret
async def content_from_doc(doc: MemoryBankDocument) -> str:
if isinstance(doc.content, URL):
if doc.content.uri.startswith("data:"):
@ -125,6 +146,7 @@ def make_overlapped_chunks(
for i in range(0, len(tokens), window_len - overlap_len):
toks = tokens[i : i + window_len]
chunk = tokenizer.decode(toks)
# chunk is a string
chunks.append(
Chunk(content=chunk, token_count=len(toks), document_id=document_id)
)