mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-09 11:20:58 +00:00
Updates to server.py to clean up streaming vs non-streaming stuff
Also make sure agent turn create is correctly marked
This commit is contained in:
parent
640c5c54f7
commit
7f1160296c
13 changed files with 115 additions and 128 deletions
|
|
@ -7,7 +7,7 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from typing import AsyncGenerator
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
import fire
|
||||
import httpx
|
||||
|
|
@ -67,9 +67,17 @@ class AgentsClient(Agents):
|
|||
response.raise_for_status()
|
||||
return AgentSessionCreateResponse(**response.json())
|
||||
|
||||
async def create_agent_turn(
|
||||
def create_agent_turn(
|
||||
self,
|
||||
request: AgentTurnCreateRequest,
|
||||
) -> AsyncGenerator:
|
||||
if request.stream:
|
||||
return self._stream_agent_turn(request)
|
||||
else:
|
||||
return self._nonstream_agent_turn(request)
|
||||
|
||||
async def _stream_agent_turn(
|
||||
self, request: AgentTurnCreateRequest
|
||||
) -> AsyncGenerator:
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.stream(
|
||||
|
|
@ -93,6 +101,9 @@ class AgentsClient(Agents):
|
|||
print(data)
|
||||
print(f"Error with parsing or validation: {e}")
|
||||
|
||||
async def _nonstream_agent_turn(self, request: AgentTurnCreateRequest):
|
||||
raise NotImplementedError("Non-streaming not implemented yet")
|
||||
|
||||
|
||||
async def _run_agent(
|
||||
api, model, tool_definitions, tool_prompt_format, user_prompts, attachments=None
|
||||
|
|
@ -132,8 +143,7 @@ async def _run_agent(
|
|||
log.print()
|
||||
|
||||
|
||||
async def run_llama_3_1(host: str, port: int):
|
||||
model = "Llama3.1-8B-Instruct"
|
||||
async def run_llama_3_1(host: str, port: int, model: str = "Llama3.1-8B-Instruct"):
|
||||
api = AgentsClient(f"http://{host}:{port}")
|
||||
|
||||
tool_definitions = [
|
||||
|
|
@ -173,8 +183,7 @@ async def run_llama_3_1(host: str, port: int):
|
|||
await _run_agent(api, model, tool_definitions, ToolPromptFormat.json, user_prompts)
|
||||
|
||||
|
||||
async def run_llama_3_2_rag(host: str, port: int):
|
||||
model = "Llama3.2-3B-Instruct"
|
||||
async def run_llama_3_2_rag(host: str, port: int, model: str = "Llama3.2-3B-Instruct"):
|
||||
api = AgentsClient(f"http://{host}:{port}")
|
||||
|
||||
urls = [
|
||||
|
|
@ -215,8 +224,7 @@ async def run_llama_3_2_rag(host: str, port: int):
|
|||
)
|
||||
|
||||
|
||||
async def run_llama_3_2(host: str, port: int):
|
||||
model = "Llama3.2-3B-Instruct"
|
||||
async def run_llama_3_2(host: str, port: int, model: str = "Llama3.2-3B-Instruct"):
|
||||
api = AgentsClient(f"http://{host}:{port}")
|
||||
|
||||
# zero shot tools for llama3.2 text models
|
||||
|
|
@ -262,7 +270,7 @@ async def run_llama_3_2(host: str, port: int):
|
|||
)
|
||||
|
||||
|
||||
def main(host: str, port: int, run_type: str):
|
||||
def main(host: str, port: int, run_type: str, model: Optional[str] = None):
|
||||
assert run_type in [
|
||||
"tools_llama_3_1",
|
||||
"tools_llama_3_2",
|
||||
|
|
@ -274,7 +282,10 @@ def main(host: str, port: int, run_type: str):
|
|||
"tools_llama_3_2": run_llama_3_2,
|
||||
"rag_llama_3_2": run_llama_3_2_rag,
|
||||
}
|
||||
asyncio.run(fn[run_type](host, port))
|
||||
args = [host, port]
|
||||
if model is not None:
|
||||
args.append(model)
|
||||
asyncio.run(fn[run_type](*args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue