mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 19:04:19 +00:00
[bugfix] fix case for agent when memory bank registered without specifying provider_id (#264)
* fix case where memory bank is registered without provider_id * memory test * agents unit test
This commit is contained in:
parent
9fcf5d58e0
commit
be3c5c034d
6 changed files with 151 additions and 5 deletions
|
@ -92,6 +92,21 @@ async def run_main(host: str, port: int, stream: bool):
|
|||
response = await client.list_memory_banks()
|
||||
cprint(f"list_memory_banks response={response}", "green")
|
||||
|
||||
# register memory bank for the first time
|
||||
response = await client.register_memory_bank(
|
||||
VectorMemoryBankDef(
|
||||
identifier="test_bank2",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
)
|
||||
)
|
||||
cprint(f"register_memory_bank response={response}", "blue")
|
||||
|
||||
# list again after registering
|
||||
response = await client.list_memory_banks()
|
||||
cprint(f"list_memory_banks response={response}", "green")
|
||||
|
||||
|
||||
def main(host: str, port: int, stream: bool = True):
|
||||
asyncio.run(run_main(host, port, stream))
|
||||
|
|
|
@ -110,10 +110,16 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
async def register_object(self, obj: RoutableObjectWithProvider):
|
||||
entries = self.registry.get(obj.identifier, [])
|
||||
for entry in entries:
|
||||
if entry.provider_id == obj.provider_id:
|
||||
print(f"`{obj.identifier}` already registered with `{obj.provider_id}`")
|
||||
if entry.provider_id == obj.provider_id or not obj.provider_id:
|
||||
print(
|
||||
f"`{obj.identifier}` already registered with `{entry.provider_id}`"
|
||||
)
|
||||
return
|
||||
|
||||
# if provider_id is not specified, we'll pick an arbitrary one from existing entries
|
||||
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
|
||||
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
|
||||
if obj.provider_id not in self.impls_by_provider_id:
|
||||
raise ValueError(f"Provider `{obj.provider_id}` not found")
|
||||
|
||||
|
|
|
@ -31,4 +31,4 @@ providers:
|
|||
persistence_store:
|
||||
namespace: null
|
||||
type: sqlite
|
||||
db_path: /Users/ashwin/.llama/runtime/kvstore.db
|
||||
db_path: ~/.llama/runtime/kvstore.db
|
||||
|
|
|
@ -64,6 +64,24 @@ def search_query_messages():
|
|||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def attachment_message():
|
||||
return [
|
||||
UserMessage(
|
||||
content="I am attaching some documentation for Torchtune. Help me answer questions I will ask next.",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def query_attachment_messages():
|
||||
return [
|
||||
UserMessage(
|
||||
content="What are the top 5 topics that were explained? Only list succinct bullet points."
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent_turn(agents_settings, sample_messages):
|
||||
agents_impl = agents_settings["impl"]
|
||||
|
@ -123,6 +141,89 @@ async def test_create_agent_turn(agents_settings, sample_messages):
|
|||
assert len(final_event.turn.output_message.content) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rag_agent_as_attachments(
|
||||
agents_settings, attachment_message, query_attachment_messages
|
||||
):
|
||||
urls = [
|
||||
"memory_optimizations.rst",
|
||||
"chat.rst",
|
||||
"llama3.rst",
|
||||
"datasets.rst",
|
||||
"qat_finetune.rst",
|
||||
"lora_finetune.rst",
|
||||
]
|
||||
|
||||
attachments = [
|
||||
Attachment(
|
||||
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
||||
mime_type="text/plain",
|
||||
)
|
||||
for i, url in enumerate(urls)
|
||||
]
|
||||
|
||||
agents_impl = agents_settings["impl"]
|
||||
|
||||
agent_config = AgentConfig(
|
||||
model=agents_settings["common_params"]["model"],
|
||||
instructions=agents_settings["common_params"]["instructions"],
|
||||
enable_session_persistence=True,
|
||||
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
||||
input_shields=[],
|
||||
output_shields=[],
|
||||
tools=[
|
||||
MemoryToolDefinition(
|
||||
memory_bank_configs=[],
|
||||
query_generator_config={
|
||||
"type": "default",
|
||||
"sep": " ",
|
||||
},
|
||||
max_tokens_in_context=4096,
|
||||
max_chunks=10,
|
||||
),
|
||||
],
|
||||
max_infer_iters=5,
|
||||
)
|
||||
|
||||
create_response = await agents_impl.create_agent(agent_config)
|
||||
agent_id = create_response.agent_id
|
||||
|
||||
# Create a session
|
||||
session_create_response = await agents_impl.create_agent_session(
|
||||
agent_id, "Test Session"
|
||||
)
|
||||
session_id = session_create_response.session_id
|
||||
|
||||
# Create and execute a turn
|
||||
turn_request = dict(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=attachment_message,
|
||||
attachments=attachments,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
turn_response = [
|
||||
chunk async for chunk in agents_impl.create_agent_turn(**turn_request)
|
||||
]
|
||||
|
||||
assert len(turn_response) > 0
|
||||
|
||||
# Create a second turn querying the agent
|
||||
turn_request = dict(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=query_attachment_messages,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
turn_response = [
|
||||
chunk async for chunk in agents_impl.create_agent_turn(**turn_request)
|
||||
]
|
||||
|
||||
assert len(turn_response) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent_turn_with_brave_search(
|
||||
agents_settings, search_query_messages
|
||||
|
|
|
@ -2,8 +2,8 @@ providers:
|
|||
- provider_id: test-faiss
|
||||
provider_type: meta-reference
|
||||
config: {}
|
||||
- provider_id: test-chroma
|
||||
provider_type: remote::chroma
|
||||
- provider_id: test-chromadb
|
||||
provider_type: remote::chromadb
|
||||
config:
|
||||
host: localhost
|
||||
port: 6001
|
||||
|
|
|
@ -89,6 +89,30 @@ async def test_banks_list(memory_settings):
|
|||
assert len(response) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_banks_register(memory_settings):
|
||||
# NOTE: this needs you to ensure that you are starting from a clean state
|
||||
# but so far we don't have an unregister API unfortunately, so be careful
|
||||
banks_impl = memory_settings["memory_banks_impl"]
|
||||
bank = VectorMemoryBankDef(
|
||||
identifier="test_bank_no_provider",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
)
|
||||
|
||||
await banks_impl.register_memory_bank(bank)
|
||||
response = await banks_impl.list_memory_banks()
|
||||
assert isinstance(response, list)
|
||||
assert len(response) == 1
|
||||
|
||||
# register same memory bank with same id again will fail
|
||||
await banks_impl.register_memory_bank(bank)
|
||||
response = await banks_impl.list_memory_banks()
|
||||
assert isinstance(response, list)
|
||||
assert len(response) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_documents(memory_settings, sample_documents):
|
||||
memory_impl = memory_settings["memory_impl"]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue