diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index f9ad44efc..6efe1b229 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -411,8 +411,10 @@ class Agents(Protocol): agent_config: AgentConfig, ) -> AgentCreateResponse: ... + # This method is not `async def` because it can result in either an + # `AsyncGenerator` or a `AgentTurnCreateResponse` depending on the value of `stream`. @webmethod(route="/agents/turn/create") - async def create_agent_turn( + def create_agent_turn( self, agent_id: str, session_id: str, diff --git a/llama_stack/apis/agents/client.py b/llama_stack/apis/agents/client.py index 27ebde57a..32bc9abdd 100644 --- a/llama_stack/apis/agents/client.py +++ b/llama_stack/apis/agents/client.py @@ -7,7 +7,7 @@ import asyncio import json import os -from typing import AsyncGenerator +from typing import AsyncGenerator, Optional import fire import httpx @@ -67,9 +67,17 @@ class AgentsClient(Agents): response.raise_for_status() return AgentSessionCreateResponse(**response.json()) - async def create_agent_turn( + def create_agent_turn( self, request: AgentTurnCreateRequest, + ) -> AsyncGenerator: + if request.stream: + return self._stream_agent_turn(request) + else: + return self._nonstream_agent_turn(request) + + async def _stream_agent_turn( + self, request: AgentTurnCreateRequest ) -> AsyncGenerator: async with httpx.AsyncClient() as client: async with client.stream( @@ -93,6 +101,9 @@ class AgentsClient(Agents): print(data) print(f"Error with parsing or validation: {e}") + async def _nonstream_agent_turn(self, request: AgentTurnCreateRequest): + raise NotImplementedError("Non-streaming not implemented yet") + async def _run_agent( api, model, tool_definitions, tool_prompt_format, user_prompts, attachments=None @@ -132,8 +143,7 @@ async def _run_agent( log.print() -async def run_llama_3_1(host: str, port: int): - model = "Llama3.1-8B-Instruct" +async def run_llama_3_1(host: str, port: int, model: str = "Llama3.1-8B-Instruct"): api = AgentsClient(f"http://{host}:{port}") tool_definitions = [ @@ -173,8 +183,7 @@ async def run_llama_3_1(host: str, port: int): await _run_agent(api, model, tool_definitions, ToolPromptFormat.json, user_prompts) -async def run_llama_3_2_rag(host: str, port: int): - model = "Llama3.2-3B-Instruct" +async def run_llama_3_2_rag(host: str, port: int, model: str = "Llama3.2-3B-Instruct"): api = AgentsClient(f"http://{host}:{port}") urls = [ @@ -215,8 +224,7 @@ async def run_llama_3_2_rag(host: str, port: int): ) -async def run_llama_3_2(host: str, port: int): - model = "Llama3.2-3B-Instruct" +async def run_llama_3_2(host: str, port: int, model: str = "Llama3.2-3B-Instruct"): api = AgentsClient(f"http://{host}:{port}") # zero shot tools for llama3.2 text models @@ -262,7 +270,7 @@ async def run_llama_3_2(host: str, port: int): ) -def main(host: str, port: int, run_type: str): +def main(host: str, port: int, run_type: str, model: Optional[str] = None): assert run_type in [ "tools_llama_3_1", "tools_llama_3_2", @@ -274,7 +282,10 @@ def main(host: str, port: int, run_type: str): "tools_llama_3_2": run_llama_3_2, "rag_llama_3_2": run_llama_3_2_rag, } - asyncio.run(fn[run_type](host, port)) + args = [host, port] + if model is not None: + args.append(model) + asyncio.run(fn[run_type](*args)) if __name__ == "__main__": diff --git a/llama_stack/apis/memory_banks/memory_banks.py b/llama_stack/apis/memory_banks/memory_banks.py index d54c3868d..6d9f2f9f6 100644 --- a/llama_stack/apis/memory_banks/memory_banks.py +++ b/llama_stack/apis/memory_banks/memory_banks.py @@ -22,7 +22,7 @@ class MemoryBankType(Enum): class CommonDef(BaseModel): identifier: str - provider_id: str + provider_id: Optional[str] = None @json_schema_type diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index 21dd17ca2..3a770af25 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -18,8 +18,8 @@ class ModelDef(BaseModel): llama_model: str = Field( description="Pointer to the core Llama family model", ) - provider_id: str = Field( - description="The provider instance which serves this model" + provider_id: Optional[str] = Field( + default=None, description="The provider instance which serves this model" ) # For now, we are only supporting core llama models but as soon as finetuned # and other custom models (for example various quantizations) are allowed, there diff --git a/llama_stack/apis/safety/client.py b/llama_stack/apis/safety/client.py index e601e6dba..35843e206 100644 --- a/llama_stack/apis/safety/client.py +++ b/llama_stack/apis/safety/client.py @@ -96,12 +96,6 @@ async def run_main(host: str, port: int, image_path: str = None): ) print(response) - response = await client.run_shield( - shield_type="injection_shield", - messages=[message], - ) - print(response) - def main(host: str, port: int, image: str = None): asyncio.run(run_main(host, port, image)) diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py index db507a383..cec82516e 100644 --- a/llama_stack/apis/shields/shields.py +++ b/llama_stack/apis/shields/shields.py @@ -23,12 +23,12 @@ class ShieldDef(BaseModel): identifier: str = Field( description="A unique identifier for the shield type", ) - provider_id: str = Field( - description="The provider instance which serves this shield" - ) type: str = Field( description="The type of shield this is; the value is one of the ShieldType enum" ) + provider_id: Optional[str] = Field( + default=None, description="The provider instance which serves this shield" + ) params: Dict[str, Any] = Field( default_factory=dict, description="Any additional parameters needed for this shield", diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 4db72d29e..857eef757 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -50,8 +50,10 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An f"Provider `{provider.provider_type}` is not available for API `{api}`" ) + p = all_api_providers[api][provider.provider_type] + p.deps__ = [a.value for a in p.api_dependencies] spec = ProviderWithSpec( - spec=all_api_providers[api][provider.provider_type], + spec=p, **(provider.dict()), ) specs[provider.provider_id] = spec @@ -93,6 +95,10 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An registry=registry, module="llama_stack.distribution.routers", api_dependencies=inner_deps, + deps__=( + [x.value for x in inner_deps] + + [f"inner-{info.router_api.value}"] + ), ), ) } @@ -107,6 +113,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An module="llama_stack.distribution.routers", routing_table_api=info.routing_table_api, api_dependencies=[info.routing_table_api], + deps__=([info.routing_table_api.value]), ), ) } @@ -130,6 +137,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An config_class="llama_stack.distribution.inspect.DistributionInspectConfig", module="llama_stack.distribution.inspect", api_dependencies=apis, + deps__=([x.value for x in apis]), ), ), ) @@ -175,10 +183,8 @@ def topological_sort( deps = [] for provider in providers: - for dep in provider.spec.api_dependencies: - deps.append(dep.value) - if isinstance(provider, AutoRoutedProviderSpec): - deps.append(f"inner-{provider.api}") + for dep in provider.spec.deps__: + deps.append(dep) for dep in deps: if dep not in visited: diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 73e26dd2e..7cb6e8432 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -39,6 +39,7 @@ class CommonRoutingTableImpl(RoutingTable): ) -> None: for obj in registry: if obj.provider_id not in impls_by_provider_id: + print(f"{impls_by_provider_id=}") raise ValueError( f"Provider `{obj.provider_id}` pointed by `{obj.identifier}` not found" ) @@ -70,7 +71,7 @@ class CommonRoutingTableImpl(RoutingTable): def get_provider_impl(self, routing_key: str) -> Any: if routing_key not in self.routing_key_to_object: - raise ValueError(f"Object `{routing_key}` not registered") + raise ValueError(f"`{routing_key}` not registered") obj = self.routing_key_to_object[routing_key] if obj.provider_id not in self.impls_by_provider_id: @@ -86,7 +87,7 @@ class CommonRoutingTableImpl(RoutingTable): async def register_object(self, obj: RoutableObject): if obj.identifier in self.routing_key_to_object: - print(f"Object `{obj.identifier}` is already registered") + print(f"`{obj.identifier}` is already registered") return if not obj.provider_id: diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 7b19f7996..5c1a7806d 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -11,13 +11,9 @@ import json import signal import traceback -from collections.abc import ( - AsyncGenerator as AsyncGeneratorABC, - AsyncIterator as AsyncIteratorABC, -) from contextlib import asynccontextmanager from ssl import SSLError -from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional +from typing import Any, Dict, Optional import fire import httpx @@ -44,42 +40,13 @@ from llama_stack.distribution.resolver import resolve_impls_with_routing from .endpoints import get_all_api_endpoints -def is_async_iterator_type(typ): - if hasattr(typ, "__origin__"): - origin = typ.__origin__ - if isinstance(origin, type): - return issubclass( - origin, - (AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC), - ) - return False - return isinstance( - typ, (AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC) - ) - - -def create_sse_event(data: Any, **kwargs) -> str: +def create_sse_event(data: Any) -> str: if isinstance(data, BaseModel): data = data.json() else: data = json.dumps(data) - # !!FIX THIS ASAP!! grossest hack ever; not really SSE - # - # we use the return type of the function to determine if there's an AsyncGenerator - # and change the implementation to send SSE. unfortunately, chat_completion() takes a - # parameter called stream which _changes_ the return type. one correct way to fix this is: - # - # - have separate underlying functions for streaming and non-streaming because they need - # to operate differently anyhow - # - do a late binding of the return type based on the parameters passed in - if kwargs.get("stream", False): - return f"data: {data}\n\n" - else: - print( - f"!!FIX THIS ASAP!! Sending non-SSE event because client really is non-SSE: {data}" - ) - return data + return f"data: {data}\n\n" async def global_exception_handler(request: Request, exc: Exception): @@ -221,65 +188,56 @@ def create_dynamic_passthrough( return endpoint +def is_streaming_request(func_name: str, request: Request, **kwargs): + # TODO: pass the api method and punt it to the Protocol definition directly + return kwargs.get("stream", False) + + +async def sse_generator(event_gen): + try: + async for item in event_gen: + yield create_sse_event(item) + await asyncio.sleep(0.01) + except asyncio.CancelledError: + print("Generator cancelled") + await event_gen.aclose() + except Exception as e: + traceback.print_exception(e) + yield create_sse_event( + { + "error": { + "message": str(translate_exception(e)), + }, + } + ) + finally: + await end_trace() + + def create_dynamic_typed_route(func: Any, method: str): - hints = get_type_hints(func) - response_model = hints.get("return") - # NOTE: I think it is better to just add a method within each Api - # "Protocol" / adapter-impl to tell what sort of a response this request - # is going to produce. /chat_completion can produce a streaming or - # non-streaming response depending on if request.stream is True / False. - is_streaming = is_async_iterator_type(response_model) + async def endpoint(request: Request, **kwargs): + await start_trace(func.__name__) - if is_streaming: + set_request_provider_data(request.headers) - async def endpoint(request: Request, **kwargs): - await start_trace(func.__name__) - - set_request_provider_data(request.headers) - - async def sse_generator(event_gen): - try: - async for item in event_gen: - yield create_sse_event(item, **kwargs) - await asyncio.sleep(0.01) - except asyncio.CancelledError: - print("Generator cancelled") - await event_gen.aclose() - except Exception as e: - traceback.print_exception(e) - yield create_sse_event( - { - "error": { - "message": str(translate_exception(e)), - }, - } - ) - finally: - await end_trace() - - return StreamingResponse( - sse_generator(func(**kwargs)), media_type="text/event-stream" - ) - - else: - - async def endpoint(request: Request, **kwargs): - await start_trace(func.__name__) - - set_request_provider_data(request.headers) - - try: + is_streaming = is_streaming_request(func.__name__, request, **kwargs) + try: + if is_streaming: + return StreamingResponse( + sse_generator(func(**kwargs)), media_type="text/event-stream" + ) + else: return ( await func(**kwargs) if asyncio.iscoroutinefunction(func) else func(**kwargs) ) - except Exception as e: - traceback.print_exception(e) - raise translate_exception(e) from e - finally: - await end_trace() + except Exception as e: + traceback.print_exception(e) + raise translate_exception(e) from e + finally: + await end_trace() sig = inspect.signature(func) new_params = [ diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 0c8f6ad21..44ecb5355 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -41,6 +41,9 @@ class ProviderSpec(BaseModel): description="Higher-level API surfaces may depend on other providers to provide their functionality", ) + # used internally by the resolver; this is a hack for now + deps__: List[str] = Field(default_factory=list) + class RoutingTable(Protocol): def get_provider_impl(self, routing_key: str) -> Any: ... diff --git a/llama_stack/providers/impls/meta_reference/agents/agent_instance.py b/llama_stack/providers/impls/meta_reference/agents/agent_instance.py index 661da10cc..fca335bf5 100644 --- a/llama_stack/providers/impls/meta_reference/agents/agent_instance.py +++ b/llama_stack/providers/impls/meta_reference/agents/agent_instance.py @@ -144,6 +144,8 @@ class ChatAgent(ShieldRunnerMixin): async def create_and_execute_turn( self, request: AgentTurnCreateRequest ) -> AsyncGenerator: + assert request.stream is True, "Non-streaming not supported" + session_info = await self.storage.get_session_info(request.session_id) if session_info is None: raise ValueError(f"Session {request.session_id} not found") @@ -635,14 +637,13 @@ class ChatAgent(ShieldRunnerMixin): raise ValueError(f"Session {session_id} not found") if session_info.memory_bank_id is None: - memory_bank = await self.memory_api.create_memory_bank( - name=f"memory_bank_{session_id}", - config=VectorMemoryBankConfig( - embedding_model="all-MiniLM-L6-v2", - chunk_size_in_tokens=512, - ), + bank_id = f"memory_bank_{session_id}" + memory_bank = VectorMemoryBankDef( + identifier=bank_id, + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, ) - bank_id = memory_bank.bank_id + await self.memory_api.register_memory_bank(memory_bank) await self.storage.add_memory_bank_to_session(session_id, bank_id) else: bank_id = session_info.memory_bank_id diff --git a/llama_stack/providers/impls/meta_reference/agents/agents.py b/llama_stack/providers/impls/meta_reference/agents/agents.py index 0673cd16f..e6fa1744d 100644 --- a/llama_stack/providers/impls/meta_reference/agents/agents.py +++ b/llama_stack/providers/impls/meta_reference/agents/agents.py @@ -100,7 +100,7 @@ class MetaReferenceAgentsImpl(Agents): session_id=session_id, ) - async def create_agent_turn( + def create_agent_turn( self, agent_id: str, session_id: str, @@ -113,16 +113,22 @@ class MetaReferenceAgentsImpl(Agents): attachments: Optional[List[Attachment]] = None, stream: Optional[bool] = False, ) -> AsyncGenerator: - agent = await self.get_agent(agent_id) - - # wrapper request to make it easier to pass around (internal only, not exposed to API) request = AgentTurnCreateRequest( agent_id=agent_id, session_id=session_id, messages=messages, attachments=attachments, - stream=stream, + stream=True, ) + if stream: + return self._create_agent_turn_streaming(request) + else: + raise NotImplementedError("Non-streaming agent turns not yet implemented") + async def _create_agent_turn_streaming( + self, + request: AgentTurnCreateRequest, + ) -> AsyncGenerator: + agent = await self.get_agent(request.agent_id) async for event in agent.create_and_execute_turn(request): yield event diff --git a/llama_stack/providers/tests/agents/__init__.py b/llama_stack/providers/tests/agents/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/tests/agents/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree.