mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-08 19:10:56 +00:00
Updates to server.py to clean up streaming vs non-streaming stuff
Also make sure agent turn create is correctly marked
This commit is contained in:
parent
640c5c54f7
commit
7f1160296c
13 changed files with 115 additions and 128 deletions
|
|
@ -411,8 +411,10 @@ class Agents(Protocol):
|
|||
agent_config: AgentConfig,
|
||||
) -> AgentCreateResponse: ...
|
||||
|
||||
# This method is not `async def` because it can result in either an
|
||||
# `AsyncGenerator` or a `AgentTurnCreateResponse` depending on the value of `stream`.
|
||||
@webmethod(route="/agents/turn/create")
|
||||
async def create_agent_turn(
|
||||
def create_agent_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from typing import AsyncGenerator
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
import fire
|
||||
import httpx
|
||||
|
|
@ -67,9 +67,17 @@ class AgentsClient(Agents):
|
|||
response.raise_for_status()
|
||||
return AgentSessionCreateResponse(**response.json())
|
||||
|
||||
async def create_agent_turn(
|
||||
def create_agent_turn(
|
||||
self,
|
||||
request: AgentTurnCreateRequest,
|
||||
) -> AsyncGenerator:
|
||||
if request.stream:
|
||||
return self._stream_agent_turn(request)
|
||||
else:
|
||||
return self._nonstream_agent_turn(request)
|
||||
|
||||
async def _stream_agent_turn(
|
||||
self, request: AgentTurnCreateRequest
|
||||
) -> AsyncGenerator:
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.stream(
|
||||
|
|
@ -93,6 +101,9 @@ class AgentsClient(Agents):
|
|||
print(data)
|
||||
print(f"Error with parsing or validation: {e}")
|
||||
|
||||
async def _nonstream_agent_turn(self, request: AgentTurnCreateRequest):
|
||||
raise NotImplementedError("Non-streaming not implemented yet")
|
||||
|
||||
|
||||
async def _run_agent(
|
||||
api, model, tool_definitions, tool_prompt_format, user_prompts, attachments=None
|
||||
|
|
@ -132,8 +143,7 @@ async def _run_agent(
|
|||
log.print()
|
||||
|
||||
|
||||
async def run_llama_3_1(host: str, port: int):
|
||||
model = "Llama3.1-8B-Instruct"
|
||||
async def run_llama_3_1(host: str, port: int, model: str = "Llama3.1-8B-Instruct"):
|
||||
api = AgentsClient(f"http://{host}:{port}")
|
||||
|
||||
tool_definitions = [
|
||||
|
|
@ -173,8 +183,7 @@ async def run_llama_3_1(host: str, port: int):
|
|||
await _run_agent(api, model, tool_definitions, ToolPromptFormat.json, user_prompts)
|
||||
|
||||
|
||||
async def run_llama_3_2_rag(host: str, port: int):
|
||||
model = "Llama3.2-3B-Instruct"
|
||||
async def run_llama_3_2_rag(host: str, port: int, model: str = "Llama3.2-3B-Instruct"):
|
||||
api = AgentsClient(f"http://{host}:{port}")
|
||||
|
||||
urls = [
|
||||
|
|
@ -215,8 +224,7 @@ async def run_llama_3_2_rag(host: str, port: int):
|
|||
)
|
||||
|
||||
|
||||
async def run_llama_3_2(host: str, port: int):
|
||||
model = "Llama3.2-3B-Instruct"
|
||||
async def run_llama_3_2(host: str, port: int, model: str = "Llama3.2-3B-Instruct"):
|
||||
api = AgentsClient(f"http://{host}:{port}")
|
||||
|
||||
# zero shot tools for llama3.2 text models
|
||||
|
|
@ -262,7 +270,7 @@ async def run_llama_3_2(host: str, port: int):
|
|||
)
|
||||
|
||||
|
||||
def main(host: str, port: int, run_type: str):
|
||||
def main(host: str, port: int, run_type: str, model: Optional[str] = None):
|
||||
assert run_type in [
|
||||
"tools_llama_3_1",
|
||||
"tools_llama_3_2",
|
||||
|
|
@ -274,7 +282,10 @@ def main(host: str, port: int, run_type: str):
|
|||
"tools_llama_3_2": run_llama_3_2,
|
||||
"rag_llama_3_2": run_llama_3_2_rag,
|
||||
}
|
||||
asyncio.run(fn[run_type](host, port))
|
||||
args = [host, port]
|
||||
if model is not None:
|
||||
args.append(model)
|
||||
asyncio.run(fn[run_type](*args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ class MemoryBankType(Enum):
|
|||
|
||||
class CommonDef(BaseModel):
|
||||
identifier: str
|
||||
provider_id: str
|
||||
provider_id: Optional[str] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
|||
|
|
@ -18,8 +18,8 @@ class ModelDef(BaseModel):
|
|||
llama_model: str = Field(
|
||||
description="Pointer to the core Llama family model",
|
||||
)
|
||||
provider_id: str = Field(
|
||||
description="The provider instance which serves this model"
|
||||
provider_id: Optional[str] = Field(
|
||||
default=None, description="The provider instance which serves this model"
|
||||
)
|
||||
# For now, we are only supporting core llama models but as soon as finetuned
|
||||
# and other custom models (for example various quantizations) are allowed, there
|
||||
|
|
|
|||
|
|
@ -96,12 +96,6 @@ async def run_main(host: str, port: int, image_path: str = None):
|
|||
)
|
||||
print(response)
|
||||
|
||||
response = await client.run_shield(
|
||||
shield_type="injection_shield",
|
||||
messages=[message],
|
||||
)
|
||||
print(response)
|
||||
|
||||
|
||||
def main(host: str, port: int, image: str = None):
|
||||
asyncio.run(run_main(host, port, image))
|
||||
|
|
|
|||
|
|
@ -23,12 +23,12 @@ class ShieldDef(BaseModel):
|
|||
identifier: str = Field(
|
||||
description="A unique identifier for the shield type",
|
||||
)
|
||||
provider_id: str = Field(
|
||||
description="The provider instance which serves this shield"
|
||||
)
|
||||
type: str = Field(
|
||||
description="The type of shield this is; the value is one of the ShieldType enum"
|
||||
)
|
||||
provider_id: Optional[str] = Field(
|
||||
default=None, description="The provider instance which serves this shield"
|
||||
)
|
||||
params: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Any additional parameters needed for this shield",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue