mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
[bugfix] fix case for agent when memory bank registered without specifying provider_id (#264)
* fix case where memory bank is registered without provider_id * memory test * agents unit test
This commit is contained in:
parent
9fcf5d58e0
commit
be3c5c034d
6 changed files with 151 additions and 5 deletions
|
@ -92,6 +92,21 @@ async def run_main(host: str, port: int, stream: bool):
|
||||||
response = await client.list_memory_banks()
|
response = await client.list_memory_banks()
|
||||||
cprint(f"list_memory_banks response={response}", "green")
|
cprint(f"list_memory_banks response={response}", "green")
|
||||||
|
|
||||||
|
# register memory bank for the first time
|
||||||
|
response = await client.register_memory_bank(
|
||||||
|
VectorMemoryBankDef(
|
||||||
|
identifier="test_bank2",
|
||||||
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
|
chunk_size_in_tokens=512,
|
||||||
|
overlap_size_in_tokens=64,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cprint(f"register_memory_bank response={response}", "blue")
|
||||||
|
|
||||||
|
# list again after registering
|
||||||
|
response = await client.list_memory_banks()
|
||||||
|
cprint(f"list_memory_banks response={response}", "green")
|
||||||
|
|
||||||
|
|
||||||
def main(host: str, port: int, stream: bool = True):
|
def main(host: str, port: int, stream: bool = True):
|
||||||
asyncio.run(run_main(host, port, stream))
|
asyncio.run(run_main(host, port, stream))
|
||||||
|
|
|
@ -110,10 +110,16 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
async def register_object(self, obj: RoutableObjectWithProvider):
|
async def register_object(self, obj: RoutableObjectWithProvider):
|
||||||
entries = self.registry.get(obj.identifier, [])
|
entries = self.registry.get(obj.identifier, [])
|
||||||
for entry in entries:
|
for entry in entries:
|
||||||
if entry.provider_id == obj.provider_id:
|
if entry.provider_id == obj.provider_id or not obj.provider_id:
|
||||||
print(f"`{obj.identifier}` already registered with `{obj.provider_id}`")
|
print(
|
||||||
|
f"`{obj.identifier}` already registered with `{entry.provider_id}`"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# if provider_id is not specified, we'll pick an arbitrary one from existing entries
|
||||||
|
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
|
||||||
|
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
|
||||||
if obj.provider_id not in self.impls_by_provider_id:
|
if obj.provider_id not in self.impls_by_provider_id:
|
||||||
raise ValueError(f"Provider `{obj.provider_id}` not found")
|
raise ValueError(f"Provider `{obj.provider_id}` not found")
|
||||||
|
|
||||||
|
|
|
@ -31,4 +31,4 @@ providers:
|
||||||
persistence_store:
|
persistence_store:
|
||||||
namespace: null
|
namespace: null
|
||||||
type: sqlite
|
type: sqlite
|
||||||
db_path: /Users/ashwin/.llama/runtime/kvstore.db
|
db_path: ~/.llama/runtime/kvstore.db
|
||||||
|
|
|
@ -64,6 +64,24 @@ def search_query_messages():
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def attachment_message():
|
||||||
|
return [
|
||||||
|
UserMessage(
|
||||||
|
content="I am attaching some documentation for Torchtune. Help me answer questions I will ask next.",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def query_attachment_messages():
|
||||||
|
return [
|
||||||
|
UserMessage(
|
||||||
|
content="What are the top 5 topics that were explained? Only list succinct bullet points."
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_agent_turn(agents_settings, sample_messages):
|
async def test_create_agent_turn(agents_settings, sample_messages):
|
||||||
agents_impl = agents_settings["impl"]
|
agents_impl = agents_settings["impl"]
|
||||||
|
@ -123,6 +141,89 @@ async def test_create_agent_turn(agents_settings, sample_messages):
|
||||||
assert len(final_event.turn.output_message.content) > 0
|
assert len(final_event.turn.output_message.content) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rag_agent_as_attachments(
|
||||||
|
agents_settings, attachment_message, query_attachment_messages
|
||||||
|
):
|
||||||
|
urls = [
|
||||||
|
"memory_optimizations.rst",
|
||||||
|
"chat.rst",
|
||||||
|
"llama3.rst",
|
||||||
|
"datasets.rst",
|
||||||
|
"qat_finetune.rst",
|
||||||
|
"lora_finetune.rst",
|
||||||
|
]
|
||||||
|
|
||||||
|
attachments = [
|
||||||
|
Attachment(
|
||||||
|
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
||||||
|
mime_type="text/plain",
|
||||||
|
)
|
||||||
|
for i, url in enumerate(urls)
|
||||||
|
]
|
||||||
|
|
||||||
|
agents_impl = agents_settings["impl"]
|
||||||
|
|
||||||
|
agent_config = AgentConfig(
|
||||||
|
model=agents_settings["common_params"]["model"],
|
||||||
|
instructions=agents_settings["common_params"]["instructions"],
|
||||||
|
enable_session_persistence=True,
|
||||||
|
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
||||||
|
input_shields=[],
|
||||||
|
output_shields=[],
|
||||||
|
tools=[
|
||||||
|
MemoryToolDefinition(
|
||||||
|
memory_bank_configs=[],
|
||||||
|
query_generator_config={
|
||||||
|
"type": "default",
|
||||||
|
"sep": " ",
|
||||||
|
},
|
||||||
|
max_tokens_in_context=4096,
|
||||||
|
max_chunks=10,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
max_infer_iters=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
create_response = await agents_impl.create_agent(agent_config)
|
||||||
|
agent_id = create_response.agent_id
|
||||||
|
|
||||||
|
# Create a session
|
||||||
|
session_create_response = await agents_impl.create_agent_session(
|
||||||
|
agent_id, "Test Session"
|
||||||
|
)
|
||||||
|
session_id = session_create_response.session_id
|
||||||
|
|
||||||
|
# Create and execute a turn
|
||||||
|
turn_request = dict(
|
||||||
|
agent_id=agent_id,
|
||||||
|
session_id=session_id,
|
||||||
|
messages=attachment_message,
|
||||||
|
attachments=attachments,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
turn_response = [
|
||||||
|
chunk async for chunk in agents_impl.create_agent_turn(**turn_request)
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(turn_response) > 0
|
||||||
|
|
||||||
|
# Create a second turn querying the agent
|
||||||
|
turn_request = dict(
|
||||||
|
agent_id=agent_id,
|
||||||
|
session_id=session_id,
|
||||||
|
messages=query_attachment_messages,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
turn_response = [
|
||||||
|
chunk async for chunk in agents_impl.create_agent_turn(**turn_request)
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(turn_response) > 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_agent_turn_with_brave_search(
|
async def test_create_agent_turn_with_brave_search(
|
||||||
agents_settings, search_query_messages
|
agents_settings, search_query_messages
|
||||||
|
|
|
@ -2,8 +2,8 @@ providers:
|
||||||
- provider_id: test-faiss
|
- provider_id: test-faiss
|
||||||
provider_type: meta-reference
|
provider_type: meta-reference
|
||||||
config: {}
|
config: {}
|
||||||
- provider_id: test-chroma
|
- provider_id: test-chromadb
|
||||||
provider_type: remote::chroma
|
provider_type: remote::chromadb
|
||||||
config:
|
config:
|
||||||
host: localhost
|
host: localhost
|
||||||
port: 6001
|
port: 6001
|
||||||
|
|
|
@ -89,6 +89,30 @@ async def test_banks_list(memory_settings):
|
||||||
assert len(response) == 0
|
assert len(response) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_banks_register(memory_settings):
|
||||||
|
# NOTE: this needs you to ensure that you are starting from a clean state
|
||||||
|
# but so far we don't have an unregister API unfortunately, so be careful
|
||||||
|
banks_impl = memory_settings["memory_banks_impl"]
|
||||||
|
bank = VectorMemoryBankDef(
|
||||||
|
identifier="test_bank_no_provider",
|
||||||
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
|
chunk_size_in_tokens=512,
|
||||||
|
overlap_size_in_tokens=64,
|
||||||
|
)
|
||||||
|
|
||||||
|
await banks_impl.register_memory_bank(bank)
|
||||||
|
response = await banks_impl.list_memory_banks()
|
||||||
|
assert isinstance(response, list)
|
||||||
|
assert len(response) == 1
|
||||||
|
|
||||||
|
# register same memory bank with same id again will fail
|
||||||
|
await banks_impl.register_memory_bank(bank)
|
||||||
|
response = await banks_impl.list_memory_banks()
|
||||||
|
assert isinstance(response, list)
|
||||||
|
assert len(response) == 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_query_documents(memory_settings, sample_documents):
|
async def test_query_documents(memory_settings, sample_documents):
|
||||||
memory_impl = memory_settings["memory_impl"]
|
memory_impl = memory_settings["memory_impl"]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue