small bug fixes for inline attachments

This commit is contained in:
Ashwin Bharambe 2024-08-24 23:51:27 -07:00
parent 58e2feceb0
commit 440d125ea0
2 changed files with 32 additions and 11 deletions

View file

@ -94,7 +94,7 @@ class AgenticSystemClient(AgenticSystem):
print(f"Error with parsing or validation: {e}") print(f"Error with parsing or validation: {e}")
async def _run_agent(api, tool_definitions, user_prompts): async def _run_agent(api, tool_definitions, user_prompts, attachments=None):
agent_config = AgentConfig( agent_config = AgentConfig(
model="Meta-Llama3.1-8B-Instruct", model="Meta-Llama3.1-8B-Instruct",
instructions="You are a helpful assistant", instructions="You are a helpful assistant",
@ -119,6 +119,7 @@ async def _run_agent(api, tool_definitions, user_prompts):
messages=[ messages=[
UserMessage(content=content), UserMessage(content=content),
], ],
attachments=attachments,
stream=True, stream=True,
) )
) )
@ -168,17 +169,36 @@ async def run_main(host: str, port: int):
async def run_rag(host: str, port: int): async def run_rag(host: str, port: int):
api = AgenticSystemClient(f"http://{host}:{port}") api = AgenticSystemClient(f"http://{host}:{port}")
# NOTE: for this, I ran `llama_toolchain.memory.client` first which urls = [
# populated the memory banks with torchtune docs. Then grabbed the bank_id "memory_optimizations.rst",
"chat.rst",
"llama3.rst",
"datasets.rst",
"qat_finetune.rst",
"lora_finetune.rst",
]
attachments = [
Attachment(
content=URL(
uri=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}"
),
mime_type="text/plain",
)
for i, url in enumerate(urls)
]
# Alternatively, you can pre-populate the memory bank with documents for example,
# using `llama_toolchain.memory.client`. Then you can grab the bank_id
# from the output of that run. # from the output of that run.
tool_definitions = [ tool_definitions = [
MemoryToolDefinition( MemoryToolDefinition(
max_tokens_in_context=2048, max_tokens_in_context=2048,
memory_bank_configs=[ memory_bank_configs=[],
AgenticSystemVectorMemoryBankConfig( # memory_bank_configs=[
bank_id="970b8790-268e-4fd3-a9b1-d0e597e975ed", # AgenticSystemVectorMemoryBankConfig(
) # bank_id="970b8790-268e-4fd3-a9b1-d0e597e975ed",
], # )
# ],
), ),
] ]
@ -187,11 +207,11 @@ async def run_rag(host: str, port: int):
"Tell me briefly about llama3 and torchtune", "Tell me briefly about llama3 and torchtune",
] ]
await _run_agent(api, tool_definitions, user_prompts) await _run_agent(api, tool_definitions, user_prompts, attachments)
def main(host: str, port: int): def main(host: str, port: int):
asyncio.run(run_main(host, port)) asyncio.run(run_rag(host, port))
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -580,6 +580,7 @@ class ChatAgent(ShieldRunnerMixin):
name=f"memory_bank_{session.session_id}", name=f"memory_bank_{session.session_id}",
config=VectorMemoryBankConfig( config=VectorMemoryBankConfig(
embedding_model="sentence-transformer/all-MiniLM-L6-v2", embedding_model="sentence-transformer/all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
), ),
) )
@ -619,7 +620,7 @@ class ChatAgent(ShieldRunnerMixin):
documents = [ documents = [
MemoryBankDocument( MemoryBankDocument(
doc_id=str(uuid.uuid4()), document_id=str(uuid.uuid4()),
content=a.content, content=a.content,
mime_type=a.mime_type, mime_type=a.mime_type,
metadata={}, metadata={},