[API Updates] Model / shield / memory-bank routing + agent persistence + support for private headers (#92)

This is yet another of those large PRs (hopefully we will have less and less of them as things mature fast). This one introduces substantial improvements and some simplifications to the stack.

Most important bits:

* Agents reference implementation now has support for session / turn persistence. The default implementation uses sqlite but there's also support for using Redis.

* We have re-architected the structure of the Stack APIs to allow for more flexible routing. The motivating use cases are:
  - routing model A to ollama and model B to a remote provider like Together
  - routing shield A to local impl while shield B to a remote provider like Bedrock
  - routing a vector memory bank to Weaviate while routing a keyvalue memory bank to Redis

* Support for provider specific parameters to be passed from the clients. A client can pass data using `x_llamastack_provider_data` parameter which can be type-checked and provided to the Adapter implementations.
This commit is contained in:
Ashwin Bharambe 2024-09-23 14:22:22 -07:00 committed by GitHub
parent 8bf8c07eb3
commit ec4fc800cc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
130 changed files with 9701 additions and 11227 deletions

View file

@ -25,10 +25,21 @@ from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.telemetry import tracing
from .persistence import AgentPersistence
from .rag.context_retriever import generate_rag_query
from .safety import SafetyException, ShieldRunnerMixin
from .tools.base import BaseTool
from .tools.builtin import interpret_content_as_attachment, SingleMessageBuiltinTool
from .tools.builtin import (
CodeInterpreterTool,
interpret_content_as_attachment,
PhotogenTool,
SearchTool,
WolframAlphaTool,
)
from .tools.safety import SafeTool
def make_random_string(length: int = 8):
@ -40,23 +51,44 @@ def make_random_string(length: int = 8):
class ChatAgent(ShieldRunnerMixin):
def __init__(
self,
agent_id: str,
agent_config: AgentConfig,
inference_api: Inference,
memory_api: Memory,
safety_api: Safety,
builtin_tools: List[SingleMessageBuiltinTool],
max_infer_iters: int = 10,
persistence_store: KVStore,
):
self.agent_id = agent_id
self.agent_config = agent_config
self.inference_api = inference_api
self.memory_api = memory_api
self.safety_api = safety_api
self.max_infer_iters = max_infer_iters
self.tools_dict = {t.get_name(): t for t in builtin_tools}
self.storage = AgentPersistence(agent_id, persistence_store)
self.tempdir = tempfile.mkdtemp()
self.sessions = {}
builtin_tools = []
for tool_defn in agent_config.tools:
if isinstance(tool_defn, WolframAlphaToolDefinition):
tool = WolframAlphaTool(tool_defn.api_key)
elif isinstance(tool_defn, SearchToolDefinition):
tool = SearchTool(tool_defn.engine, tool_defn.api_key)
elif isinstance(tool_defn, CodeInterpreterToolDefinition):
tool = CodeInterpreterTool()
elif isinstance(tool_defn, PhotogenToolDefinition):
tool = PhotogenTool(dump_dir=self.tempdir)
else:
continue
builtin_tools.append(
SafeTool(
tool,
safety_api,
tool_defn.input_shields,
tool_defn.output_shields,
)
)
self.tools_dict = {t.get_name(): t for t in builtin_tools}
ShieldRunnerMixin.__init__(
self,
@ -80,7 +112,6 @@ class ChatAgent(ShieldRunnerMixin):
msg.context = None
messages.append(msg)
# messages.extend(turn.input_messages)
for step in turn.steps:
if step.step_type == StepType.inference.value:
messages.append(step.model_response)
@ -94,43 +125,35 @@ class ChatAgent(ShieldRunnerMixin):
)
)
elif step.step_type == StepType.shield_call.value:
response = step.response
if response.is_violation:
if step.violation:
# CompletionMessage itself in the ShieldResponse
messages.append(
CompletionMessage(
content=response.violation_return_message,
content=violation.user_message,
stop_reason=StopReason.end_of_turn,
)
)
# print_dialog(messages)
return messages
def create_session(self, name: str) -> Session:
session_id = str(uuid.uuid4())
session = Session(
session_id=session_id,
session_name=name,
turns=[],
started_at=datetime.now(),
)
self.sessions[session_id] = session
return session
async def create_session(self, name: str) -> str:
return await self.storage.create_session(name)
@tracing.span("create_and_execute_turn")
async def create_and_execute_turn(
self, request: AgentTurnCreateRequest
) -> AsyncGenerator:
assert (
request.session_id in self.sessions
), f"Session {request.session_id} not found"
session_info = await self.storage.get_session_info(request.session_id)
if session_info is None:
raise ValueError(f"Session {request.session_id} not found")
session = self.sessions[request.session_id]
turns = await self.storage.get_session_turns(request.session_id)
messages = []
if len(session.turns) == 0 and self.agent_config.instructions != "":
if len(turns) == 0 and self.agent_config.instructions != "":
messages.append(SystemMessage(content=self.agent_config.instructions))
for i, turn in enumerate(session.turns):
for i, turn in enumerate(turns):
messages.extend(self.turn_to_messages(turn))
messages.extend(request.messages)
@ -148,7 +171,7 @@ class ChatAgent(ShieldRunnerMixin):
steps = []
output_message = None
async for chunk in self.run(
session=session,
session_id=request.session_id,
turn_id=turn_id,
input_messages=messages,
attachments=request.attachments or [],
@ -187,7 +210,7 @@ class ChatAgent(ShieldRunnerMixin):
completed_at=datetime.now(),
steps=steps,
)
session.turns.append(turn)
await self.storage.add_turn_to_session(request.session_id, turn)
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
@ -200,7 +223,7 @@ class ChatAgent(ShieldRunnerMixin):
async def run(
self,
session: Session,
session_id: str,
turn_id: str,
input_messages: List[Message],
attachments: List[Attachment],
@ -212,7 +235,7 @@ class ChatAgent(ShieldRunnerMixin):
# return a "final value" for the `yield from` statement. we simulate that by yielding a
# final boolean (to see whether an exception happened) and then explicitly testing for it.
async for res in self.run_shields_wrapper(
async for res in self.run_multiple_shields_wrapper(
turn_id, input_messages, self.input_shields, "user-input"
):
if isinstance(res, bool):
@ -221,7 +244,7 @@ class ChatAgent(ShieldRunnerMixin):
yield res
async for res in self._run(
session, turn_id, input_messages, attachments, sampling_params, stream
session_id, turn_id, input_messages, attachments, sampling_params, stream
):
if isinstance(res, bool):
return
@ -235,7 +258,7 @@ class ChatAgent(ShieldRunnerMixin):
# for output shields run on the full input and output combination
messages = input_messages + [final_response]
async for res in self.run_shields_wrapper(
async for res in self.run_multiple_shields_wrapper(
turn_id, messages, self.output_shields, "assistant-output"
):
if isinstance(res, bool):
@ -245,11 +268,12 @@ class ChatAgent(ShieldRunnerMixin):
yield final_response
async def run_shields_wrapper(
@tracing.span("run_shields")
async def run_multiple_shields_wrapper(
self,
turn_id: str,
messages: List[Message],
shields: List[ShieldDefinition],
shields: List[str],
touchpoint: str,
) -> AsyncGenerator:
if len(shields) == 0:
@ -266,7 +290,7 @@ class ChatAgent(ShieldRunnerMixin):
)
)
)
await self.run_shields(messages, shields)
await self.run_multiple_shields(messages, shields)
except SafetyException as e:
yield AgentTurnResponseStreamChunk(
@ -276,7 +300,7 @@ class ChatAgent(ShieldRunnerMixin):
step_details=ShieldCallStep(
step_id=step_id,
turn_id=turn_id,
response=e.response,
violation=e.violation,
),
)
)
@ -295,12 +319,7 @@ class ChatAgent(ShieldRunnerMixin):
step_details=ShieldCallStep(
step_id=step_id,
turn_id=turn_id,
response=ShieldResponse(
# TODO: fix this, give each shield a shield type method and
# fire one event for each shield run
shield_type=BuiltinShield.llama_guard,
is_violation=False,
),
violation=None,
),
)
)
@ -308,7 +327,7 @@ class ChatAgent(ShieldRunnerMixin):
async def _run(
self,
session: Session,
session_id: str,
turn_id: str,
input_messages: List[Message],
attachments: List[Attachment],
@ -332,9 +351,10 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: find older context from the session and either replace it
# or append with a sliding window. this is really a very simplistic implementation
rag_context, bank_ids = await self._retrieve_context(
session, input_messages, attachments
)
with tracing.span("retrieve_rag_context"):
rag_context, bank_ids = await self._retrieve_context(
session_id, input_messages, attachments
)
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
@ -387,55 +407,57 @@ class ChatAgent(ShieldRunnerMixin):
tool_calls = []
content = ""
stop_reason = None
async for chunk in self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
tools=self._get_tools(),
tool_prompt_format=self.agent_config.tool_prompt_format,
stream=True,
sampling_params=sampling_params,
):
event = chunk.event
if event.event_type == ChatCompletionResponseEventType.start:
continue
elif event.event_type == ChatCompletionResponseEventType.complete:
stop_reason = StopReason.end_of_turn
continue
delta = event.delta
if isinstance(delta, ToolCallDelta):
if delta.parse_status == ToolCallParseStatus.success:
tool_calls.append(delta.content)
with tracing.span("inference"):
async for chunk in self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
tools=self._get_tools(),
tool_prompt_format=self.agent_config.tool_prompt_format,
stream=True,
sampling_params=sampling_params,
):
event = chunk.event
if event.event_type == ChatCompletionResponseEventType.start:
continue
elif event.event_type == ChatCompletionResponseEventType.complete:
stop_reason = StopReason.end_of_turn
continue
if stream:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_id=step_id,
model_response_text_delta="",
tool_call_delta=delta,
delta = event.delta
if isinstance(delta, ToolCallDelta):
if delta.parse_status == ToolCallParseStatus.success:
tool_calls.append(delta.content)
if stream:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_id=step_id,
model_response_text_delta="",
tool_call_delta=delta,
)
)
)
)
elif isinstance(delta, str):
content += delta
if stream and event.stop_reason is None:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_id=step_id,
model_response_text_delta=event.delta,
elif isinstance(delta, str):
content += delta
if stream and event.stop_reason is None:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_id=step_id,
model_response_text_delta=event.delta,
)
)
)
)
else:
raise ValueError(f"Unexpected delta type {type(delta)}")
else:
raise ValueError(f"Unexpected delta type {type(delta)}")
if event.stop_reason is not None:
stop_reason = event.stop_reason
if event.stop_reason is not None:
stop_reason = event.stop_reason
stop_reason = stop_reason or StopReason.out_of_tokens
message = CompletionMessage(
@ -461,7 +483,7 @@ class ChatAgent(ShieldRunnerMixin):
)
)
if n_iter >= self.max_infer_iters:
if n_iter >= self.agent_config.max_infer_iters:
cprint("Done with MAX iterations, exiting.")
yield message
break
@ -512,14 +534,15 @@ class ChatAgent(ShieldRunnerMixin):
)
)
result_messages = await execute_tool_call_maybe(
self.tools_dict,
[message],
)
assert (
len(result_messages) == 1
), "Currently not supporting multiple messages"
result_message = result_messages[0]
with tracing.span("tool_execution"):
result_messages = await execute_tool_call_maybe(
self.tools_dict,
[message],
)
assert (
len(result_messages) == 1
), "Currently not supporting multiple messages"
result_message = result_messages[0]
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
@ -550,12 +573,7 @@ class ChatAgent(ShieldRunnerMixin):
step_details=ShieldCallStep(
step_id=str(uuid.uuid4()),
turn_id=turn_id,
response=ShieldResponse(
# TODO: fix this, give each shield a shield type method and
# fire one event for each shield run
shield_type=BuiltinShield.llama_guard,
is_violation=False,
),
violation=None,
),
)
)
@ -569,7 +587,7 @@ class ChatAgent(ShieldRunnerMixin):
step_details=ShieldCallStep(
step_id=str(uuid.uuid4()),
turn_id=turn_id,
response=e.response,
violation=e.violation,
),
)
)
@ -594,17 +612,25 @@ class ChatAgent(ShieldRunnerMixin):
n_iter += 1
async def _ensure_memory_bank(self, session: Session) -> MemoryBank:
if session.memory_bank is None:
session.memory_bank = await self.memory_api.create_memory_bank(
name=f"memory_bank_{session.session_id}",
async def _ensure_memory_bank(self, session_id: str) -> str:
session_info = await self.storage.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
if session_info.memory_bank_id is None:
memory_bank = await self.memory_api.create_memory_bank(
name=f"memory_bank_{session_id}",
config=VectorMemoryBankConfig(
embedding_model="sentence-transformer/all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
),
)
bank_id = memory_bank.bank_id
await self.storage.add_memory_bank_to_session(session_id, bank_id)
else:
bank_id = session_info.memory_bank_id
return session.memory_bank
return bank_id
async def _should_retrieve_context(
self, messages: List[Message], attachments: List[Attachment]
@ -619,7 +645,6 @@ class ChatAgent(ShieldRunnerMixin):
else:
return True
print(f"{enabled_tools=}")
return AgentTool.memory.value in enabled_tools
def _memory_tool_definition(self) -> Optional[MemoryToolDefinition]:
@ -630,7 +655,7 @@ class ChatAgent(ShieldRunnerMixin):
return None
async def _retrieve_context(
self, session: Session, messages: List[Message], attachments: List[Attachment]
self, session_id: str, messages: List[Message], attachments: List[Attachment]
) -> Tuple[List[str], List[int]]: # (rag_context, bank_ids)
bank_ids = []
@ -639,8 +664,8 @@ class ChatAgent(ShieldRunnerMixin):
bank_ids.extend(c.bank_id for c in memory.memory_bank_configs)
if attachments:
bank = await self._ensure_memory_bank(session)
bank_ids.append(bank.bank_id)
bank_id = await self._ensure_memory_bank(session_id)
bank_ids.append(bank_id)
documents = [
MemoryBankDocument(
@ -651,9 +676,12 @@ class ChatAgent(ShieldRunnerMixin):
)
for a in attachments
]
await self.memory_api.insert_documents(bank.bank_id, documents)
elif session.memory_bank:
bank_ids.append(session.memory_bank.bank_id)
with tracing.span("insert_documents"):
await self.memory_api.insert_documents(bank_id, documents)
else:
session_info = await self.storage.get_session_info(session_id)
if session_info.memory_bank_id:
bank_ids.append(session_info.memory_bank_id)
if not bank_ids:
# this can happen if the per-session memory bank is not yet populated