mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
This is yet another of those large PRs (hopefully we will have less and less of them as things mature fast). This one introduces substantial improvements and some simplifications to the stack. Most important bits: * Agents reference implementation now has support for session / turn persistence. The default implementation uses sqlite but there's also support for using Redis. * We have re-architected the structure of the Stack APIs to allow for more flexible routing. The motivating use cases are: - routing model A to ollama and model B to a remote provider like Together - routing shield A to local impl while shield B to a remote provider like Bedrock - routing a vector memory bank to Weaviate while routing a keyvalue memory bank to Redis * Support for provider specific parameters to be passed from the clients. A client can pass data using `x_llamastack_provider_data` parameter which can be type-checked and provided to the Adapter implementations.
218 lines
6.8 KiB
Python
218 lines
6.8 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import asyncio
|
|
import json
|
|
import os
|
|
from typing import AsyncGenerator
|
|
|
|
import fire
|
|
import httpx
|
|
from dotenv import load_dotenv
|
|
|
|
from pydantic import BaseModel
|
|
from termcolor import cprint
|
|
|
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
|
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
|
|
|
from .agents import * # noqa: F403
|
|
from .event_logger import EventLogger
|
|
|
|
|
|
load_dotenv()
|
|
|
|
|
|
async def get_client_impl(config: RemoteProviderConfig, _deps):
|
|
return AgentsClient(config.url)
|
|
|
|
|
|
def encodable_dict(d: BaseModel):
|
|
return json.loads(d.json())
|
|
|
|
|
|
class AgentsClient(Agents):
|
|
def __init__(self, base_url: str):
|
|
self.base_url = base_url
|
|
|
|
async def create_agent(self, agent_config: AgentConfig) -> AgentCreateResponse:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.post(
|
|
f"{self.base_url}/agents/create",
|
|
json={
|
|
"agent_config": encodable_dict(agent_config),
|
|
},
|
|
headers={"Content-Type": "application/json"},
|
|
)
|
|
response.raise_for_status()
|
|
return AgentCreateResponse(**response.json())
|
|
|
|
async def create_agent_session(
|
|
self,
|
|
agent_id: str,
|
|
session_name: str,
|
|
) -> AgentSessionCreateResponse:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.post(
|
|
f"{self.base_url}/agents/session/create",
|
|
json={
|
|
"agent_id": agent_id,
|
|
"session_name": session_name,
|
|
},
|
|
headers={"Content-Type": "application/json"},
|
|
)
|
|
response.raise_for_status()
|
|
return AgentSessionCreateResponse(**response.json())
|
|
|
|
async def create_agent_turn(
|
|
self,
|
|
request: AgentTurnCreateRequest,
|
|
) -> AsyncGenerator:
|
|
async with httpx.AsyncClient() as client:
|
|
async with client.stream(
|
|
"POST",
|
|
f"{self.base_url}/agents/turn/create",
|
|
json=encodable_dict(request),
|
|
headers={"Content-Type": "application/json"},
|
|
timeout=20,
|
|
) as response:
|
|
async for line in response.aiter_lines():
|
|
if line.startswith("data:"):
|
|
data = line[len("data: ") :]
|
|
try:
|
|
jdata = json.loads(data)
|
|
if "error" in jdata:
|
|
cprint(data, "red")
|
|
continue
|
|
|
|
yield AgentTurnResponseStreamChunk(**jdata)
|
|
except Exception as e:
|
|
print(data)
|
|
print(f"Error with parsing or validation: {e}")
|
|
|
|
|
|
async def _run_agent(api, tool_definitions, user_prompts, attachments=None):
|
|
agent_config = AgentConfig(
|
|
model="Meta-Llama3.1-8B-Instruct",
|
|
instructions="You are a helpful assistant",
|
|
sampling_params=SamplingParams(temperature=1.0, top_p=0.9),
|
|
tools=tool_definitions,
|
|
tool_choice=ToolChoice.auto,
|
|
tool_prompt_format=ToolPromptFormat.function_tag,
|
|
enable_session_persistence=False,
|
|
)
|
|
|
|
create_response = await api.create_agent(agent_config)
|
|
session_response = await api.create_agent_session(
|
|
agent_id=create_response.agent_id,
|
|
session_name="test_session",
|
|
)
|
|
|
|
for content in user_prompts:
|
|
cprint(f"User> {content}", color="white", attrs=["bold"])
|
|
iterator = api.create_agent_turn(
|
|
AgentTurnCreateRequest(
|
|
agent_id=create_response.agent_id,
|
|
session_id=session_response.session_id,
|
|
messages=[
|
|
UserMessage(content=content),
|
|
],
|
|
attachments=attachments,
|
|
stream=True,
|
|
)
|
|
)
|
|
|
|
async for event, log in EventLogger().log(iterator):
|
|
if log is not None:
|
|
log.print()
|
|
|
|
|
|
async def run_main(host: str, port: int):
|
|
api = AgentsClient(f"http://{host}:{port}")
|
|
|
|
tool_definitions = [
|
|
SearchToolDefinition(
|
|
engine=SearchEngineType.brave,
|
|
api_key=os.getenv("BRAVE_SEARCH_API_KEY"),
|
|
),
|
|
WolframAlphaToolDefinition(api_key=os.getenv("WOLFRAM_ALPHA_API_KEY")),
|
|
CodeInterpreterToolDefinition(),
|
|
]
|
|
tool_definitions += [
|
|
FunctionCallToolDefinition(
|
|
function_name="get_boiling_point",
|
|
description="Get the boiling point of a imaginary liquids (eg. polyjuice)",
|
|
parameters={
|
|
"liquid_name": ToolParamDefinition(
|
|
param_type="str",
|
|
description="The name of the liquid",
|
|
required=True,
|
|
),
|
|
"celcius": ToolParamDefinition(
|
|
param_type="str",
|
|
description="Whether to return the boiling point in Celcius",
|
|
required=False,
|
|
),
|
|
},
|
|
),
|
|
]
|
|
|
|
user_prompts = [
|
|
"Who are you?",
|
|
"what is the 100th prime number?",
|
|
"Search web for who was 44th President of USA?",
|
|
"Write code to check if a number is prime. Use that to check if 7 is prime",
|
|
"What is the boiling point of polyjuicepotion ?",
|
|
]
|
|
await _run_agent(api, tool_definitions, user_prompts)
|
|
|
|
|
|
async def run_rag(host: str, port: int):
|
|
api = AgentsClient(f"http://{host}:{port}")
|
|
|
|
urls = [
|
|
"memory_optimizations.rst",
|
|
"chat.rst",
|
|
"llama3.rst",
|
|
"datasets.rst",
|
|
"qat_finetune.rst",
|
|
"lora_finetune.rst",
|
|
]
|
|
attachments = [
|
|
Attachment(
|
|
content=URL(
|
|
uri=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}"
|
|
),
|
|
mime_type="text/plain",
|
|
)
|
|
for i, url in enumerate(urls)
|
|
]
|
|
|
|
# Alternatively, you can pre-populate the memory bank with documents for example,
|
|
# using `llama_stack.memory.client`. Then you can grab the bank_id
|
|
# from the output of that run.
|
|
tool_definitions = [
|
|
MemoryToolDefinition(
|
|
max_tokens_in_context=2048,
|
|
memory_bank_configs=[],
|
|
),
|
|
]
|
|
|
|
user_prompts = [
|
|
"How do I use Lora?",
|
|
"Tell me briefly about llama3 and torchtune",
|
|
]
|
|
|
|
await _run_agent(api, tool_definitions, user_prompts, attachments)
|
|
|
|
|
|
def main(host: str, port: int, rag: bool = False):
|
|
fn = run_rag if rag else run_main
|
|
asyncio.run(fn(host, port))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
fire.Fire(main)
|