Updates to server.py to clean up streaming vs non-streaming stuff

Also make sure agent turn create is correctly marked
This commit is contained in:
Ashwin Bharambe 2024-10-08 14:28:50 -07:00 committed by Ashwin Bharambe
parent 640c5c54f7
commit 7f1160296c
13 changed files with 115 additions and 128 deletions

View file

@ -50,8 +50,10 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
f"Provider `{provider.provider_type}` is not available for API `{api}`"
)
p = all_api_providers[api][provider.provider_type]
p.deps__ = [a.value for a in p.api_dependencies]
spec = ProviderWithSpec(
spec=all_api_providers[api][provider.provider_type],
spec=p,
**(provider.dict()),
)
specs[provider.provider_id] = spec
@ -93,6 +95,10 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
registry=registry,
module="llama_stack.distribution.routers",
api_dependencies=inner_deps,
deps__=(
[x.value for x in inner_deps]
+ [f"inner-{info.router_api.value}"]
),
),
)
}
@ -107,6 +113,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
module="llama_stack.distribution.routers",
routing_table_api=info.routing_table_api,
api_dependencies=[info.routing_table_api],
deps__=([info.routing_table_api.value]),
),
)
}
@ -130,6 +137,7 @@ async def resolve_impls_with_routing(run_config: StackRunConfig) -> Dict[Api, An
config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
module="llama_stack.distribution.inspect",
api_dependencies=apis,
deps__=([x.value for x in apis]),
),
),
)
@ -175,10 +183,8 @@ def topological_sort(
deps = []
for provider in providers:
for dep in provider.spec.api_dependencies:
deps.append(dep.value)
if isinstance(provider, AutoRoutedProviderSpec):
deps.append(f"inner-{provider.api}")
for dep in provider.spec.deps__:
deps.append(dep)
for dep in deps:
if dep not in visited:

View file

@ -39,6 +39,7 @@ class CommonRoutingTableImpl(RoutingTable):
) -> None:
for obj in registry:
if obj.provider_id not in impls_by_provider_id:
print(f"{impls_by_provider_id=}")
raise ValueError(
f"Provider `{obj.provider_id}` pointed by `{obj.identifier}` not found"
)
@ -70,7 +71,7 @@ class CommonRoutingTableImpl(RoutingTable):
def get_provider_impl(self, routing_key: str) -> Any:
if routing_key not in self.routing_key_to_object:
raise ValueError(f"Object `{routing_key}` not registered")
raise ValueError(f"`{routing_key}` not registered")
obj = self.routing_key_to_object[routing_key]
if obj.provider_id not in self.impls_by_provider_id:
@ -86,7 +87,7 @@ class CommonRoutingTableImpl(RoutingTable):
async def register_object(self, obj: RoutableObject):
if obj.identifier in self.routing_key_to_object:
print(f"Object `{obj.identifier}` is already registered")
print(f"`{obj.identifier}` is already registered")
return
if not obj.provider_id:

View file

@ -11,13 +11,9 @@ import json
import signal
import traceback
from collections.abc import (
AsyncGenerator as AsyncGeneratorABC,
AsyncIterator as AsyncIteratorABC,
)
from contextlib import asynccontextmanager
from ssl import SSLError
from typing import Any, AsyncGenerator, AsyncIterator, Dict, get_type_hints, Optional
from typing import Any, Dict, Optional
import fire
import httpx
@ -44,42 +40,13 @@ from llama_stack.distribution.resolver import resolve_impls_with_routing
from .endpoints import get_all_api_endpoints
def is_async_iterator_type(typ):
if hasattr(typ, "__origin__"):
origin = typ.__origin__
if isinstance(origin, type):
return issubclass(
origin,
(AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC),
)
return False
return isinstance(
typ, (AsyncIterator, AsyncGenerator, AsyncIteratorABC, AsyncGeneratorABC)
)
def create_sse_event(data: Any, **kwargs) -> str:
def create_sse_event(data: Any) -> str:
if isinstance(data, BaseModel):
data = data.json()
else:
data = json.dumps(data)
# !!FIX THIS ASAP!! grossest hack ever; not really SSE
#
# we use the return type of the function to determine if there's an AsyncGenerator
# and change the implementation to send SSE. unfortunately, chat_completion() takes a
# parameter called stream which _changes_ the return type. one correct way to fix this is:
#
# - have separate underlying functions for streaming and non-streaming because they need
# to operate differently anyhow
# - do a late binding of the return type based on the parameters passed in
if kwargs.get("stream", False):
return f"data: {data}\n\n"
else:
print(
f"!!FIX THIS ASAP!! Sending non-SSE event because client really is non-SSE: {data}"
)
return data
return f"data: {data}\n\n"
async def global_exception_handler(request: Request, exc: Exception):
@ -221,65 +188,56 @@ def create_dynamic_passthrough(
return endpoint
def is_streaming_request(func_name: str, request: Request, **kwargs):
# TODO: pass the api method and punt it to the Protocol definition directly
return kwargs.get("stream", False)
async def sse_generator(event_gen):
try:
async for item in event_gen:
yield create_sse_event(item)
await asyncio.sleep(0.01)
except asyncio.CancelledError:
print("Generator cancelled")
await event_gen.aclose()
except Exception as e:
traceback.print_exception(e)
yield create_sse_event(
{
"error": {
"message": str(translate_exception(e)),
},
}
)
finally:
await end_trace()
def create_dynamic_typed_route(func: Any, method: str):
hints = get_type_hints(func)
response_model = hints.get("return")
# NOTE: I think it is better to just add a method within each Api
# "Protocol" / adapter-impl to tell what sort of a response this request
# is going to produce. /chat_completion can produce a streaming or
# non-streaming response depending on if request.stream is True / False.
is_streaming = is_async_iterator_type(response_model)
async def endpoint(request: Request, **kwargs):
await start_trace(func.__name__)
if is_streaming:
set_request_provider_data(request.headers)
async def endpoint(request: Request, **kwargs):
await start_trace(func.__name__)
set_request_provider_data(request.headers)
async def sse_generator(event_gen):
try:
async for item in event_gen:
yield create_sse_event(item, **kwargs)
await asyncio.sleep(0.01)
except asyncio.CancelledError:
print("Generator cancelled")
await event_gen.aclose()
except Exception as e:
traceback.print_exception(e)
yield create_sse_event(
{
"error": {
"message": str(translate_exception(e)),
},
}
)
finally:
await end_trace()
return StreamingResponse(
sse_generator(func(**kwargs)), media_type="text/event-stream"
)
else:
async def endpoint(request: Request, **kwargs):
await start_trace(func.__name__)
set_request_provider_data(request.headers)
try:
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
try:
if is_streaming:
return StreamingResponse(
sse_generator(func(**kwargs)), media_type="text/event-stream"
)
else:
return (
await func(**kwargs)
if asyncio.iscoroutinefunction(func)
else func(**kwargs)
)
except Exception as e:
traceback.print_exception(e)
raise translate_exception(e) from e
finally:
await end_trace()
except Exception as e:
traceback.print_exception(e)
raise translate_exception(e) from e
finally:
await end_trace()
sig = inspect.signature(func)
new_params = [