mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 01:03:59 +00:00
Fix agent executor
This commit is contained in:
parent
59ce047aea
commit
e0731ba353
4 changed files with 62 additions and 16 deletions
|
@ -10,8 +10,6 @@ from typing import Optional
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.common.deployment_types import URL
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class PostTrainingMetric(BaseModel):
|
class PostTrainingMetric(BaseModel):
|
||||||
|
|
|
@ -26,6 +26,7 @@ from llama_stack.apis.memory_banks import * # noqa: F403
|
||||||
from llama_stack.apis.safety import * # noqa: F403
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
||||||
from llama_stack.providers.utils.telemetry import tracing
|
from llama_stack.providers.utils.telemetry import tracing
|
||||||
|
|
||||||
from .persistence import AgentPersistence
|
from .persistence import AgentPersistence
|
||||||
|
@ -387,7 +388,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
if rag_context:
|
if rag_context:
|
||||||
last_message = input_messages[-1]
|
last_message = input_messages[-1]
|
||||||
last_message.context = "\n".join(rag_context)
|
last_message.context = rag_context
|
||||||
|
|
||||||
elif attachments and AgentTool.code_interpreter.value in enabled_tools:
|
elif attachments and AgentTool.code_interpreter.value in enabled_tools:
|
||||||
urls = [a.content for a in attachments if isinstance(a.content, URL)]
|
urls = [a.content for a in attachments if isinstance(a.content, URL)]
|
||||||
|
@ -687,7 +688,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
async def _retrieve_context(
|
async def _retrieve_context(
|
||||||
self, session_id: str, messages: List[Message], attachments: List[Attachment]
|
self, session_id: str, messages: List[Message], attachments: List[Attachment]
|
||||||
) -> Tuple[Optional[List[str]], Optional[List[int]]]: # (rag_context, bank_ids)
|
) -> Tuple[Optional[InterleavedContent], List[int]]: # (rag_context, bank_ids)
|
||||||
bank_ids = []
|
bank_ids = []
|
||||||
|
|
||||||
memory = self._memory_tool_definition()
|
memory = self._memory_tool_definition()
|
||||||
|
@ -755,11 +756,16 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
break
|
break
|
||||||
picked.append(f"id:{c.document_id}; content:{c.content}")
|
picked.append(f"id:{c.document_id}; content:{c.content}")
|
||||||
|
|
||||||
return [
|
return (
|
||||||
|
concat_interleaved_content(
|
||||||
|
[
|
||||||
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
|
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
|
||||||
*picked,
|
*picked,
|
||||||
"\n=== END-RETRIEVED-CONTEXT ===\n",
|
"\n=== END-RETRIEVED-CONTEXT ===\n",
|
||||||
], bank_ids
|
]
|
||||||
|
),
|
||||||
|
bank_ids,
|
||||||
|
)
|
||||||
|
|
||||||
def _get_tools(self) -> List[ToolDefinition]:
|
def _get_tools(self) -> List[ToolDefinition]:
|
||||||
ret = []
|
ret = []
|
||||||
|
|
|
@ -9,7 +9,7 @@ import tempfile
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
|
||||||
from llama_stack.apis.models import ModelInput
|
from llama_stack.apis.models import ModelInput, ModelType
|
||||||
from llama_stack.distribution.datatypes import Api, Provider
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
|
|
||||||
from llama_stack.providers.inline.agents.meta_reference import (
|
from llama_stack.providers.inline.agents.meta_reference import (
|
||||||
|
@ -67,22 +67,42 @@ async def agents_stack(request, inference_model, safety_shield):
|
||||||
for key in ["inference", "safety", "memory", "agents"]:
|
for key in ["inference", "safety", "memory", "agents"]:
|
||||||
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||||
providers[key] = fixture.providers
|
providers[key] = fixture.providers
|
||||||
|
if key == "inference":
|
||||||
|
providers[key].append(
|
||||||
|
Provider(
|
||||||
|
provider_id="agents_memory_provider",
|
||||||
|
provider_type="inline::sentence-transformers",
|
||||||
|
config={},
|
||||||
|
)
|
||||||
|
)
|
||||||
if fixture.provider_data:
|
if fixture.provider_data:
|
||||||
provider_data.update(fixture.provider_data)
|
provider_data.update(fixture.provider_data)
|
||||||
|
|
||||||
inference_models = (
|
inference_models = (
|
||||||
inference_model if isinstance(inference_model, list) else [inference_model]
|
inference_model if isinstance(inference_model, list) else [inference_model]
|
||||||
)
|
)
|
||||||
|
models = [
|
||||||
|
ModelInput(
|
||||||
|
model_id=model,
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
provider_id=providers["inference"][0].provider_id,
|
||||||
|
)
|
||||||
|
for model in inference_models
|
||||||
|
]
|
||||||
|
models.append(
|
||||||
|
ModelInput(
|
||||||
|
model_id="all-MiniLM-L6-v2",
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
provider_id="agents_memory_provider",
|
||||||
|
metadata={"embedding_dimension": 384},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
test_stack = await construct_stack_for_test(
|
test_stack = await construct_stack_for_test(
|
||||||
[Api.agents, Api.inference, Api.safety, Api.memory],
|
[Api.agents, Api.inference, Api.safety, Api.memory],
|
||||||
providers,
|
providers,
|
||||||
provider_data,
|
provider_data,
|
||||||
models=[
|
models=models,
|
||||||
ModelInput(
|
|
||||||
model_id=model,
|
|
||||||
)
|
|
||||||
for model in inference_models
|
|
||||||
],
|
|
||||||
shields=[safety_shield] if safety_shield else [],
|
shields=[safety_shield] if safety_shield else [],
|
||||||
)
|
)
|
||||||
return test_stack
|
return test_stack
|
||||||
|
|
|
@ -21,6 +21,7 @@ from pypdf import PdfReader
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import InterleavedContent, TextContentItem
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
from llama_stack.apis.memory_banks import VectorMemoryBank
|
from llama_stack.apis.memory_banks import VectorMemoryBank
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
@ -88,6 +89,26 @@ def content_from_data(data_url: str) -> str:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def concat_interleaved_content(content: List[InterleavedContent]) -> InterleavedContent:
|
||||||
|
"""concatenate interleaved content into a single list. ensure that 'str's are converted to TextContentItem when in a list"""
|
||||||
|
|
||||||
|
ret = []
|
||||||
|
|
||||||
|
def _process(c):
|
||||||
|
if isinstance(c, str):
|
||||||
|
ret.append(TextContentItem(text=c))
|
||||||
|
elif isinstance(c, list):
|
||||||
|
for item in c:
|
||||||
|
_process(item)
|
||||||
|
else:
|
||||||
|
ret.append(c)
|
||||||
|
|
||||||
|
for c in content:
|
||||||
|
_process(c)
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
async def content_from_doc(doc: MemoryBankDocument) -> str:
|
async def content_from_doc(doc: MemoryBankDocument) -> str:
|
||||||
if isinstance(doc.content, URL):
|
if isinstance(doc.content, URL):
|
||||||
if doc.content.uri.startswith("data:"):
|
if doc.content.uri.startswith("data:"):
|
||||||
|
@ -125,6 +146,7 @@ def make_overlapped_chunks(
|
||||||
for i in range(0, len(tokens), window_len - overlap_len):
|
for i in range(0, len(tokens), window_len - overlap_len):
|
||||||
toks = tokens[i : i + window_len]
|
toks = tokens[i : i + window_len]
|
||||||
chunk = tokenizer.decode(toks)
|
chunk = tokenizer.decode(toks)
|
||||||
|
# chunk is a string
|
||||||
chunks.append(
|
chunks.append(
|
||||||
Chunk(content=chunk, token_count=len(toks), document_id=document_id)
|
Chunk(content=chunk, token_count=len(toks), document_id=document_id)
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue