mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
Updates to server.py to clean up streaming vs non-streaming stuff
Also make sure agent turn create is correctly marked
This commit is contained in:
parent
640c5c54f7
commit
7f1160296c
13 changed files with 115 additions and 128 deletions
|
@ -411,8 +411,10 @@ class Agents(Protocol):
|
|||
agent_config: AgentConfig,
|
||||
) -> AgentCreateResponse: ...
|
||||
|
||||
# This method is not `async def` because it can result in either an
|
||||
# `AsyncGenerator` or a `AgentTurnCreateResponse` depending on the value of `stream`.
|
||||
@webmethod(route="/agents/turn/create")
|
||||
async def create_agent_turn(
|
||||
def create_agent_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from typing import AsyncGenerator
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
import fire
|
||||
import httpx
|
||||
|
@ -67,9 +67,17 @@ class AgentsClient(Agents):
|
|||
response.raise_for_status()
|
||||
return AgentSessionCreateResponse(**response.json())
|
||||
|
||||
async def create_agent_turn(
|
||||
def create_agent_turn(
|
||||
self,
|
||||
request: AgentTurnCreateRequest,
|
||||
) -> AsyncGenerator:
|
||||
if request.stream:
|
||||
return self._stream_agent_turn(request)
|
||||
else:
|
||||
return self._nonstream_agent_turn(request)
|
||||
|
||||
async def _stream_agent_turn(
|
||||
self, request: AgentTurnCreateRequest
|
||||
) -> AsyncGenerator:
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.stream(
|
||||
|
@ -93,6 +101,9 @@ class AgentsClient(Agents):
|
|||
print(data)
|
||||
print(f"Error with parsing or validation: {e}")
|
||||
|
||||
async def _nonstream_agent_turn(self, request: AgentTurnCreateRequest):
|
||||
raise NotImplementedError("Non-streaming not implemented yet")
|
||||
|
||||
|
||||
async def _run_agent(
|
||||
api, model, tool_definitions, tool_prompt_format, user_prompts, attachments=None
|
||||
|
@ -132,8 +143,7 @@ async def _run_agent(
|
|||
log.print()
|
||||
|
||||
|
||||
async def run_llama_3_1(host: str, port: int):
|
||||
model = "Llama3.1-8B-Instruct"
|
||||
async def run_llama_3_1(host: str, port: int, model: str = "Llama3.1-8B-Instruct"):
|
||||
api = AgentsClient(f"http://{host}:{port}")
|
||||
|
||||
tool_definitions = [
|
||||
|
@ -173,8 +183,7 @@ async def run_llama_3_1(host: str, port: int):
|
|||
await _run_agent(api, model, tool_definitions, ToolPromptFormat.json, user_prompts)
|
||||
|
||||
|
||||
async def run_llama_3_2_rag(host: str, port: int):
|
||||
model = "Llama3.2-3B-Instruct"
|
||||
async def run_llama_3_2_rag(host: str, port: int, model: str = "Llama3.2-3B-Instruct"):
|
||||
api = AgentsClient(f"http://{host}:{port}")
|
||||
|
||||
urls = [
|
||||
|
@ -215,8 +224,7 @@ async def run_llama_3_2_rag(host: str, port: int):
|
|||
)
|
||||
|
||||
|
||||
async def run_llama_3_2(host: str, port: int):
|
||||
model = "Llama3.2-3B-Instruct"
|
||||
async def run_llama_3_2(host: str, port: int, model: str = "Llama3.2-3B-Instruct"):
|
||||
api = AgentsClient(f"http://{host}:{port}")
|
||||
|
||||
# zero shot tools for llama3.2 text models
|
||||
|
@ -262,7 +270,7 @@ async def run_llama_3_2(host: str, port: int):
|
|||
)
|
||||
|
||||
|
||||
def main(host: str, port: int, run_type: str):
|
||||
def main(host: str, port: int, run_type: str, model: Optional[str] = None):
|
||||
assert run_type in [
|
||||
"tools_llama_3_1",
|
||||
"tools_llama_3_2",
|
||||
|
@ -274,7 +282,10 @@ def main(host: str, port: int, run_type: str):
|
|||
"tools_llama_3_2": run_llama_3_2,
|
||||
"rag_llama_3_2": run_llama_3_2_rag,
|
||||
}
|
||||
asyncio.run(fn[run_type](host, port))
|
||||
args = [host, port]
|
||||
if model is not None:
|
||||
args.append(model)
|
||||
asyncio.run(fn[run_type](*args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -22,7 +22,7 @@ class MemoryBankType(Enum):
|
|||
|
||||
class CommonDef(BaseModel):
|
||||
identifier: str
|
||||
provider_id: str
|
||||
provider_id: Optional[str] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -18,8 +18,8 @@ class ModelDef(BaseModel):
|
|||
llama_model: str = Field(
|
||||
description="Pointer to the core Llama family model",
|
||||
)
|
||||
provider_id: str = Field(
|
||||
description="The provider instance which serves this model"
|
||||
provider_id: Optional[str] = Field(
|
||||
default=None, description="The provider instance which serves this model"
|
||||
)
|
||||
# For now, we are only supporting core llama models but as soon as finetuned
|
||||
# and other custom models (for example various quantizations) are allowed, there
|
||||
|
|
|
@ -96,12 +96,6 @@ async def run_main(host: str, port: int, image_path: str = None):
|
|||
)
|
||||
print(response)
|
||||
|
||||
response = await client.run_shield(
|
||||
shield_type="injection_shield",
|
||||
messages=[message],
|
||||
)
|
||||
print(response)
|
||||
|
||||
|
||||
def main(host: str, port: int, image: str = None):
|
||||
asyncio.run(run_main(host, port, image))
|
||||
|
|
|
@ -23,12 +23,12 @@ class ShieldDef(BaseModel):
|
|||
identifier: str = Field(
|
||||
description="A unique identifier for the shield type",
|
||||
)
|
||||
provider_id: str = Field(
|
||||
description="The provider instance which serves this shield"
|
||||
)
|
||||
type: str = Field(
|
||||
description="The type of shield this is; the value is one of the ShieldType enum"
|
||||
)
|
||||
provider_id: Optional[str] = Field(
|
||||
default=None, description="The provider instance which serves this shield"
|
||||
)
|
||||
params: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Any additional parameters needed for this shield",
|
||||
|
|
|
@ -50,8 +50,10 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
|||
f"Provider `{provider.provider_type}` is not available for API `{api}`"
|
||||
)
|
||||
|
||||
p = all_api_providers[api][provider.provider_type]
|
||||
p.deps__ = [a.value for a in p.api_dependencies]
|
||||
spec = ProviderWithSpec(
|
||||
spec=all_api_providers[api][provider.provider_type],
|
||||
spec=p,
|
||||
**(provider.dict()),
|
||||
)
|
||||
specs[provider.provider_id] = spec
|
||||
|
@ -93,6 +95,10 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
|||
registry=registry,
|
||||
module="llama_stack.distribution.routers",
|
||||
api_dependencies=inner_deps,
|
||||
deps__=(
|
||||
[x.value for x in inner_deps]
|
||||
+ [f"inner-{info.router_api.value}"]
|
||||
),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
@ -107,6 +113,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
|||
module="llama_stack.distribution.routers",
|
||||
routing_table_api=info.routing_table_api,
|
||||
api_dependencies=[info.routing_table_api],
|
||||
deps__=([info.routing_table_api.value]),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
@ -130,6 +137,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
|||
config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
|
||||
module="llama_stack.distribution.inspect",
|
||||
api_dependencies=apis,
|
||||
deps__=([x.value for x in apis]),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
@ -175,10 +183,8 @@ def topological_sort(
|
|||
|
||||
deps = []
|
||||
for provider in providers:
|
||||
for dep in provider.spec.api_dependencies:
|
||||
deps.append(dep.value)
|
||||
if isinstance(provider, AutoRoutedProviderSpec):
|
||||
deps.append(f"inner-{provider.api}")
|
||||
for dep in provider.spec.deps__:
|
||||
deps.append(dep)
|
||||
|
||||
for dep in deps:
|
||||
if dep not in visited:
|
||||
|
|
|
@ -39,6 +39,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
) -> None:
|
||||
for obj in registry:
|
||||
if obj.provider_id not in impls_by_provider_id:
|
||||
print(f"{impls_by_provider_id=}")
|
||||
raise ValueError(
|
||||
f"Provider `{obj.provider_id}` pointed by `{obj.identifier}` not found"
|
||||
)
|
||||
|
@ -70,7 +71,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
|
||||
def get_provider_impl(self, routing_key: str) -> Any:
|
||||
if routing_key not in self.routing_key_to_object:
|
||||
raise ValueError(f"Object `{routing_key}` not registered")
|
||||
raise ValueError(f"`{routing_key}` not registered")
|
||||
|
||||
obj = self.routing_key_to_object[routing_key]
|
||||
if obj.provider_id not in self.impls_by_provider_id:
|
||||
|
@ -86,7 +87,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
|
||||
async def register_object(self, obj: RoutableObject):
|
||||
if obj.identifier in self.routing_key_to_object:
|
||||
print(f"Object `{obj.identifier}` is already registered")
|
||||
print(f"`{obj.identifier}` is already registered")
|
||||
return
|
||||
|
||||
if not obj.provider_id:
|
||||
|
|
|
@ -11,13 +11,9 @@ import json
|
|||
import signal
|
||||
import traceback
|
||||
|
||||
from collections.abc import (
|
||||
AsyncGenerator as AsyncGeneratorABC,
|
||||
AsyncIterator as AsyncIteratorABC,
|
||||
)
|
||||
from contextlib import asynccontextmanager
|
||||
from ssl import SSLError
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import fire
|
||||
import httpx
|
||||
|
@ -44,42 +40,13 @@ from llama_stack.distribution.resolver import resolve_impls_with_routing
|
|||
from .endpoints import get_all_api_endpoints
|
||||
|
||||
|
||||
def is_async_iterator_type(typ):
|
||||
if hasattr(typ, "__origin__"):
|
||||
origin = typ.__origin__
|
||||
if isinstance(origin, type):
|
||||
return issubclass(
|
||||
origin,
|
||||
(AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC),
|
||||
)
|
||||
return False
|
||||
return isinstance(
|
||||
typ, (AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC)
|
||||
)
|
||||
|
||||
|
||||
def create_sse_event(data: Any, **kwargs) -> str:
|
||||
def create_sse_event(data: Any) -> str:
|
||||
if isinstance(data, BaseModel):
|
||||
data = data.json()
|
||||
else:
|
||||
data = json.dumps(data)
|
||||
|
||||
# !!FIX THIS ASAP!! grossest hack ever; not really SSE
|
||||
#
|
||||
# we use the return type of the function to determine if there's an AsyncGenerator
|
||||
# and change the implementation to send SSE. unfortunately, chat_completion() takes a
|
||||
# parameter called stream which _changes_ the return type. one correct way to fix this is:
|
||||
#
|
||||
# - have separate underlying functions for streaming and non-streaming because they need
|
||||
# to operate differently anyhow
|
||||
# - do a late binding of the return type based on the parameters passed in
|
||||
if kwargs.get("stream", False):
|
||||
return f"data: {data}\n\n"
|
||||
else:
|
||||
print(
|
||||
f"!!FIX THIS ASAP!! Sending non-SSE event because client really is non-SSE: {data}"
|
||||
)
|
||||
return data
|
||||
return f"data: {data}\n\n"
|
||||
|
||||
|
||||
async def global_exception_handler(request: Request, exc: Exception):
|
||||
|
@ -221,65 +188,56 @@ def create_dynamic_passthrough(
|
|||
return endpoint
|
||||
|
||||
|
||||
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
||||
# TODO: pass the api method and punt it to the Protocol definition directly
|
||||
return kwargs.get("stream", False)
|
||||
|
||||
|
||||
async def sse_generator(event_gen):
|
||||
try:
|
||||
async for item in event_gen:
|
||||
yield create_sse_event(item)
|
||||
await asyncio.sleep(0.01)
|
||||
except asyncio.CancelledError:
|
||||
print("Generator cancelled")
|
||||
await event_gen.aclose()
|
||||
except Exception as e:
|
||||
traceback.print_exception(e)
|
||||
yield create_sse_event(
|
||||
{
|
||||
"error": {
|
||||
"message": str(translate_exception(e)),
|
||||
},
|
||||
}
|
||||
)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
|
||||
def create_dynamic_typed_route(func: Any, method: str):
|
||||
hints = get_type_hints(func)
|
||||
response_model = hints.get("return")
|
||||
|
||||
# NOTE: I think it is better to just add a method within each Api
|
||||
# "Protocol" / adapter-impl to tell what sort of a response this request
|
||||
# is going to produce. /chat_completion can produce a streaming or
|
||||
# non-streaming response depending on if request.stream is True / False.
|
||||
is_streaming = is_async_iterator_type(response_model)
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
await start_trace(func.__name__)
|
||||
|
||||
if is_streaming:
|
||||
set_request_provider_data(request.headers)
|
||||
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
await start_trace(func.__name__)
|
||||
|
||||
set_request_provider_data(request.headers)
|
||||
|
||||
async def sse_generator(event_gen):
|
||||
try:
|
||||
async for item in event_gen:
|
||||
yield create_sse_event(item, **kwargs)
|
||||
await asyncio.sleep(0.01)
|
||||
except asyncio.CancelledError:
|
||||
print("Generator cancelled")
|
||||
await event_gen.aclose()
|
||||
except Exception as e:
|
||||
traceback.print_exception(e)
|
||||
yield create_sse_event(
|
||||
{
|
||||
"error": {
|
||||
"message": str(translate_exception(e)),
|
||||
},
|
||||
}
|
||||
)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
return StreamingResponse(
|
||||
sse_generator(func(**kwargs)), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
await start_trace(func.__name__)
|
||||
|
||||
set_request_provider_data(request.headers)
|
||||
|
||||
try:
|
||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
try:
|
||||
if is_streaming:
|
||||
return StreamingResponse(
|
||||
sse_generator(func(**kwargs)), media_type="text/event-stream"
|
||||
)
|
||||
else:
|
||||
return (
|
||||
await func(**kwargs)
|
||||
if asyncio.iscoroutinefunction(func)
|
||||
else func(**kwargs)
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exception(e)
|
||||
raise translate_exception(e) from e
|
||||
finally:
|
||||
await end_trace()
|
||||
except Exception as e:
|
||||
traceback.print_exception(e)
|
||||
raise translate_exception(e) from e
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
sig = inspect.signature(func)
|
||||
new_params = [
|
||||
|
|
|
@ -41,6 +41,9 @@ class ProviderSpec(BaseModel):
|
|||
description="Higher-level API surfaces may depend on other providers to provide their functionality",
|
||||
)
|
||||
|
||||
# used internally by the resolver; this is a hack for now
|
||||
deps__: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class RoutingTable(Protocol):
|
||||
def get_provider_impl(self, routing_key: str) -> Any: ...
|
||||
|
|
|
@ -144,6 +144,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
async def create_and_execute_turn(
|
||||
self, request: AgentTurnCreateRequest
|
||||
) -> AsyncGenerator:
|
||||
assert request.stream is True, "Non-streaming not supported"
|
||||
|
||||
session_info = await self.storage.get_session_info(request.session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {request.session_id} not found")
|
||||
|
@ -635,14 +637,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
raise ValueError(f"Session {session_id} not found")
|
||||
|
||||
if session_info.memory_bank_id is None:
|
||||
memory_bank = await self.memory_api.create_memory_bank(
|
||||
name=f"memory_bank_{session_id}",
|
||||
config=VectorMemoryBankConfig(
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
),
|
||||
bank_id = f"memory_bank_{session_id}"
|
||||
memory_bank = VectorMemoryBankDef(
|
||||
identifier=bank_id,
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
)
|
||||
bank_id = memory_bank.bank_id
|
||||
await self.memory_api.register_memory_bank(memory_bank)
|
||||
await self.storage.add_memory_bank_to_session(session_id, bank_id)
|
||||
else:
|
||||
bank_id = session_info.memory_bank_id
|
||||
|
|
|
@ -100,7 +100,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
session_id=session_id,
|
||||
)
|
||||
|
||||
async def create_agent_turn(
|
||||
def create_agent_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
|
@ -113,16 +113,22 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
attachments: Optional[List[Attachment]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
) -> AsyncGenerator:
|
||||
agent = await self.get_agent(agent_id)
|
||||
|
||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||
request = AgentTurnCreateRequest(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=messages,
|
||||
attachments=attachments,
|
||||
stream=stream,
|
||||
stream=True,
|
||||
)
|
||||
if stream:
|
||||
return self._create_agent_turn_streaming(request)
|
||||
else:
|
||||
raise NotImplementedError("Non-streaming agent turns not yet implemented")
|
||||
|
||||
async def _create_agent_turn_streaming(
|
||||
self,
|
||||
request: AgentTurnCreateRequest,
|
||||
) -> AsyncGenerator:
|
||||
agent = await self.get_agent(request.agent_id)
|
||||
async for event in agent.create_and_execute_turn(request):
|
||||
yield event
|
||||
|
|
5
llama_stack/providers/tests/agents/__init__.py
Normal file
5
llama_stack/providers/tests/agents/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
Loading…
Add table
Add a link
Reference in a new issue