mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
basic RAG seems to work
This commit is contained in:
parent
830252257b
commit
58e2feceb0
3 changed files with 96 additions and 44 deletions
|
@ -82,6 +82,10 @@ class AgenticSystemClient(AgenticSystem):
|
||||||
if line.startswith("data:"):
|
if line.startswith("data:"):
|
||||||
data = line[len("data: ") :]
|
data = line[len("data: ") :]
|
||||||
try:
|
try:
|
||||||
|
if "error" in data:
|
||||||
|
cprint(data, "red")
|
||||||
|
continue
|
||||||
|
|
||||||
yield AgenticSystemTurnResponseStreamChunk(
|
yield AgenticSystemTurnResponseStreamChunk(
|
||||||
**json.loads(data)
|
**json.loads(data)
|
||||||
)
|
)
|
||||||
|
@ -90,8 +94,41 @@ class AgenticSystemClient(AgenticSystem):
|
||||||
print(f"Error with parsing or validation: {e}")
|
print(f"Error with parsing or validation: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_agent(api, tool_definitions, user_prompts):
|
||||||
|
agent_config = AgentConfig(
|
||||||
|
model="Meta-Llama3.1-8B-Instruct",
|
||||||
|
instructions="You are a helpful assistant",
|
||||||
|
sampling_params=SamplingParams(temperature=1.0, top_p=0.9),
|
||||||
|
tools=tool_definitions,
|
||||||
|
tool_choice=ToolChoice.auto,
|
||||||
|
tool_prompt_format=ToolPromptFormat.function_tag,
|
||||||
|
)
|
||||||
|
|
||||||
|
create_response = await api.create_agentic_system(agent_config)
|
||||||
|
session_response = await api.create_agentic_system_session(
|
||||||
|
agent_id=create_response.agent_id,
|
||||||
|
session_name="test_session",
|
||||||
|
)
|
||||||
|
|
||||||
|
for content in user_prompts:
|
||||||
|
cprint(f"User> {content}", color="blue")
|
||||||
|
iterator = api.create_agentic_system_turn(
|
||||||
|
AgenticSystemTurnCreateRequest(
|
||||||
|
agent_id=create_response.agent_id,
|
||||||
|
session_id=session_response.session_id,
|
||||||
|
messages=[
|
||||||
|
UserMessage(content=content),
|
||||||
|
],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async for event, log in EventLogger().log(iterator):
|
||||||
|
if log is not None:
|
||||||
|
log.print()
|
||||||
|
|
||||||
|
|
||||||
async def run_main(host: str, port: int):
|
async def run_main(host: str, port: int):
|
||||||
# client to test remote impl of agentic system
|
|
||||||
api = AgenticSystemClient(f"http://{host}:{port}")
|
api = AgenticSystemClient(f"http://{host}:{port}")
|
||||||
|
|
||||||
tool_definitions = [
|
tool_definitions = [
|
||||||
|
@ -118,24 +155,6 @@ async def run_main(host: str, port: int):
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
agent_config = AgentConfig(
|
|
||||||
model="Meta-Llama3.1-8B-Instruct",
|
|
||||||
instructions="You are a helpful assistant",
|
|
||||||
sampling_params=SamplingParams(temperature=1.0, top_p=0.9),
|
|
||||||
tools=tool_definitions,
|
|
||||||
tool_choice=ToolChoice.auto,
|
|
||||||
tool_prompt_format=ToolPromptFormat.function_tag,
|
|
||||||
)
|
|
||||||
|
|
||||||
create_response = await api.create_agentic_system(agent_config)
|
|
||||||
print(create_response)
|
|
||||||
|
|
||||||
session_response = await api.create_agentic_system_session(
|
|
||||||
agent_id=create_response.agent_id,
|
|
||||||
session_name="test_session",
|
|
||||||
)
|
|
||||||
print(session_response)
|
|
||||||
|
|
||||||
user_prompts = [
|
user_prompts = [
|
||||||
"Who are you?",
|
"Who are you?",
|
||||||
"what is the 100th prime number?",
|
"what is the 100th prime number?",
|
||||||
|
@ -143,22 +162,32 @@ async def run_main(host: str, port: int):
|
||||||
"Write code to check if a number is prime. Use that to check if 7 is prime",
|
"Write code to check if a number is prime. Use that to check if 7 is prime",
|
||||||
"What is the boiling point of polyjuicepotion ?",
|
"What is the boiling point of polyjuicepotion ?",
|
||||||
]
|
]
|
||||||
for content in user_prompts:
|
await _run_agent(api, tool_definitions, user_prompts)
|
||||||
cprint(f"User> {content}", color="blue")
|
|
||||||
iterator = api.create_agentic_system_turn(
|
|
||||||
AgenticSystemTurnCreateRequest(
|
|
||||||
agent_id=create_response.agent_id,
|
|
||||||
session_id=session_response.session_id,
|
|
||||||
messages=[
|
|
||||||
UserMessage(content=content),
|
|
||||||
],
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
async for event, log in EventLogger().log(iterator):
|
|
||||||
if log is not None:
|
async def run_rag(host: str, port: int):
|
||||||
log.print()
|
api = AgenticSystemClient(f"http://{host}:{port}")
|
||||||
|
|
||||||
|
# NOTE: for this, I ran `llama_toolchain.memory.client` first which
|
||||||
|
# populated the memory banks with torchtune docs. Then grabbed the bank_id
|
||||||
|
# from the output of that run.
|
||||||
|
tool_definitions = [
|
||||||
|
MemoryToolDefinition(
|
||||||
|
max_tokens_in_context=2048,
|
||||||
|
memory_bank_configs=[
|
||||||
|
AgenticSystemVectorMemoryBankConfig(
|
||||||
|
bank_id="970b8790-268e-4fd3-a9b1-d0e597e975ed",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
user_prompts = [
|
||||||
|
"How do I use Lora?",
|
||||||
|
"Tell me briefly about llama3 and torchtune",
|
||||||
|
]
|
||||||
|
|
||||||
|
await _run_agent(api, tool_definitions, user_prompts)
|
||||||
|
|
||||||
|
|
||||||
def main(host: str, port: int):
|
def main(host: str, port: int):
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import copy
|
import copy
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
@ -304,7 +304,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
# TODO: find older context from the session and either replace it
|
# TODO: find older context from the session and either replace it
|
||||||
# or append with a sliding window. this is really a very simplistic implementation
|
# or append with a sliding window. this is really a very simplistic implementation
|
||||||
rag_context, bank_ids = await self._retrieve_context(input_messages)
|
rag_context, bank_ids = await self._retrieve_context(
|
||||||
|
session, input_messages, attachments
|
||||||
|
)
|
||||||
|
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
yield AgenticSystemTurnResponseStreamChunk(
|
yield AgenticSystemTurnResponseStreamChunk(
|
||||||
|
@ -313,20 +315,24 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
step_type=StepType.memory_retrieval.value,
|
step_type=StepType.memory_retrieval.value,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
step_details=MemoryRetrievalStep(
|
step_details=MemoryRetrievalStep(
|
||||||
|
turn_id=turn_id,
|
||||||
|
step_id=step_id,
|
||||||
memory_bank_ids=bank_ids,
|
memory_bank_ids=bank_ids,
|
||||||
inserted_context=rag_context,
|
inserted_context=rag_context or "",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if rag_context:
|
if rag_context:
|
||||||
system_message = next(m for m in input_messages if m.role == "system")
|
system_message = next(
|
||||||
|
(m for m in input_messages if m.role == "system"), None
|
||||||
|
)
|
||||||
if system_message:
|
if system_message:
|
||||||
system_message.content = system_message.content + "\n" + rag_context
|
system_message.content = system_message.content + "\n" + rag_context
|
||||||
else:
|
else:
|
||||||
input_messages = [
|
input_messages = [
|
||||||
Message(role="system", content=rag_context)
|
SystemMessage(content=rag_context)
|
||||||
] + input_messages
|
] + input_messages
|
||||||
|
|
||||||
elif attachments and AgenticSystemTool.code_interpreter.value in enabled_tools:
|
elif attachments and AgenticSystemTool.code_interpreter.value in enabled_tools:
|
||||||
|
@ -644,7 +650,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
|
*sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
|
||||||
)
|
)
|
||||||
if not chunks:
|
if not chunks:
|
||||||
return None
|
return None, bank_ids
|
||||||
|
|
||||||
tokens = 0
|
tokens = 0
|
||||||
picked = []
|
picked = []
|
||||||
|
@ -656,13 +662,13 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
"red",
|
"red",
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
picked.append(c)
|
picked.append(c.content)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
"The following context was retrieved from the memory bank:\n=== START-RETRIEVED-CONTEXT ===\n",
|
"The following context was retrieved from the memory bank:\n=== START-RETRIEVED-CONTEXT ===\n",
|
||||||
*picked,
|
*picked,
|
||||||
"\n=== END-RETRIEVED-CONTEXT ===\n",
|
"\n=== END-RETRIEVED-CONTEXT ===\n",
|
||||||
]
|
], bank_ids
|
||||||
|
|
||||||
def _get_tools(self) -> List[ToolDefinition]:
|
def _get_tools(self) -> List[ToolDefinition]:
|
||||||
ret = []
|
ret = []
|
||||||
|
|
|
@ -1,4 +1,8 @@
|
||||||
import textwrap
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_toolchain.inference.api import * # noqa: F403
|
from llama_toolchain.inference.api import * # noqa: F403
|
||||||
|
@ -41,8 +45,21 @@ def prepare_messages(request: ChatCompletionRequest) -> List[Message]:
|
||||||
sys_content += default_template.render()
|
sys_content += default_template.render()
|
||||||
|
|
||||||
if existing_system_message:
|
if existing_system_message:
|
||||||
|
# TODO: this fn is needed in many places
|
||||||
|
def _process(c):
|
||||||
|
if isinstance(c, str):
|
||||||
|
return c
|
||||||
|
else:
|
||||||
|
return "<media>"
|
||||||
|
|
||||||
sys_content += "\n"
|
sys_content += "\n"
|
||||||
sys_content += existing_system_message.content
|
|
||||||
|
if isinstance(existing_system_message.content, str):
|
||||||
|
sys_content += _process(existing_system_message.content)
|
||||||
|
elif isinstance(existing_system_message.content, list):
|
||||||
|
sys_content += "\n".join(
|
||||||
|
[_process(c) for c in existing_system_message.content]
|
||||||
|
)
|
||||||
|
|
||||||
messages.append(SystemMessage(content=sys_content))
|
messages.append(SystemMessage(content=sys_content))
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue