mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Fix a bug in meta-reference inference when stream=False
Also introduce a gross hack (to cover grosser(?) hack) to ensure non-stream requests don't send back responses in SSE format. Not sure which of these hacks is grosser.
This commit is contained in:
parent
353c7dc82a
commit
4fa467731e
2 changed files with 30 additions and 15 deletions
|
@ -58,13 +58,28 @@ def is_async_iterator_type(typ):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_sse_event(data: Any) -> str:
|
def create_sse_event(data: Any, **kwargs) -> str:
|
||||||
if isinstance(data, BaseModel):
|
if isinstance(data, BaseModel):
|
||||||
data = data.json()
|
data = data.json()
|
||||||
else:
|
else:
|
||||||
data = json.dumps(data)
|
data = json.dumps(data)
|
||||||
|
|
||||||
|
# !!FIX THIS ASAP!! grossest hack ever; not really SSE
|
||||||
|
#
|
||||||
|
# we use the return type of the function to determine if there's an AsyncGenerator
|
||||||
|
# and change the implementation to send SSE. unfortunately, chat_completion() takes a
|
||||||
|
# parameter called stream which _changes_ the return type. one correct way to fix this is:
|
||||||
|
#
|
||||||
|
# - have separate underlying functions for streaming and non-streaming because they need
|
||||||
|
# to operate differently anyhow
|
||||||
|
# - do a late binding of the return type based on the parameters passed in
|
||||||
|
if kwargs.get("stream", False):
|
||||||
return f"data: {data}\n\n"
|
return f"data: {data}\n\n"
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"!!FIX THIS ASAP!! Sending non-SSE event because client really is non-SSE: {data}"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
async def global_exception_handler(request: Request, exc: Exception):
|
async def global_exception_handler(request: Request, exc: Exception):
|
||||||
|
@ -226,7 +241,7 @@ def create_dynamic_typed_route(func: Any, method: str):
|
||||||
async def sse_generator(event_gen):
|
async def sse_generator(event_gen):
|
||||||
try:
|
try:
|
||||||
async for item in event_gen:
|
async for item in event_gen:
|
||||||
yield create_sse_event(item)
|
yield create_sse_event(item, **kwargs)
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
print("Generator cancelled")
|
print("Generator cancelled")
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from typing import AsyncIterator, List, Union
|
from typing import AsyncGenerator, List
|
||||||
|
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
|
@ -58,9 +58,7 @@ class MetaReferenceInferenceImpl(Inference):
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncIterator[
|
) -> AsyncGenerator:
|
||||||
Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse]
|
|
||||||
]:
|
|
||||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||||
request = ChatCompletionRequest(
|
request = ChatCompletionRequest(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -117,6 +115,7 @@ class MetaReferenceInferenceImpl(Inference):
|
||||||
|
|
||||||
if not ipython and buffer.startswith("<|python_tag|>"):
|
if not ipython and buffer.startswith("<|python_tag|>"):
|
||||||
ipython = True
|
ipython = True
|
||||||
|
if request.stream:
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
@ -126,6 +125,7 @@ class MetaReferenceInferenceImpl(Inference):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
buffer = buffer[len("<|python_tag|>") :]
|
buffer = buffer[len("<|python_tag|>") :]
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue