migrate agentic system

This commit is contained in:
Xi Yan 2024-09-11 13:57:39 -07:00
parent 4b34f741d0
commit 6049aada71
5 changed files with 39 additions and 7 deletions

View file

@ -416,7 +416,16 @@ class AgenticSystem(Protocol):
@webmethod(route="/agentic_system/turn/create")
async def create_agentic_system_turn(
self,
request: AgenticSystemTurnCreateRequest,
agent_id: str,
session_id: str,
messages: List[
Union[
UserMessage,
ToolResponseMessage,
]
],
attachments: Optional[List[Attachment]] = None,
stream: Optional[bool] = False,
) -> AgenticSystemTurnResponseStreamChunk: ...
@webmethod(route="/agentic_system/turn/get")

View file

@ -73,9 +73,7 @@ class AgenticSystemClient(AgenticSystem):
async with client.stream(
"POST",
f"{self.base_url}/agentic_system/turn/create",
json={
"request": encodable_dict(request),
},
json=encodable_dict(request),
headers={"Content-Type": "application/json"},
timeout=20,
) as response:
@ -134,7 +132,7 @@ async def run_main(host: str, port: int):
api = AgenticSystemClient(f"http://{host}:{port}")
tool_definitions = [
SearchToolDefinition(engine=SearchEngineType.bing),
SearchToolDefinition(engine=SearchEngineType.brave),
WolframAlphaToolDefinition(),
CodeInterpreterToolDefinition(),
]

View file

@ -400,7 +400,7 @@ class ChatAgent(ShieldRunnerMixin):
tool_calls = []
content = ""
stop_reason = None
async for chunk in self.inference_api.chat_completion(req):
async for chunk in self.inference_api.chat_completion_impl(req):
event = chunk.event
if event.event_type == ChatCompletionResponseEventType.start:
continue

View file

@ -114,8 +114,26 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
async def create_agentic_system_turn(
self,
request: AgenticSystemTurnCreateRequest,
agent_id: str,
session_id: str,
messages: List[
Union[
UserMessage,
ToolResponseMessage,
]
],
attachments: Optional[List[Attachment]] = None,
stream: Optional[bool] = False,
) -> AsyncGenerator:
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = AgenticSystemTurnCreateRequest(
agent_id=agent_id,
session_id=session_id,
messages=messages,
attachments=attachments,
stream=stream,
)
agent_id = request.agent_id
assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found"
agent = AGENT_INSTANCES_BY_ID[agent_id]

View file

@ -77,6 +77,13 @@ class MetaReferenceInferenceImpl(Inference):
logprobs=logprobs,
)
return self._chat_completion(request)
async def chat_completion_impl(
self, request: ChatCompletionRequest
) -> AsyncIterator[
Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse]
]:
messages = prepare_messages(request)
model = resolve_model(request.model)
if model is None: