Updates to server.py to clean up streaming vs non-streaming stuff

Also make sure agent turn create is correctly marked
This commit is contained in:
Ashwin Bharambe 2024-10-08 14:28:50 -07:00 committed by Ashwin Bharambe
parent 640c5c54f7
commit 7f1160296c
13 changed files with 115 additions and 128 deletions

View file

@ -7,7 +7,7 @@
import asyncio
import json
import os
from typing import AsyncGenerator
from typing import AsyncGenerator, Optional
import fire
import httpx
@ -67,9 +67,17 @@ class AgentsClient(Agents):
response.raise_for_status()
return AgentSessionCreateResponse(**response.json())
async def create_agent_turn(
def create_agent_turn(
self,
request: AgentTurnCreateRequest,
) -> AsyncGenerator:
if request.stream:
return self._stream_agent_turn(request)
else:
return self._nonstream_agent_turn(request)
async def _stream_agent_turn(
self, request: AgentTurnCreateRequest
) -> AsyncGenerator:
async with httpx.AsyncClient() as client:
async with client.stream(
@ -93,6 +101,9 @@ class AgentsClient(Agents):
print(data)
print(f"Error with parsing or validation: {e}")
async def _nonstream_agent_turn(self, request: AgentTurnCreateRequest):
raise NotImplementedError("Non-streaming not implemented yet")
async def _run_agent(
api, model, tool_definitions, tool_prompt_format, user_prompts, attachments=None
@ -132,8 +143,7 @@ async def _run_agent(
log.print()
async def run_llama_3_1(host: str, port: int):
model = "Llama3.1-8B-Instruct"
async def run_llama_3_1(host: str, port: int, model: str = "Llama3.1-8B-Instruct"):
api = AgentsClient(f"http://{host}:{port}")
tool_definitions = [
@ -173,8 +183,7 @@ async def run_llama_3_1(host: str, port: int):
await _run_agent(api, model, tool_definitions, ToolPromptFormat.json, user_prompts)
async def run_llama_3_2_rag(host: str, port: int):
model = "Llama3.2-3B-Instruct"
async def run_llama_3_2_rag(host: str, port: int, model: str = "Llama3.2-3B-Instruct"):
api = AgentsClient(f"http://{host}:{port}")
urls = [
@ -215,8 +224,7 @@ async def run_llama_3_2_rag(host: str, port: int):
)
async def run_llama_3_2(host: str, port: int):
model = "Llama3.2-3B-Instruct"
async def run_llama_3_2(host: str, port: int, model: str = "Llama3.2-3B-Instruct"):
api = AgentsClient(f"http://{host}:{port}")
# zero shot tools for llama3.2 text models
@ -262,7 +270,7 @@ async def run_llama_3_2(host: str, port: int):
)
def main(host: str, port: int, run_type: str):
def main(host: str, port: int, run_type: str, model: Optional[str] = None):
assert run_type in [
"tools_llama_3_1",
"tools_llama_3_2",
@ -274,7 +282,10 @@ def main(host: str, port: int, run_type: str):
"tools_llama_3_2": run_llama_3_2,
"rag_llama_3_2": run_llama_3_2_rag,
}
asyncio.run(fn[run_type](host, port))
args = [host, port]
if model is not None:
args.append(model)
asyncio.run(fn[run_type](*args))
if __name__ == "__main__":