Add timeout and retries for HTTP requests in AgenticSystemClient

This commit is contained in:
Mandlin Sarah 2024-09-03 03:20:07 -07:00
parent 70d557f793
commit b5d958631e

View file

@ -12,6 +12,7 @@ from typing import AsyncGenerator
import fire import fire
import httpx import httpx
from httpx import Timeout, Retry
from llama_models.llama3.api.datatypes import ( from llama_models.llama3.api.datatypes import (
BuiltinTool, BuiltinTool,
@ -47,7 +48,7 @@ class AgenticSystemClient(AgenticSystem):
async def create_agentic_system( async def create_agentic_system(
self, request: AgenticSystemCreateRequest self, request: AgenticSystemCreateRequest
) -> AgenticSystemCreateResponse: ) -> AgenticSystemCreateResponse:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient(timeout=Timeout(10.0), retries=Retry(3)) as client:
response = await client.post( response = await client.post(
f"{self.base_url}/agentic_system/create", f"{self.base_url}/agentic_system/create",
data=request.json(), data=request.json(),
@ -60,7 +61,7 @@ class AgenticSystemClient(AgenticSystem):
self, self,
request: AgenticSystemSessionCreateRequest, request: AgenticSystemSessionCreateRequest,
) -> AgenticSystemSessionCreateResponse: ) -> AgenticSystemSessionCreateResponse:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient(timeout=Timeout(10.0), retries=Retry(3)) as client:
response = await client.post( response = await client.post(
f"{self.base_url}/agentic_system/session/create", f"{self.base_url}/agentic_system/session/create",
data=request.json(), data=request.json(),
@ -73,13 +74,12 @@ class AgenticSystemClient(AgenticSystem):
self, self,
request: AgenticSystemTurnCreateRequest, request: AgenticSystemTurnCreateRequest,
) -> AsyncGenerator: ) -> AsyncGenerator:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient(timeout=Timeout(20.0), retries=Retry(3)) as client:
async with client.stream( async with client.stream(
"POST", "POST",
f"{self.base_url}/agentic_system/turn/create", f"{self.base_url}/agentic_system/turn/create",
data=request.json(), data=request.json(),
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
timeout=20,
) as response: ) as response:
async for line in response.aiter_lines(): async for line in response.aiter_lines():
if line.startswith("data:"): if line.startswith("data:"):
@ -182,3 +182,4 @@ def main(host: str, port: int):
if __name__ == "__main__": if __name__ == "__main__":
fire.Fire(main) fire.Fire(main)