mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-08 19:10:56 +00:00
Updates to server.py to clean up streaming vs non-streaming stuff
Also make sure agent turn create is correctly marked
This commit is contained in:
parent
640c5c54f7
commit
7f1160296c
13 changed files with 115 additions and 128 deletions
|
|
@ -50,8 +50,10 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
|||
f"Provider `{provider.provider_type}` is not available for API `{api}`"
|
||||
)
|
||||
|
||||
p = all_api_providers[api][provider.provider_type]
|
||||
p.deps__ = [a.value for a in p.api_dependencies]
|
||||
spec = ProviderWithSpec(
|
||||
spec=all_api_providers[api][provider.provider_type],
|
||||
spec=p,
|
||||
**(provider.dict()),
|
||||
)
|
||||
specs[provider.provider_id] = spec
|
||||
|
|
@ -93,6 +95,10 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
|||
registry=registry,
|
||||
module="llama_stack.distribution.routers",
|
||||
api_dependencies=inner_deps,
|
||||
deps__=(
|
||||
[x.value for x in inner_deps]
|
||||
+ [f"inner-{info.router_api.value}"]
|
||||
),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
|
@ -107,6 +113,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
|||
module="llama_stack.distribution.routers",
|
||||
routing_table_api=info.routing_table_api,
|
||||
api_dependencies=[info.routing_table_api],
|
||||
deps__=([info.routing_table_api.value]),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
|
@ -130,6 +137,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
|
|||
config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
|
||||
module="llama_stack.distribution.inspect",
|
||||
api_dependencies=apis,
|
||||
deps__=([x.value for x in apis]),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
|
@ -175,10 +183,8 @@ def topological_sort(
|
|||
|
||||
deps = []
|
||||
for provider in providers:
|
||||
for dep in provider.spec.api_dependencies:
|
||||
deps.append(dep.value)
|
||||
if isinstance(provider, AutoRoutedProviderSpec):
|
||||
deps.append(f"inner-{provider.api}")
|
||||
for dep in provider.spec.deps__:
|
||||
deps.append(dep)
|
||||
|
||||
for dep in deps:
|
||||
if dep not in visited:
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
) -> None:
|
||||
for obj in registry:
|
||||
if obj.provider_id not in impls_by_provider_id:
|
||||
print(f"{impls_by_provider_id=}")
|
||||
raise ValueError(
|
||||
f"Provider `{obj.provider_id}` pointed by `{obj.identifier}` not found"
|
||||
)
|
||||
|
|
@ -70,7 +71,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
|
||||
def get_provider_impl(self, routing_key: str) -> Any:
|
||||
if routing_key not in self.routing_key_to_object:
|
||||
raise ValueError(f"Object `{routing_key}` not registered")
|
||||
raise ValueError(f"`{routing_key}` not registered")
|
||||
|
||||
obj = self.routing_key_to_object[routing_key]
|
||||
if obj.provider_id not in self.impls_by_provider_id:
|
||||
|
|
@ -86,7 +87,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
|
||||
async def register_object(self, obj: RoutableObject):
|
||||
if obj.identifier in self.routing_key_to_object:
|
||||
print(f"Object `{obj.identifier}` is already registered")
|
||||
print(f"`{obj.identifier}` is already registered")
|
||||
return
|
||||
|
||||
if not obj.provider_id:
|
||||
|
|
|
|||
|
|
@ -11,13 +11,9 @@ import json
|
|||
import signal
|
||||
import traceback
|
||||
|
||||
from collections.abc import (
|
||||
AsyncGenerator as AsyncGeneratorABC,
|
||||
AsyncIterator as AsyncIteratorABC,
|
||||
)
|
||||
from contextlib import asynccontextmanager
|
||||
from ssl import SSLError
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import fire
|
||||
import httpx
|
||||
|
|
@ -44,42 +40,13 @@ from llama_stack.distribution.resolver import resolve_impls_with_routing
|
|||
from .endpoints import get_all_api_endpoints
|
||||
|
||||
|
||||
def is_async_iterator_type(typ):
|
||||
if hasattr(typ, "__origin__"):
|
||||
origin = typ.__origin__
|
||||
if isinstance(origin, type):
|
||||
return issubclass(
|
||||
origin,
|
||||
(AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC),
|
||||
)
|
||||
return False
|
||||
return isinstance(
|
||||
typ, (AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC)
|
||||
)
|
||||
|
||||
|
||||
def create_sse_event(data: Any, **kwargs) -> str:
|
||||
def create_sse_event(data: Any) -> str:
|
||||
if isinstance(data, BaseModel):
|
||||
data = data.json()
|
||||
else:
|
||||
data = json.dumps(data)
|
||||
|
||||
# !!FIX THIS ASAP!! grossest hack ever; not really SSE
|
||||
#
|
||||
# we use the return type of the function to determine if there's an AsyncGenerator
|
||||
# and change the implementation to send SSE. unfortunately, chat_completion() takes a
|
||||
# parameter called stream which _changes_ the return type. one correct way to fix this is:
|
||||
#
|
||||
# - have separate underlying functions for streaming and non-streaming because they need
|
||||
# to operate differently anyhow
|
||||
# - do a late binding of the return type based on the parameters passed in
|
||||
if kwargs.get("stream", False):
|
||||
return f"data: {data}\n\n"
|
||||
else:
|
||||
print(
|
||||
f"!!FIX THIS ASAP!! Sending non-SSE event because client really is non-SSE: {data}"
|
||||
)
|
||||
return data
|
||||
return f"data: {data}\n\n"
|
||||
|
||||
|
||||
async def global_exception_handler(request: Request, exc: Exception):
|
||||
|
|
@ -221,65 +188,56 @@ def create_dynamic_passthrough(
|
|||
return endpoint
|
||||
|
||||
|
||||
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
||||
# TODO: pass the api method and punt it to the Protocol definition directly
|
||||
return kwargs.get("stream", False)
|
||||
|
||||
|
||||
async def sse_generator(event_gen):
|
||||
try:
|
||||
async for item in event_gen:
|
||||
yield create_sse_event(item)
|
||||
await asyncio.sleep(0.01)
|
||||
except asyncio.CancelledError:
|
||||
print("Generator cancelled")
|
||||
await event_gen.aclose()
|
||||
except Exception as e:
|
||||
traceback.print_exception(e)
|
||||
yield create_sse_event(
|
||||
{
|
||||
"error": {
|
||||
"message": str(translate_exception(e)),
|
||||
},
|
||||
}
|
||||
)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
|
||||
def create_dynamic_typed_route(func: Any, method: str):
|
||||
hints = get_type_hints(func)
|
||||
response_model = hints.get("return")
|
||||
|
||||
# NOTE: I think it is better to just add a method within each Api
|
||||
# "Protocol" / adapter-impl to tell what sort of a response this request
|
||||
# is going to produce. /chat_completion can produce a streaming or
|
||||
# non-streaming response depending on if request.stream is True / False.
|
||||
is_streaming = is_async_iterator_type(response_model)
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
await start_trace(func.__name__)
|
||||
|
||||
if is_streaming:
|
||||
set_request_provider_data(request.headers)
|
||||
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
await start_trace(func.__name__)
|
||||
|
||||
set_request_provider_data(request.headers)
|
||||
|
||||
async def sse_generator(event_gen):
|
||||
try:
|
||||
async for item in event_gen:
|
||||
yield create_sse_event(item, **kwargs)
|
||||
await asyncio.sleep(0.01)
|
||||
except asyncio.CancelledError:
|
||||
print("Generator cancelled")
|
||||
await event_gen.aclose()
|
||||
except Exception as e:
|
||||
traceback.print_exception(e)
|
||||
yield create_sse_event(
|
||||
{
|
||||
"error": {
|
||||
"message": str(translate_exception(e)),
|
||||
},
|
||||
}
|
||||
)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
return StreamingResponse(
|
||||
sse_generator(func(**kwargs)), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
await start_trace(func.__name__)
|
||||
|
||||
set_request_provider_data(request.headers)
|
||||
|
||||
try:
|
||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
try:
|
||||
if is_streaming:
|
||||
return StreamingResponse(
|
||||
sse_generator(func(**kwargs)), media_type="text/event-stream"
|
||||
)
|
||||
else:
|
||||
return (
|
||||
await func(**kwargs)
|
||||
if asyncio.iscoroutinefunction(func)
|
||||
else func(**kwargs)
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exception(e)
|
||||
raise translate_exception(e) from e
|
||||
finally:
|
||||
await end_trace()
|
||||
except Exception as e:
|
||||
traceback.print_exception(e)
|
||||
raise translate_exception(e) from e
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
sig = inspect.signature(func)
|
||||
new_params = [
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue