mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 13:00:39 +00:00
tmp
This commit is contained in:
parent
6b6feebc72
commit
7dae81cb68
3 changed files with 30 additions and 7 deletions
|
@ -344,15 +344,15 @@ class Agents(Protocol):
|
||||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
||||||
|
|
||||||
@webmethod(
|
@webmethod(
|
||||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/submit_tool_response_messages",
|
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/tool_responses",
|
||||||
method="POST",
|
method="POST",
|
||||||
)
|
)
|
||||||
async def submit_tool_response_messages(
|
async def submit_tool_responses(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
turn_id: str,
|
turn_id: str,
|
||||||
tool_response_messages: List[ToolResponseMessage],
|
tool_responses: Dict[str, ToolResponseMessage],
|
||||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
||||||
|
|
||||||
@webmethod(
|
@webmethod(
|
||||||
|
|
|
@ -31,6 +31,7 @@ from llama_stack.apis.agents import (
|
||||||
AgentTurnResponseStepStartPayload,
|
AgentTurnResponseStepStartPayload,
|
||||||
AgentTurnResponseStreamChunk,
|
AgentTurnResponseStreamChunk,
|
||||||
AgentTurnResponseTurnCompletePayload,
|
AgentTurnResponseTurnCompletePayload,
|
||||||
|
AgentTurnResponseTurnPendingPayload,
|
||||||
AgentTurnResponseTurnStartPayload,
|
AgentTurnResponseTurnStartPayload,
|
||||||
Attachment,
|
Attachment,
|
||||||
Document,
|
Document,
|
||||||
|
@ -62,7 +63,11 @@ from llama_stack.apis.inference import (
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.models.llama.datatypes import BuiltinTool, ToolCall, ToolParamDefinition
|
from llama_stack.models.llama.datatypes import (
|
||||||
|
BuiltinTool,
|
||||||
|
ToolCall,
|
||||||
|
ToolParamDefinition,
|
||||||
|
)
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
|
||||||
from llama_stack.providers.utils.telemetry import tracing
|
from llama_stack.providers.utils.telemetry import tracing
|
||||||
|
@ -222,6 +227,15 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||||
|
|
||||||
|
if output_message.tool_calls:
|
||||||
|
chunk = AgentTurnResponseStreamChunk(
|
||||||
|
event=AgentTurnResponseEvent(
|
||||||
|
payload=AgentTurnResponseTurnPendingPayload(
|
||||||
|
turn=turn,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
chunk = AgentTurnResponseStreamChunk(
|
chunk = AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseTurnCompletePayload(
|
payload=AgentTurnResponseTurnCompletePayload(
|
||||||
|
@ -229,7 +243,8 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
yield chunk
|
|
||||||
|
yield chunk
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -19,8 +19,12 @@ from llama_stack_client.types.shared.completion_message import CompletionMessage
|
||||||
from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig
|
from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig
|
||||||
from llama_stack_client.types.tool_def_param import Parameter
|
from llama_stack_client.types.tool_def_param import Parameter
|
||||||
|
|
||||||
from llama_stack.apis.agents.agents import AgentConfig as Server__AgentConfig
|
from llama_stack.apis.agents.agents import (
|
||||||
from llama_stack.apis.agents.agents import ToolChoice
|
AgentConfig as Server__AgentConfig,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.agents.agents import (
|
||||||
|
ToolChoice,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestClientTool(ClientTool):
|
class TestClientTool(ClientTool):
|
||||||
|
@ -314,6 +318,10 @@ def test_custom_tool(llama_stack_client, agent_config):
|
||||||
],
|
],
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
from rich.pretty import pprint
|
||||||
|
|
||||||
|
for x in response:
|
||||||
|
pprint(x)
|
||||||
|
|
||||||
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
||||||
logs_str = "".join(logs)
|
logs_str = "".join(logs)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue