Updates to server.py to clean up streaming vs non-streaming stuff

Also make sure agent turn create is correctly marked
This commit is contained in:
Ashwin Bharambe 2024-10-08 14:28:50 -07:00 committed by Ashwin Bharambe
parent 640c5c54f7
commit 7f1160296c
13 changed files with 115 additions and 128 deletions

View file

@ -411,8 +411,10 @@ class Agents(Protocol):
agent_config: AgentConfig, agent_config: AgentConfig,
) -> AgentCreateResponse: ... ) -> AgentCreateResponse: ...
# This method is not `async def` because it can result in either an
# `AsyncGenerator` or a `AgentTurnCreateResponse` depending on the value of `stream`.
@webmethod(route="/agents/turn/create") @webmethod(route="/agents/turn/create")
async def create_agent_turn( def create_agent_turn(
self, self,
agent_id: str, agent_id: str,
session_id: str, session_id: str,

View file

@ -7,7 +7,7 @@
import asyncio import asyncio
import json import json
import os import os
from typing import AsyncGenerator from typing import AsyncGenerator, Optional
import fire import fire
import httpx import httpx
@ -67,9 +67,17 @@ class AgentsClient(Agents):
response.raise_for_status() response.raise_for_status()
return AgentSessionCreateResponse(**response.json()) return AgentSessionCreateResponse(**response.json())
async def create_agent_turn( def create_agent_turn(
self, self,
request: AgentTurnCreateRequest, request: AgentTurnCreateRequest,
) -> AsyncGenerator:
if request.stream:
return self._stream_agent_turn(request)
else:
return self._nonstream_agent_turn(request)
async def _stream_agent_turn(
self, request: AgentTurnCreateRequest
) -> AsyncGenerator: ) -> AsyncGenerator:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
async with client.stream( async with client.stream(
@ -93,6 +101,9 @@ class AgentsClient(Agents):
print(data) print(data)
print(f"Error with parsing or validation: {e}") print(f"Error with parsing or validation: {e}")
async def _nonstream_agent_turn(self, request: AgentTurnCreateRequest):
raise NotImplementedError("Non-streaming not implemented yet")
async def _run_agent( async def _run_agent(
api, model, tool_definitions, tool_prompt_format, user_prompts, attachments=None api, model, tool_definitions, tool_prompt_format, user_prompts, attachments=None
@ -132,8 +143,7 @@ async def _run_agent(
log.print() log.print()
async def run_llama_3_1(host: str, port: int): async def run_llama_3_1(host: str, port: int, model: str = "Llama3.1-8B-Instruct"):
model = "Llama3.1-8B-Instruct"
api = AgentsClient(f"http://{host}:{port}") api = AgentsClient(f"http://{host}:{port}")
tool_definitions = [ tool_definitions = [
@ -173,8 +183,7 @@ async def run_llama_3_1(host: str, port: int):
await _run_agent(api, model, tool_definitions, ToolPromptFormat.json, user_prompts) await _run_agent(api, model, tool_definitions, ToolPromptFormat.json, user_prompts)
async def run_llama_3_2_rag(host: str, port: int): async def run_llama_3_2_rag(host: str, port: int, model: str = "Llama3.2-3B-Instruct"):
model = "Llama3.2-3B-Instruct"
api = AgentsClient(f"http://{host}:{port}") api = AgentsClient(f"http://{host}:{port}")
urls = [ urls = [
@ -215,8 +224,7 @@ async def run_llama_3_2_rag(host: str, port: int):
) )
async def run_llama_3_2(host: str, port: int): async def run_llama_3_2(host: str, port: int, model: str = "Llama3.2-3B-Instruct"):
model = "Llama3.2-3B-Instruct"
api = AgentsClient(f"http://{host}:{port}") api = AgentsClient(f"http://{host}:{port}")
# zero shot tools for llama3.2 text models # zero shot tools for llama3.2 text models
@ -262,7 +270,7 @@ async def run_llama_3_2(host: str, port: int):
) )
def main(host: str, port: int, run_type: str): def main(host: str, port: int, run_type: str, model: Optional[str] = None):
assert run_type in [ assert run_type in [
"tools_llama_3_1", "tools_llama_3_1",
"tools_llama_3_2", "tools_llama_3_2",
@ -274,7 +282,10 @@ def main(host: str, port: int, run_type: str):
"tools_llama_3_2": run_llama_3_2, "tools_llama_3_2": run_llama_3_2,
"rag_llama_3_2": run_llama_3_2_rag, "rag_llama_3_2": run_llama_3_2_rag,
} }
asyncio.run(fn[run_type](host, port)) args = [host, port]
if model is not None:
args.append(model)
asyncio.run(fn[run_type](*args))
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -22,7 +22,7 @@ class MemoryBankType(Enum):
class CommonDef(BaseModel): class CommonDef(BaseModel):
identifier: str identifier: str
provider_id: str provider_id: Optional[str] = None
@json_schema_type @json_schema_type

View file

@ -18,8 +18,8 @@ class ModelDef(BaseModel):
llama_model: str = Field( llama_model: str = Field(
description="Pointer to the core Llama family model", description="Pointer to the core Llama family model",
) )
provider_id: str = Field( provider_id: Optional[str] = Field(
description="The provider instance which serves this model" default=None, description="The provider instance which serves this model"
) )
# For now, we are only supporting core llama models but as soon as finetuned # For now, we are only supporting core llama models but as soon as finetuned
# and other custom models (for example various quantizations) are allowed, there # and other custom models (for example various quantizations) are allowed, there

View file

@ -96,12 +96,6 @@ async def run_main(host: str, port: int, image_path: str = None):
) )
print(response) print(response)
response = await client.run_shield(
shield_type="injection_shield",
messages=[message],
)
print(response)
def main(host: str, port: int, image: str = None): def main(host: str, port: int, image: str = None):
asyncio.run(run_main(host, port, image)) asyncio.run(run_main(host, port, image))

View file

@ -23,12 +23,12 @@ class ShieldDef(BaseModel):
identifier: str = Field( identifier: str = Field(
description="A unique identifier for the shield type", description="A unique identifier for the shield type",
) )
provider_id: str = Field(
description="The provider instance which serves this shield"
)
type: str = Field( type: str = Field(
description="The type of shield this is; the value is one of the ShieldType enum" description="The type of shield this is; the value is one of the ShieldType enum"
) )
provider_id: Optional[str] = Field(
default=None, description="The provider instance which serves this shield"
)
params: Dict[str, Any] = Field( params: Dict[str, Any] = Field(
default_factory=dict, default_factory=dict,
description="Any additional parameters needed for this shield", description="Any additional parameters needed for this shield",

View file

@ -50,8 +50,10 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
f"Provider `{provider.provider_type}` is not available for API `{api}`" f"Provider `{provider.provider_type}` is not available for API `{api}`"
) )
p = all_api_providers[api][provider.provider_type]
p.deps__ = [a.value for a in p.api_dependencies]
spec = ProviderWithSpec( spec = ProviderWithSpec(
spec=all_api_providers[api][provider.provider_type], spec=p,
**(provider.dict()), **(provider.dict()),
) )
specs[provider.provider_id] = spec specs[provider.provider_id] = spec
@ -93,6 +95,10 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
registry=registry, registry=registry,
module="llama_stack.distribution.routers", module="llama_stack.distribution.routers",
api_dependencies=inner_deps, api_dependencies=inner_deps,
deps__=(
[x.value for x in inner_deps]
+ [f"inner-{info.router_api.value}"]
),
), ),
) )
} }
@ -107,6 +113,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
module="llama_stack.distribution.routers", module="llama_stack.distribution.routers",
routing_table_api=info.routing_table_api, routing_table_api=info.routing_table_api,
api_dependencies=[info.routing_table_api], api_dependencies=[info.routing_table_api],
deps__=([info.routing_table_api.value]),
), ),
) )
} }
@ -130,6 +137,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
config_class="llama_stack.distribution.inspect.DistributionInspectConfig", config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
module="llama_stack.distribution.inspect", module="llama_stack.distribution.inspect",
api_dependencies=apis, api_dependencies=apis,
deps__=([x.value for x in apis]),
), ),
), ),
) )
@ -175,10 +183,8 @@ def topological_sort(
deps = [] deps = []
for provider in providers: for provider in providers:
for dep in provider.spec.api_dependencies: for dep in provider.spec.deps__:
deps.append(dep.value) deps.append(dep)
if isinstance(provider, AutoRoutedProviderSpec):
deps.append(f"inner-{provider.api}")
for dep in deps: for dep in deps:
if dep not in visited: if dep not in visited:

View file

@ -39,6 +39,7 @@ class CommonRoutingTableImpl(RoutingTable):
) -> None: ) -> None:
for obj in registry: for obj in registry:
if obj.provider_id not in impls_by_provider_id: if obj.provider_id not in impls_by_provider_id:
print(f"{impls_by_provider_id=}")
raise ValueError( raise ValueError(
f"Provider `{obj.provider_id}` pointed by `{obj.identifier}` not found" f"Provider `{obj.provider_id}` pointed by `{obj.identifier}` not found"
) )
@ -70,7 +71,7 @@ class CommonRoutingTableImpl(RoutingTable):
def get_provider_impl(self, routing_key: str) -> Any: def get_provider_impl(self, routing_key: str) -> Any:
if routing_key not in self.routing_key_to_object: if routing_key not in self.routing_key_to_object:
raise ValueError(f"Object `{routing_key}` not registered") raise ValueError(f"`{routing_key}` not registered")
obj = self.routing_key_to_object[routing_key] obj = self.routing_key_to_object[routing_key]
if obj.provider_id not in self.impls_by_provider_id: if obj.provider_id not in self.impls_by_provider_id:
@ -86,7 +87,7 @@ class CommonRoutingTableImpl(RoutingTable):
async def register_object(self, obj: RoutableObject): async def register_object(self, obj: RoutableObject):
if obj.identifier in self.routing_key_to_object: if obj.identifier in self.routing_key_to_object:
print(f"Object `{obj.identifier}` is already registered") print(f"`{obj.identifier}` is already registered")
return return
if not obj.provider_id: if not obj.provider_id:

View file

@ -11,13 +11,9 @@ import json
import signal import signal
import traceback import traceback
from collections.abc import (
AsyncGenerator as AsyncGeneratorABC,
AsyncIterator as AsyncIteratorABC,
)
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from ssl import SSLError from ssl import SSLError
from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional from typing import Any, Dict, Optional
import fire import fire
import httpx import httpx
@ -44,42 +40,13 @@ from llama_stack.distribution.resolver import resolve_impls_with_routing
from .endpoints import get_all_api_endpoints from .endpoints import get_all_api_endpoints
def is_async_iterator_type(typ): def create_sse_event(data: Any) -> str:
if hasattr(typ, "__origin__"):
origin = typ.__origin__
if isinstance(origin, type):
return issubclass(
origin,
(AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC),
)
return False
return isinstance(
typ, (AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC)
)
def create_sse_event(data: Any, **kwargs) -> str:
if isinstance(data, BaseModel): if isinstance(data, BaseModel):
data = data.json() data = data.json()
else: else:
data = json.dumps(data) data = json.dumps(data)
# !!FIX THIS ASAP!! grossest hack ever; not really SSE
#
# we use the return type of the function to determine if there's an AsyncGenerator
# and change the implementation to send SSE. unfortunately, chat_completion() takes a
# parameter called stream which _changes_ the return type. one correct way to fix this is:
#
# - have separate underlying functions for streaming and non-streaming because they need
# to operate differently anyhow
# - do a late binding of the return type based on the parameters passed in
if kwargs.get("stream", False):
return f"data: {data}\n\n" return f"data: {data}\n\n"
else:
print(
f"!!FIX THIS ASAP!! Sending non-SSE event because client really is non-SSE: {data}"
)
return data
async def global_exception_handler(request: Request, exc: Exception): async def global_exception_handler(request: Request, exc: Exception):
@ -221,27 +188,15 @@ def create_dynamic_passthrough(
return endpoint return endpoint
def create_dynamic_typed_route(func: Any, method: str): def is_streaming_request(func_name: str, request: Request, **kwargs):
hints = get_type_hints(func) # TODO: pass the api method and punt it to the Protocol definition directly
response_model = hints.get("return") return kwargs.get("stream", False)
# NOTE: I think it is better to just add a method within each Api
# "Protocol" / adapter-impl to tell what sort of a response this request
# is going to produce. /chat_completion can produce a streaming or
# non-streaming response depending on if request.stream is True / False.
is_streaming = is_async_iterator_type(response_model)
if is_streaming: async def sse_generator(event_gen):
async def endpoint(request: Request, **kwargs):
await start_trace(func.__name__)
set_request_provider_data(request.headers)
async def sse_generator(event_gen):
try: try:
async for item in event_gen: async for item in event_gen:
yield create_sse_event(item, **kwargs) yield create_sse_event(item)
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
except asyncio.CancelledError: except asyncio.CancelledError:
print("Generator cancelled") print("Generator cancelled")
@ -258,18 +213,21 @@ def create_dynamic_typed_route(func: Any, method: str):
finally: finally:
await end_trace() await end_trace()
return StreamingResponse(
sse_generator(func(**kwargs)), media_type="text/event-stream"
)
else: def create_dynamic_typed_route(func: Any, method: str):
async def endpoint(request: Request, **kwargs): async def endpoint(request: Request, **kwargs):
await start_trace(func.__name__) await start_trace(func.__name__)
set_request_provider_data(request.headers) set_request_provider_data(request.headers)
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
try: try:
if is_streaming:
return StreamingResponse(
sse_generator(func(**kwargs)), media_type="text/event-stream"
)
else:
return ( return (
await func(**kwargs) await func(**kwargs)
if asyncio.iscoroutinefunction(func) if asyncio.iscoroutinefunction(func)

View file

@ -41,6 +41,9 @@ class ProviderSpec(BaseModel):
description="Higher-level API surfaces may depend on other providers to provide their functionality", description="Higher-level API surfaces may depend on other providers to provide their functionality",
) )
# used internally by the resolver; this is a hack for now
deps__: List[str] = Field(default_factory=list)
class RoutingTable(Protocol): class RoutingTable(Protocol):
def get_provider_impl(self, routing_key: str) -> Any: ... def get_provider_impl(self, routing_key: str) -> Any: ...

View file

@ -144,6 +144,8 @@ class ChatAgent(ShieldRunnerMixin):
async def create_and_execute_turn( async def create_and_execute_turn(
self, request: AgentTurnCreateRequest self, request: AgentTurnCreateRequest
) -> AsyncGenerator: ) -> AsyncGenerator:
assert request.stream is True, "Non-streaming not supported"
session_info = await self.storage.get_session_info(request.session_id) session_info = await self.storage.get_session_info(request.session_id)
if session_info is None: if session_info is None:
raise ValueError(f"Session {request.session_id} not found") raise ValueError(f"Session {request.session_id} not found")
@ -635,14 +637,13 @@ class ChatAgent(ShieldRunnerMixin):
raise ValueError(f"Session {session_id} not found") raise ValueError(f"Session {session_id} not found")
if session_info.memory_bank_id is None: if session_info.memory_bank_id is None:
memory_bank = await self.memory_api.create_memory_bank( bank_id = f"memory_bank_{session_id}"
name=f"memory_bank_{session_id}", memory_bank = VectorMemoryBankDef(
config=VectorMemoryBankConfig( identifier=bank_id,
embedding_model="all-MiniLM-L6-v2", embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512, chunk_size_in_tokens=512,
),
) )
bank_id = memory_bank.bank_id await self.memory_api.register_memory_bank(memory_bank)
await self.storage.add_memory_bank_to_session(session_id, bank_id) await self.storage.add_memory_bank_to_session(session_id, bank_id)
else: else:
bank_id = session_info.memory_bank_id bank_id = session_info.memory_bank_id

View file

@ -100,7 +100,7 @@ class MetaReferenceAgentsImpl(Agents):
session_id=session_id, session_id=session_id,
) )
async def create_agent_turn( def create_agent_turn(
self, self,
agent_id: str, agent_id: str,
session_id: str, session_id: str,
@ -113,16 +113,22 @@ class MetaReferenceAgentsImpl(Agents):
attachments: Optional[List[Attachment]] = None, attachments: Optional[List[Attachment]] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
) -> AsyncGenerator: ) -> AsyncGenerator:
agent = await self.get_agent(agent_id)
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = AgentTurnCreateRequest( request = AgentTurnCreateRequest(
agent_id=agent_id, agent_id=agent_id,
session_id=session_id, session_id=session_id,
messages=messages, messages=messages,
attachments=attachments, attachments=attachments,
stream=stream, stream=True,
) )
if stream:
return self._create_agent_turn_streaming(request)
else:
raise NotImplementedError("Non-streaming agent turns not yet implemented")
async def _create_agent_turn_streaming(
self,
request: AgentTurnCreateRequest,
) -> AsyncGenerator:
agent = await self.get_agent(request.agent_id)
async for event in agent.create_and_execute_turn(request): async for event in agent.create_and_execute_turn(request):
yield event yield event

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.