mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 20:14:13 +00:00
basic RAG seems to work
This commit is contained in:
parent
830252257b
commit
58e2feceb0
3 changed files with 96 additions and 44 deletions
|
@ -82,6 +82,10 @@ class AgenticSystemClient(AgenticSystem):
|
|||
if line.startswith("data:"):
|
||||
data = line[len("data: ") :]
|
||||
try:
|
||||
if "error" in data:
|
||||
cprint(data, "red")
|
||||
continue
|
||||
|
||||
yield AgenticSystemTurnResponseStreamChunk(
|
||||
**json.loads(data)
|
||||
)
|
||||
|
@ -90,8 +94,41 @@ class AgenticSystemClient(AgenticSystem):
|
|||
print(f"Error with parsing or validation: {e}")
|
||||
|
||||
|
||||
async def _run_agent(api, tool_definitions, user_prompts):
|
||||
agent_config = AgentConfig(
|
||||
model="Meta-Llama3.1-8B-Instruct",
|
||||
instructions="You are a helpful assistant",
|
||||
sampling_params=SamplingParams(temperature=1.0, top_p=0.9),
|
||||
tools=tool_definitions,
|
||||
tool_choice=ToolChoice.auto,
|
||||
tool_prompt_format=ToolPromptFormat.function_tag,
|
||||
)
|
||||
|
||||
create_response = await api.create_agentic_system(agent_config)
|
||||
session_response = await api.create_agentic_system_session(
|
||||
agent_id=create_response.agent_id,
|
||||
session_name="test_session",
|
||||
)
|
||||
|
||||
for content in user_prompts:
|
||||
cprint(f"User> {content}", color="blue")
|
||||
iterator = api.create_agentic_system_turn(
|
||||
AgenticSystemTurnCreateRequest(
|
||||
agent_id=create_response.agent_id,
|
||||
session_id=session_response.session_id,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
)
|
||||
|
||||
async for event, log in EventLogger().log(iterator):
|
||||
if log is not None:
|
||||
log.print()
|
||||
|
||||
|
||||
async def run_main(host: str, port: int):
|
||||
# client to test remote impl of agentic system
|
||||
api = AgenticSystemClient(f"http://{host}:{port}")
|
||||
|
||||
tool_definitions = [
|
||||
|
@ -118,24 +155,6 @@ async def run_main(host: str, port: int):
|
|||
),
|
||||
]
|
||||
|
||||
agent_config = AgentConfig(
|
||||
model="Meta-Llama3.1-8B-Instruct",
|
||||
instructions="You are a helpful assistant",
|
||||
sampling_params=SamplingParams(temperature=1.0, top_p=0.9),
|
||||
tools=tool_definitions,
|
||||
tool_choice=ToolChoice.auto,
|
||||
tool_prompt_format=ToolPromptFormat.function_tag,
|
||||
)
|
||||
|
||||
create_response = await api.create_agentic_system(agent_config)
|
||||
print(create_response)
|
||||
|
||||
session_response = await api.create_agentic_system_session(
|
||||
agent_id=create_response.agent_id,
|
||||
session_name="test_session",
|
||||
)
|
||||
print(session_response)
|
||||
|
||||
user_prompts = [
|
||||
"Who are you?",
|
||||
"what is the 100th prime number?",
|
||||
|
@ -143,22 +162,32 @@ async def run_main(host: str, port: int):
|
|||
"Write code to check if a number is prime. Use that to check if 7 is prime",
|
||||
"What is the boiling point of polyjuicepotion ?",
|
||||
]
|
||||
for content in user_prompts:
|
||||
cprint(f"User> {content}", color="blue")
|
||||
iterator = api.create_agentic_system_turn(
|
||||
AgenticSystemTurnCreateRequest(
|
||||
agent_id=create_response.agent_id,
|
||||
session_id=session_response.session_id,
|
||||
messages=[
|
||||
UserMessage(content=content),
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
)
|
||||
await _run_agent(api, tool_definitions, user_prompts)
|
||||
|
||||
async for event, log in EventLogger().log(iterator):
|
||||
if log is not None:
|
||||
log.print()
|
||||
|
||||
async def run_rag(host: str, port: int):
|
||||
api = AgenticSystemClient(f"http://{host}:{port}")
|
||||
|
||||
# NOTE: for this, I ran `llama_toolchain.memory.client` first which
|
||||
# populated the memory banks with torchtune docs. Then grabbed the bank_id
|
||||
# from the output of that run.
|
||||
tool_definitions = [
|
||||
MemoryToolDefinition(
|
||||
max_tokens_in_context=2048,
|
||||
memory_bank_configs=[
|
||||
AgenticSystemVectorMemoryBankConfig(
|
||||
bank_id="970b8790-268e-4fd3-a9b1-d0e597e975ed",
|
||||
)
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
user_prompts = [
|
||||
"How do I use Lora?",
|
||||
"Tell me briefly about llama3 and torchtune",
|
||||
]
|
||||
|
||||
await _run_agent(api, tool_definitions, user_prompts)
|
||||
|
||||
|
||||
def main(host: str, port: int):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue