diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index dd3fafd0a..7b19f7996 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -58,13 +58,28 @@ def is_async_iterator_type(typ): ) -def create_sse_event(data: Any) -> str: +def create_sse_event(data: Any, **kwargs) -> str: if isinstance(data, BaseModel): data = data.json() else: data = json.dumps(data) - return f"data: {data}\n\n" + # !!FIX THIS ASAP!! grossest hack ever; not really SSE + # + # we use the return type of the function to determine if there's an AsyncGenerator + # and change the implementation to send SSE. unfortunately, chat_completion() takes a + # parameter called stream which _changes_ the return type. one correct way to fix this is: + # + # - have separate underlying functions for streaming and non-streaming because they need + # to operate differently anyhow + # - do a late binding of the return type based on the parameters passed in + if kwargs.get("stream", False): + return f"data: {data}\n\n" + else: + print( + f"!!FIX THIS ASAP!! Sending non-SSE event because client really is non-SSE: {data}" + ) + return data async def global_exception_handler(request: Request, exc: Exception): @@ -226,7 +241,7 @@ def create_dynamic_typed_route(func: Any, method: str): async def sse_generator(event_gen): try: async for item in event_gen: - yield create_sse_event(item) + yield create_sse_event(item, **kwargs) await asyncio.sleep(0.01) except asyncio.CancelledError: print("Generator cancelled") diff --git a/llama_stack/providers/impls/meta_reference/inference/inference.py b/llama_stack/providers/impls/meta_reference/inference/inference.py index f36d65c3f..a310a479a 100644 --- a/llama_stack/providers/impls/meta_reference/inference/inference.py +++ b/llama_stack/providers/impls/meta_reference/inference/inference.py @@ -6,7 +6,7 @@ import asyncio -from typing import AsyncIterator, List, Union +from typing import AsyncGenerator, List from llama_models.sku_list import resolve_model @@ -58,9 +58,7 @@ class MetaReferenceInferenceImpl(Inference): tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, - ) -> AsyncIterator[ - Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse] - ]: + ) -> AsyncGenerator: # wrapper request to make it easier to pass around (internal only, not exposed to API) request = ChatCompletionRequest( model=model, @@ -117,15 +115,17 @@ class MetaReferenceInferenceImpl(Inference): if not ipython and buffer.startswith("<|python_tag|>"): ipython = True - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - content="", - parse_status=ToolCallParseStatus.started, - ), + if request.stream: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + content="", + parse_status=ToolCallParseStatus.started, + ), + ) ) - ) + buffer = buffer[len("<|python_tag|>") :] continue