mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Rag Updates
This commit is contained in:
parent
ea6d9ec937
commit
f5620c09ad
4 changed files with 48 additions and 40 deletions
|
@ -8,7 +8,7 @@ import asyncio
|
||||||
import copy
|
import copy
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import AsyncGenerator, List
|
from typing import AsyncGenerator, List, Tuple
|
||||||
|
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
|
@ -53,6 +53,44 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
output_shields=agent_config.output_shields,
|
output_shields=agent_config.output_shields,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def turn_to_messages(self, turn: Turn) -> List[Message]:
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
# We do not want to keep adding RAG context to the input messages
|
||||||
|
# May be this should be a parameter of the agentic instance
|
||||||
|
# that can define its behavior in a custom way
|
||||||
|
for m in turn.input_messages:
|
||||||
|
msg = m.copy()
|
||||||
|
if isinstance(msg, UserMessage):
|
||||||
|
msg.context = None
|
||||||
|
messages.append(msg)
|
||||||
|
|
||||||
|
# messages.extend(turn.input_messages)
|
||||||
|
for step in turn.steps:
|
||||||
|
if step.step_type == StepType.inference.value:
|
||||||
|
messages.append(step.model_response)
|
||||||
|
elif step.step_type == StepType.tool_execution.value:
|
||||||
|
for response in step.tool_responses:
|
||||||
|
messages.append(
|
||||||
|
ToolResponseMessage(
|
||||||
|
call_id=response.call_id,
|
||||||
|
tool_name=response.tool_name,
|
||||||
|
content=response.content,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif step.step_type == StepType.shield_call.value:
|
||||||
|
response = step.response
|
||||||
|
if response.is_violation:
|
||||||
|
# CompletionMessage itself in the ShieldResponse
|
||||||
|
messages.append(
|
||||||
|
CompletionMessage(
|
||||||
|
content=response.violation_return_message,
|
||||||
|
stop_reason=StopReason.end_of_turn,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# print_dialog(messages)
|
||||||
|
return messages
|
||||||
|
|
||||||
def create_session(self, name: str) -> Session:
|
def create_session(self, name: str) -> Session:
|
||||||
session_id = str(uuid.uuid4())
|
session_id = str(uuid.uuid4())
|
||||||
session = Session(
|
session = Session(
|
||||||
|
@ -75,32 +113,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
messages = []
|
messages = []
|
||||||
for i, turn in enumerate(session.turns):
|
for i, turn in enumerate(session.turns):
|
||||||
# print(f"turn {i}")
|
messages.extend(self.turn_to_messages(turn))
|
||||||
# print_dialog(turn.input_messages)
|
|
||||||
messages.extend(turn.input_messages)
|
|
||||||
for step in turn.steps:
|
|
||||||
if step.step_type == StepType.inference.value:
|
|
||||||
messages.append(step.model_response)
|
|
||||||
elif step.step_type == StepType.tool_execution.value:
|
|
||||||
for response in step.tool_responses:
|
|
||||||
messages.append(
|
|
||||||
ToolResponseMessage(
|
|
||||||
call_id=response.call_id,
|
|
||||||
tool_name=response.tool_name,
|
|
||||||
content=response.content,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif step.step_type == StepType.shield_call.value:
|
|
||||||
response = step.response
|
|
||||||
if response.is_violation:
|
|
||||||
# TODO: Properly persist the
|
|
||||||
# CompletionMessage itself in the ShieldResponse
|
|
||||||
messages.append(
|
|
||||||
CompletionMessage(
|
|
||||||
content=response.violation_return_message,
|
|
||||||
stop_reason=StopReason.end_of_turn,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
messages.extend(request.messages)
|
messages.extend(request.messages)
|
||||||
|
|
||||||
|
@ -325,15 +338,8 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
if rag_context:
|
if rag_context:
|
||||||
system_message = next(
|
last_message = input_messages[-1]
|
||||||
(m for m in input_messages if m.role == "system"), None
|
last_message.context = "\n".join(rag_context)
|
||||||
)
|
|
||||||
if system_message:
|
|
||||||
system_message.content = system_message.content + "\n" + rag_context
|
|
||||||
else:
|
|
||||||
input_messages = [
|
|
||||||
SystemMessage(content=rag_context)
|
|
||||||
] + input_messages
|
|
||||||
|
|
||||||
elif attachments and AgenticSystemTool.code_interpreter.value in enabled_tools:
|
elif attachments and AgenticSystemTool.code_interpreter.value in enabled_tools:
|
||||||
urls = [a.content for a in attachments if isinstance(a.content, URL)]
|
urls = [a.content for a in attachments if isinstance(a.content, URL)]
|
||||||
|
@ -610,7 +616,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
async def _retrieve_context(
|
async def _retrieve_context(
|
||||||
self, session: Session, messages: List[Message], attachments: List[Attachment]
|
self, session: Session, messages: List[Message], attachments: List[Attachment]
|
||||||
) -> Optional[InterleavedTextMedia]:
|
) -> Tuple[List[str], List[int]]: # (rag_context, bank_ids)
|
||||||
bank_ids = []
|
bank_ids = []
|
||||||
|
|
||||||
memory = self._memory_tool_definition()
|
memory = self._memory_tool_definition()
|
||||||
|
@ -671,10 +677,10 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
"red",
|
"red",
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
picked.append(c.content)
|
picked.append(f"id:{c.document_id}; content:{c.content}")
|
||||||
|
|
||||||
return [
|
return [
|
||||||
"The following context was retrieved from the memory bank:\n=== START-RETRIEVED-CONTEXT ===\n",
|
"Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n",
|
||||||
*picked,
|
*picked,
|
||||||
"\n=== END-RETRIEVED-CONTEXT ===\n",
|
"\n=== END-RETRIEVED-CONTEXT ===\n",
|
||||||
], bank_ids
|
], bank_ids
|
||||||
|
|
|
@ -49,7 +49,7 @@ def available_distribution_specs() -> List[DistributionSpec]:
|
||||||
Api.inference: providers[Api.inference]["meta-ollama"],
|
Api.inference: providers[Api.inference]["meta-ollama"],
|
||||||
Api.safety: providers[Api.safety]["meta-reference"],
|
Api.safety: providers[Api.safety]["meta-reference"],
|
||||||
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
|
Api.agentic_system: providers[Api.agentic_system]["meta-reference"],
|
||||||
Api.memory: remote_spec(Api.memory),
|
Api.memory: providers[Api.memory]["meta-reference-faiss"],
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
DistributionSpec(
|
DistributionSpec(
|
||||||
|
|
|
@ -68,6 +68,7 @@ MemoryBankConfig = Annotated[
|
||||||
class Chunk(BaseModel):
|
class Chunk(BaseModel):
|
||||||
content: InterleavedTextMedia
|
content: InterleavedTextMedia
|
||||||
token_count: int
|
token_count: int
|
||||||
|
document_id: str
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -89,6 +89,7 @@ class BankState:
|
||||||
self.chunk_by_index[indexlen + i] = Chunk(
|
self.chunk_by_index[indexlen + i] = Chunk(
|
||||||
content=chunk[0],
|
content=chunk[0],
|
||||||
token_count=chunk[1],
|
token_count=chunk[1],
|
||||||
|
document_id=doc.document_id,
|
||||||
)
|
)
|
||||||
print(f"Adding chunk #{indexlen + i} tokens={chunk[1]}")
|
print(f"Adding chunk #{indexlen + i} tokens={chunk[1]}")
|
||||||
self.id_by_index[indexlen + i] = doc.document_id
|
self.id_by_index[indexlen + i] = doc.document_id
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue