mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Updates to server.py to clean up streaming vs non-streaming stuff
Also make sure agent turn create is correctly marked
This commit is contained in:
parent
640c5c54f7
commit
7f1160296c
13 changed files with 115 additions and 128 deletions
|
@ -411,8 +411,10 @@ class Agents(Protocol):
|
||||||
agent_config: AgentConfig,
|
agent_config: AgentConfig,
|
||||||
) -> AgentCreateResponse: ...
|
) -> AgentCreateResponse: ...
|
||||||
|
|
||||||
|
# This method is not `async def` because it can result in either an
|
||||||
|
# `AsyncGenerator` or a `AgentTurnCreateResponse` depending on the value of `stream`.
|
||||||
@webmethod(route="/agents/turn/create")
|
@webmethod(route="/agents/turn/create")
|
||||||
async def create_agent_turn(
|
def create_agent_turn(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator, Optional
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -67,9 +67,17 @@ class AgentsClient(Agents):
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return AgentSessionCreateResponse(**response.json())
|
return AgentSessionCreateResponse(**response.json())
|
||||||
|
|
||||||
async def create_agent_turn(
|
def create_agent_turn(
|
||||||
self,
|
self,
|
||||||
request: AgentTurnCreateRequest,
|
request: AgentTurnCreateRequest,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
if request.stream:
|
||||||
|
return self._stream_agent_turn(request)
|
||||||
|
else:
|
||||||
|
return self._nonstream_agent_turn(request)
|
||||||
|
|
||||||
|
async def _stream_agent_turn(
|
||||||
|
self, request: AgentTurnCreateRequest
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
async with client.stream(
|
async with client.stream(
|
||||||
|
@ -93,6 +101,9 @@ class AgentsClient(Agents):
|
||||||
print(data)
|
print(data)
|
||||||
print(f"Error with parsing or validation: {e}")
|
print(f"Error with parsing or validation: {e}")
|
||||||
|
|
||||||
|
async def _nonstream_agent_turn(self, request: AgentTurnCreateRequest):
|
||||||
|
raise NotImplementedError("Non-streaming not implemented yet")
|
||||||
|
|
||||||
|
|
||||||
async def _run_agent(
|
async def _run_agent(
|
||||||
api, model, tool_definitions, tool_prompt_format, user_prompts, attachments=None
|
api, model, tool_definitions, tool_prompt_format, user_prompts, attachments=None
|
||||||
|
@ -132,8 +143,7 @@ async def _run_agent(
|
||||||
log.print()
|
log.print()
|
||||||
|
|
||||||
|
|
||||||
async def run_llama_3_1(host: str, port: int):
|
async def run_llama_3_1(host: str, port: int, model: str = "Llama3.1-8B-Instruct"):
|
||||||
model = "Llama3.1-8B-Instruct"
|
|
||||||
api = AgentsClient(f"http://{host}:{port}")
|
api = AgentsClient(f"http://{host}:{port}")
|
||||||
|
|
||||||
tool_definitions = [
|
tool_definitions = [
|
||||||
|
@ -173,8 +183,7 @@ async def run_llama_3_1(host: str, port: int):
|
||||||
await _run_agent(api, model, tool_definitions, ToolPromptFormat.json, user_prompts)
|
await _run_agent(api, model, tool_definitions, ToolPromptFormat.json, user_prompts)
|
||||||
|
|
||||||
|
|
||||||
async def run_llama_3_2_rag(host: str, port: int):
|
async def run_llama_3_2_rag(host: str, port: int, model: str = "Llama3.2-3B-Instruct"):
|
||||||
model = "Llama3.2-3B-Instruct"
|
|
||||||
api = AgentsClient(f"http://{host}:{port}")
|
api = AgentsClient(f"http://{host}:{port}")
|
||||||
|
|
||||||
urls = [
|
urls = [
|
||||||
|
@ -215,8 +224,7 @@ async def run_llama_3_2_rag(host: str, port: int):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def run_llama_3_2(host: str, port: int):
|
async def run_llama_3_2(host: str, port: int, model: str = "Llama3.2-3B-Instruct"):
|
||||||
model = "Llama3.2-3B-Instruct"
|
|
||||||
api = AgentsClient(f"http://{host}:{port}")
|
api = AgentsClient(f"http://{host}:{port}")
|
||||||
|
|
||||||
# zero shot tools for llama3.2 text models
|
# zero shot tools for llama3.2 text models
|
||||||
|
@ -262,7 +270,7 @@ async def run_llama_3_2(host: str, port: int):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def main(host: str, port: int, run_type: str):
|
def main(host: str, port: int, run_type: str, model: Optional[str] = None):
|
||||||
assert run_type in [
|
assert run_type in [
|
||||||
"tools_llama_3_1",
|
"tools_llama_3_1",
|
||||||
"tools_llama_3_2",
|
"tools_llama_3_2",
|
||||||
|
@ -274,7 +282,10 @@ def main(host: str, port: int, run_type: str):
|
||||||
"tools_llama_3_2": run_llama_3_2,
|
"tools_llama_3_2": run_llama_3_2,
|
||||||
"rag_llama_3_2": run_llama_3_2_rag,
|
"rag_llama_3_2": run_llama_3_2_rag,
|
||||||
}
|
}
|
||||||
asyncio.run(fn[run_type](host, port))
|
args = [host, port]
|
||||||
|
if model is not None:
|
||||||
|
args.append(model)
|
||||||
|
asyncio.run(fn[run_type](*args))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -22,7 +22,7 @@ class MemoryBankType(Enum):
|
||||||
|
|
||||||
class CommonDef(BaseModel):
|
class CommonDef(BaseModel):
|
||||||
identifier: str
|
identifier: str
|
||||||
provider_id: str
|
provider_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -18,8 +18,8 @@ class ModelDef(BaseModel):
|
||||||
llama_model: str = Field(
|
llama_model: str = Field(
|
||||||
description="Pointer to the core Llama family model",
|
description="Pointer to the core Llama family model",
|
||||||
)
|
)
|
||||||
provider_id: str = Field(
|
provider_id: Optional[str] = Field(
|
||||||
description="The provider instance which serves this model"
|
default=None, description="The provider instance which serves this model"
|
||||||
)
|
)
|
||||||
# For now, we are only supporting core llama models but as soon as finetuned
|
# For now, we are only supporting core llama models but as soon as finetuned
|
||||||
# and other custom models (for example various quantizations) are allowed, there
|
# and other custom models (for example various quantizations) are allowed, there
|
||||||
|
|
|
@ -96,12 +96,6 @@ async def run_main(host: str, port: int, image_path: str = None):
|
||||||
)
|
)
|
||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
response = await client.run_shield(
|
|
||||||
shield_type="injection_shield",
|
|
||||||
messages=[message],
|
|
||||||
)
|
|
||||||
print(response)
|
|
||||||
|
|
||||||
|
|
||||||
def main(host: str, port: int, image: str = None):
|
def main(host: str, port: int, image: str = None):
|
||||||
asyncio.run(run_main(host, port, image))
|
asyncio.run(run_main(host, port, image))
|
||||||
|
|
|
@ -23,12 +23,12 @@ class ShieldDef(BaseModel):
|
||||||
identifier: str = Field(
|
identifier: str = Field(
|
||||||
description="A unique identifier for the shield type",
|
description="A unique identifier for the shield type",
|
||||||
)
|
)
|
||||||
provider_id: str = Field(
|
|
||||||
description="The provider instance which serves this shield"
|
|
||||||
)
|
|
||||||
type: str = Field(
|
type: str = Field(
|
||||||
description="The type of shield this is; the value is one of the ShieldType enum"
|
description="The type of shield this is; the value is one of the ShieldType enum"
|
||||||
)
|
)
|
||||||
|
provider_id: Optional[str] = Field(
|
||||||
|
default=None, description="The provider instance which serves this shield"
|
||||||
|
)
|
||||||
params: Dict[str, Any] = Field(
|
params: Dict[str, Any] = Field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
description="Any additional parameters needed for this shield",
|
description="Any additional parameters needed for this shield",
|
||||||
|
|
|
@ -50,8 +50,10 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
||||||
f"Provider `{provider.provider_type}` is not available for API `{api}`"
|
f"Provider `{provider.provider_type}` is not available for API `{api}`"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
p = all_api_providers[api][provider.provider_type]
|
||||||
|
p.deps__ = [a.value for a in p.api_dependencies]
|
||||||
spec = ProviderWithSpec(
|
spec = ProviderWithSpec(
|
||||||
spec=all_api_providers[api][provider.provider_type],
|
spec=p,
|
||||||
**(provider.dict()),
|
**(provider.dict()),
|
||||||
)
|
)
|
||||||
specs[provider.provider_id] = spec
|
specs[provider.provider_id] = spec
|
||||||
|
@ -93,6 +95,10 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
||||||
registry=registry,
|
registry=registry,
|
||||||
module="llama_stack.distribution.routers",
|
module="llama_stack.distribution.routers",
|
||||||
api_dependencies=inner_deps,
|
api_dependencies=inner_deps,
|
||||||
|
deps__=(
|
||||||
|
[x.value for x in inner_deps]
|
||||||
|
+ [f"inner-{info.router_api.value}"]
|
||||||
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -107,6 +113,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
||||||
module="llama_stack.distribution.routers",
|
module="llama_stack.distribution.routers",
|
||||||
routing_table_api=info.routing_table_api,
|
routing_table_api=info.routing_table_api,
|
||||||
api_dependencies=[info.routing_table_api],
|
api_dependencies=[info.routing_table_api],
|
||||||
|
deps__=([info.routing_table_api.value]),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -130,6 +137,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
||||||
config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
|
config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
|
||||||
module="llama_stack.distribution.inspect",
|
module="llama_stack.distribution.inspect",
|
||||||
api_dependencies=apis,
|
api_dependencies=apis,
|
||||||
|
deps__=([x.value for x in apis]),
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -175,10 +183,8 @@ def topological_sort(
|
||||||
|
|
||||||
deps = []
|
deps = []
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
for dep in provider.spec.api_dependencies:
|
for dep in provider.spec.deps__:
|
||||||
deps.append(dep.value)
|
deps.append(dep)
|
||||||
if isinstance(provider, AutoRoutedProviderSpec):
|
|
||||||
deps.append(f"inner-{provider.api}")
|
|
||||||
|
|
||||||
for dep in deps:
|
for dep in deps:
|
||||||
if dep not in visited:
|
if dep not in visited:
|
||||||
|
|
|
@ -39,6 +39,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
) -> None:
|
) -> None:
|
||||||
for obj in registry:
|
for obj in registry:
|
||||||
if obj.provider_id not in impls_by_provider_id:
|
if obj.provider_id not in impls_by_provider_id:
|
||||||
|
print(f"{impls_by_provider_id=}")
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Provider `{obj.provider_id}` pointed by `{obj.identifier}` not found"
|
f"Provider `{obj.provider_id}` pointed by `{obj.identifier}` not found"
|
||||||
)
|
)
|
||||||
|
@ -70,7 +71,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
|
|
||||||
def get_provider_impl(self, routing_key: str) -> Any:
|
def get_provider_impl(self, routing_key: str) -> Any:
|
||||||
if routing_key not in self.routing_key_to_object:
|
if routing_key not in self.routing_key_to_object:
|
||||||
raise ValueError(f"Object `{routing_key}` not registered")
|
raise ValueError(f"`{routing_key}` not registered")
|
||||||
|
|
||||||
obj = self.routing_key_to_object[routing_key]
|
obj = self.routing_key_to_object[routing_key]
|
||||||
if obj.provider_id not in self.impls_by_provider_id:
|
if obj.provider_id not in self.impls_by_provider_id:
|
||||||
|
@ -86,7 +87,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
|
|
||||||
async def register_object(self, obj: RoutableObject):
|
async def register_object(self, obj: RoutableObject):
|
||||||
if obj.identifier in self.routing_key_to_object:
|
if obj.identifier in self.routing_key_to_object:
|
||||||
print(f"Object `{obj.identifier}` is already registered")
|
print(f"`{obj.identifier}` is already registered")
|
||||||
return
|
return
|
||||||
|
|
||||||
if not obj.provider_id:
|
if not obj.provider_id:
|
||||||
|
|
|
@ -11,13 +11,9 @@ import json
|
||||||
import signal
|
import signal
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from collections.abc import (
|
|
||||||
AsyncGenerator as AsyncGeneratorABC,
|
|
||||||
AsyncIterator as AsyncIteratorABC,
|
|
||||||
)
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from ssl import SSLError
|
from ssl import SSLError
|
||||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -44,42 +40,13 @@ from llama_stack.distribution.resolver import resolve_impls_with_routing
|
||||||
from .endpoints import get_all_api_endpoints
|
from .endpoints import get_all_api_endpoints
|
||||||
|
|
||||||
|
|
||||||
def is_async_iterator_type(typ):
|
def create_sse_event(data: Any) -> str:
|
||||||
if hasattr(typ, "__origin__"):
|
|
||||||
origin = typ.__origin__
|
|
||||||
if isinstance(origin, type):
|
|
||||||
return issubclass(
|
|
||||||
origin,
|
|
||||||
(AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC),
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
return isinstance(
|
|
||||||
typ, (AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_sse_event(data: Any, **kwargs) -> str:
|
|
||||||
if isinstance(data, BaseModel):
|
if isinstance(data, BaseModel):
|
||||||
data = data.json()
|
data = data.json()
|
||||||
else:
|
else:
|
||||||
data = json.dumps(data)
|
data = json.dumps(data)
|
||||||
|
|
||||||
# !!FIX THIS ASAP!! grossest hack ever; not really SSE
|
|
||||||
#
|
|
||||||
# we use the return type of the function to determine if there's an AsyncGenerator
|
|
||||||
# and change the implementation to send SSE. unfortunately, chat_completion() takes a
|
|
||||||
# parameter called stream which _changes_ the return type. one correct way to fix this is:
|
|
||||||
#
|
|
||||||
# - have separate underlying functions for streaming and non-streaming because they need
|
|
||||||
# to operate differently anyhow
|
|
||||||
# - do a late binding of the return type based on the parameters passed in
|
|
||||||
if kwargs.get("stream", False):
|
|
||||||
return f"data: {data}\n\n"
|
return f"data: {data}\n\n"
|
||||||
else:
|
|
||||||
print(
|
|
||||||
f"!!FIX THIS ASAP!! Sending non-SSE event because client really is non-SSE: {data}"
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
async def global_exception_handler(request: Request, exc: Exception):
|
async def global_exception_handler(request: Request, exc: Exception):
|
||||||
|
@ -221,27 +188,15 @@ def create_dynamic_passthrough(
|
||||||
return endpoint
|
return endpoint
|
||||||
|
|
||||||
|
|
||||||
def create_dynamic_typed_route(func: Any, method: str):
|
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
||||||
hints = get_type_hints(func)
|
# TODO: pass the api method and punt it to the Protocol definition directly
|
||||||
response_model = hints.get("return")
|
return kwargs.get("stream", False)
|
||||||
|
|
||||||
# NOTE: I think it is better to just add a method within each Api
|
|
||||||
# "Protocol" / adapter-impl to tell what sort of a response this request
|
|
||||||
# is going to produce. /chat_completion can produce a streaming or
|
|
||||||
# non-streaming response depending on if request.stream is True / False.
|
|
||||||
is_streaming = is_async_iterator_type(response_model)
|
|
||||||
|
|
||||||
if is_streaming:
|
async def sse_generator(event_gen):
|
||||||
|
|
||||||
async def endpoint(request: Request, **kwargs):
|
|
||||||
await start_trace(func.__name__)
|
|
||||||
|
|
||||||
set_request_provider_data(request.headers)
|
|
||||||
|
|
||||||
async def sse_generator(event_gen):
|
|
||||||
try:
|
try:
|
||||||
async for item in event_gen:
|
async for item in event_gen:
|
||||||
yield create_sse_event(item, **kwargs)
|
yield create_sse_event(item)
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
print("Generator cancelled")
|
print("Generator cancelled")
|
||||||
|
@ -258,18 +213,21 @@ def create_dynamic_typed_route(func: Any, method: str):
|
||||||
finally:
|
finally:
|
||||||
await end_trace()
|
await end_trace()
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
sse_generator(func(**kwargs)), media_type="text/event-stream"
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
def create_dynamic_typed_route(func: Any, method: str):
|
||||||
|
|
||||||
async def endpoint(request: Request, **kwargs):
|
async def endpoint(request: Request, **kwargs):
|
||||||
await start_trace(func.__name__)
|
await start_trace(func.__name__)
|
||||||
|
|
||||||
set_request_provider_data(request.headers)
|
set_request_provider_data(request.headers)
|
||||||
|
|
||||||
|
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||||
try:
|
try:
|
||||||
|
if is_streaming:
|
||||||
|
return StreamingResponse(
|
||||||
|
sse_generator(func(**kwargs)), media_type="text/event-stream"
|
||||||
|
)
|
||||||
|
else:
|
||||||
return (
|
return (
|
||||||
await func(**kwargs)
|
await func(**kwargs)
|
||||||
if asyncio.iscoroutinefunction(func)
|
if asyncio.iscoroutinefunction(func)
|
||||||
|
|
|
@ -41,6 +41,9 @@ class ProviderSpec(BaseModel):
|
||||||
description="Higher-level API surfaces may depend on other providers to provide their functionality",
|
description="Higher-level API surfaces may depend on other providers to provide their functionality",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# used internally by the resolver; this is a hack for now
|
||||||
|
deps__: List[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class RoutingTable(Protocol):
|
class RoutingTable(Protocol):
|
||||||
def get_provider_impl(self, routing_key: str) -> Any: ...
|
def get_provider_impl(self, routing_key: str) -> Any: ...
|
||||||
|
|
|
@ -144,6 +144,8 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
async def create_and_execute_turn(
|
async def create_and_execute_turn(
|
||||||
self, request: AgentTurnCreateRequest
|
self, request: AgentTurnCreateRequest
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
assert request.stream is True, "Non-streaming not supported"
|
||||||
|
|
||||||
session_info = await self.storage.get_session_info(request.session_id)
|
session_info = await self.storage.get_session_info(request.session_id)
|
||||||
if session_info is None:
|
if session_info is None:
|
||||||
raise ValueError(f"Session {request.session_id} not found")
|
raise ValueError(f"Session {request.session_id} not found")
|
||||||
|
@ -635,14 +637,13 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
raise ValueError(f"Session {session_id} not found")
|
raise ValueError(f"Session {session_id} not found")
|
||||||
|
|
||||||
if session_info.memory_bank_id is None:
|
if session_info.memory_bank_id is None:
|
||||||
memory_bank = await self.memory_api.create_memory_bank(
|
bank_id = f"memory_bank_{session_id}"
|
||||||
name=f"memory_bank_{session_id}",
|
memory_bank = VectorMemoryBankDef(
|
||||||
config=VectorMemoryBankConfig(
|
identifier=bank_id,
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
chunk_size_in_tokens=512,
|
chunk_size_in_tokens=512,
|
||||||
),
|
|
||||||
)
|
)
|
||||||
bank_id = memory_bank.bank_id
|
await self.memory_api.register_memory_bank(memory_bank)
|
||||||
await self.storage.add_memory_bank_to_session(session_id, bank_id)
|
await self.storage.add_memory_bank_to_session(session_id, bank_id)
|
||||||
else:
|
else:
|
||||||
bank_id = session_info.memory_bank_id
|
bank_id = session_info.memory_bank_id
|
||||||
|
|
|
@ -100,7 +100,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def create_agent_turn(
|
def create_agent_turn(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
|
@ -113,16 +113,22 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
attachments: Optional[List[Attachment]] = None,
|
attachments: Optional[List[Attachment]] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
agent = await self.get_agent(agent_id)
|
|
||||||
|
|
||||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
|
||||||
request = AgentTurnCreateRequest(
|
request = AgentTurnCreateRequest(
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
attachments=attachments,
|
attachments=attachments,
|
||||||
stream=stream,
|
stream=True,
|
||||||
)
|
)
|
||||||
|
if stream:
|
||||||
|
return self._create_agent_turn_streaming(request)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Non-streaming agent turns not yet implemented")
|
||||||
|
|
||||||
|
async def _create_agent_turn_streaming(
|
||||||
|
self,
|
||||||
|
request: AgentTurnCreateRequest,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
agent = await self.get_agent(request.agent_id)
|
||||||
async for event in agent.create_and_execute_turn(request):
|
async for event in agent.create_and_execute_turn(request):
|
||||||
yield event
|
yield event
|
||||||
|
|
5
llama_stack/providers/tests/agents/__init__.py
Normal file
5
llama_stack/providers/tests/agents/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
Loading…
Add table
Add a link
Reference in a new issue