mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
small bug fixes for inline attachments
This commit is contained in:
parent
58e2feceb0
commit
440d125ea0
2 changed files with 32 additions and 11 deletions
|
@ -94,7 +94,7 @@ class AgenticSystemClient(AgenticSystem):
|
||||||
print(f"Error with parsing or validation: {e}")
|
print(f"Error with parsing or validation: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def _run_agent(api, tool_definitions, user_prompts):
|
async def _run_agent(api, tool_definitions, user_prompts, attachments=None):
|
||||||
agent_config = AgentConfig(
|
agent_config = AgentConfig(
|
||||||
model="Meta-Llama3.1-8B-Instruct",
|
model="Meta-Llama3.1-8B-Instruct",
|
||||||
instructions="You are a helpful assistant",
|
instructions="You are a helpful assistant",
|
||||||
|
@ -119,6 +119,7 @@ async def _run_agent(api, tool_definitions, user_prompts):
|
||||||
messages=[
|
messages=[
|
||||||
UserMessage(content=content),
|
UserMessage(content=content),
|
||||||
],
|
],
|
||||||
|
attachments=attachments,
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -168,17 +169,36 @@ async def run_main(host: str, port: int):
|
||||||
async def run_rag(host: str, port: int):
|
async def run_rag(host: str, port: int):
|
||||||
api = AgenticSystemClient(f"http://{host}:{port}")
|
api = AgenticSystemClient(f"http://{host}:{port}")
|
||||||
|
|
||||||
# NOTE: for this, I ran `llama_toolchain.memory.client` first which
|
urls = [
|
||||||
# populated the memory banks with torchtune docs. Then grabbed the bank_id
|
"memory_optimizations.rst",
|
||||||
|
"chat.rst",
|
||||||
|
"llama3.rst",
|
||||||
|
"datasets.rst",
|
||||||
|
"qat_finetune.rst",
|
||||||
|
"lora_finetune.rst",
|
||||||
|
]
|
||||||
|
attachments = [
|
||||||
|
Attachment(
|
||||||
|
content=URL(
|
||||||
|
uri=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}"
|
||||||
|
),
|
||||||
|
mime_type="text/plain",
|
||||||
|
)
|
||||||
|
for i, url in enumerate(urls)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Alternatively, you can pre-populate the memory bank with documents for example,
|
||||||
|
# using `llama_toolchain.memory.client`. Then you can grab the bank_id
|
||||||
# from the output of that run.
|
# from the output of that run.
|
||||||
tool_definitions = [
|
tool_definitions = [
|
||||||
MemoryToolDefinition(
|
MemoryToolDefinition(
|
||||||
max_tokens_in_context=2048,
|
max_tokens_in_context=2048,
|
||||||
memory_bank_configs=[
|
memory_bank_configs=[],
|
||||||
AgenticSystemVectorMemoryBankConfig(
|
# memory_bank_configs=[
|
||||||
bank_id="970b8790-268e-4fd3-a9b1-d0e597e975ed",
|
# AgenticSystemVectorMemoryBankConfig(
|
||||||
)
|
# bank_id="970b8790-268e-4fd3-a9b1-d0e597e975ed",
|
||||||
],
|
# )
|
||||||
|
# ],
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -187,11 +207,11 @@ async def run_rag(host: str, port: int):
|
||||||
"Tell me briefly about llama3 and torchtune",
|
"Tell me briefly about llama3 and torchtune",
|
||||||
]
|
]
|
||||||
|
|
||||||
await _run_agent(api, tool_definitions, user_prompts)
|
await _run_agent(api, tool_definitions, user_prompts, attachments)
|
||||||
|
|
||||||
|
|
||||||
def main(host: str, port: int):
|
def main(host: str, port: int):
|
||||||
asyncio.run(run_main(host, port))
|
asyncio.run(run_rag(host, port))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -580,6 +580,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
name=f"memory_bank_{session.session_id}",
|
name=f"memory_bank_{session.session_id}",
|
||||||
config=VectorMemoryBankConfig(
|
config=VectorMemoryBankConfig(
|
||||||
embedding_model="sentence-transformer/all-MiniLM-L6-v2",
|
embedding_model="sentence-transformer/all-MiniLM-L6-v2",
|
||||||
|
chunk_size_in_tokens=512,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -619,7 +620,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
documents = [
|
documents = [
|
||||||
MemoryBankDocument(
|
MemoryBankDocument(
|
||||||
doc_id=str(uuid.uuid4()),
|
document_id=str(uuid.uuid4()),
|
||||||
content=a.content,
|
content=a.content,
|
||||||
mime_type=a.mime_type,
|
mime_type=a.mime_type,
|
||||||
metadata={},
|
metadata={},
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue