mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-28 15:02:37 +00:00
Fix a bug in meta-reference inference when stream=False
Also introduce a gross hack (to cover grosser(?) hack) to ensure non-stream requests don't send back responses in SSE format. Not sure which of these hacks is grosser.
This commit is contained in:
parent
353c7dc82a
commit
4fa467731e
2 changed files with 30 additions and 15 deletions
|
@ -58,13 +58,28 @@ def is_async_iterator_type(typ):
|
|||
)
|
||||
|
||||
|
||||
def create_sse_event(data: Any) -> str:
|
||||
def create_sse_event(data: Any, **kwargs) -> str:
|
||||
if isinstance(data, BaseModel):
|
||||
data = data.json()
|
||||
else:
|
||||
data = json.dumps(data)
|
||||
|
||||
return f"data: {data}\n\n"
|
||||
# !!FIX THIS ASAP!! grossest hack ever; not really SSE
|
||||
#
|
||||
# we use the return type of the function to determine if there's an AsyncGenerator
|
||||
# and change the implementation to send SSE. unfortunately, chat_completion() takes a
|
||||
# parameter called stream which _changes_ the return type. one correct way to fix this is:
|
||||
#
|
||||
# - have separate underlying functions for streaming and non-streaming because they need
|
||||
# to operate differently anyhow
|
||||
# - do a late binding of the return type based on the parameters passed in
|
||||
if kwargs.get("stream", False):
|
||||
return f"data: {data}\n\n"
|
||||
else:
|
||||
print(
|
||||
f"!!FIX THIS ASAP!! Sending non-SSE event because client really is non-SSE: {data}"
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
async def global_exception_handler(request: Request, exc: Exception):
|
||||
|
@ -226,7 +241,7 @@ def create_dynamic_typed_route(func: Any, method: str):
|
|||
async def sse_generator(event_gen):
|
||||
try:
|
||||
async for item in event_gen:
|
||||
yield create_sse_event(item)
|
||||
yield create_sse_event(item, **kwargs)
|
||||
await asyncio.sleep(0.01)
|
||||
except asyncio.CancelledError:
|
||||
print("Generator cancelled")
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import asyncio
|
||||
|
||||
from typing import AsyncIterator, List, Union
|
||||
from typing import AsyncGenerator, List
|
||||
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
|
@ -58,9 +58,7 @@ class MetaReferenceInferenceImpl(Inference):
|
|||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncIterator[
|
||||
Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse]
|
||||
]:
|
||||
) -> AsyncGenerator:
|
||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
|
@ -117,15 +115,17 @@ class MetaReferenceInferenceImpl(Inference):
|
|||
|
||||
if not ipython and buffer.startswith("<|python_tag|>"):
|
||||
ipython = True
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
if request.stream:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
content="",
|
||||
parse_status=ToolCallParseStatus.started,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
buffer = buffer[len("<|python_tag|>") :]
|
||||
continue
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue