Fix a bug in meta-reference inference when stream=False

Also introduce a gross hack (to cover grosser(?) hack) to ensure
non-stream requests don't send back responses in SSE format. Not sure
which of these hacks is grosser.
This commit is contained in:
Ashwin Bharambe 2024-10-07 14:35:50 -07:00
parent 353c7dc82a
commit 4fa467731e
2 changed files with 30 additions and 15 deletions

View file

@ -58,13 +58,28 @@ def is_async_iterator_type(typ):
) )
def create_sse_event(data: Any) -> str: def create_sse_event(data: Any, **kwargs) -> str:
if isinstance(data, BaseModel): if isinstance(data, BaseModel):
data = data.json() data = data.json()
else: else:
data = json.dumps(data) data = json.dumps(data)
# !!FIX THIS ASAP!! grossest hack ever; not really SSE
#
# we use the return type of the function to determine if there's an AsyncGenerator
# and change the implementation to send SSE. unfortunately, chat_completion() takes a
# parameter called stream which _changes_ the return type. one correct way to fix this is:
#
# - have separate underlying functions for streaming and non-streaming because they need
# to operate differently anyhow
# - do a late binding of the return type based on the parameters passed in
if kwargs.get("stream", False):
return f"data: {data}\n\n" return f"data: {data}\n\n"
else:
print(
f"!!FIX THIS ASAP!! Sending non-SSE event because client really is non-SSE: {data}"
)
return data
async def global_exception_handler(request: Request, exc: Exception): async def global_exception_handler(request: Request, exc: Exception):
@ -226,7 +241,7 @@ def create_dynamic_typed_route(func: Any, method: str):
async def sse_generator(event_gen): async def sse_generator(event_gen):
try: try:
async for item in event_gen: async for item in event_gen:
yield create_sse_event(item) yield create_sse_event(item, **kwargs)
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
except asyncio.CancelledError: except asyncio.CancelledError:
print("Generator cancelled") print("Generator cancelled")

View file

@ -6,7 +6,7 @@
import asyncio import asyncio
from typing import AsyncIterator, List, Union from typing import AsyncGenerator, List
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
@ -58,9 +58,7 @@ class MetaReferenceInferenceImpl(Inference):
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncIterator[ ) -> AsyncGenerator:
Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse]
]:
# wrapper request to make it easier to pass around (internal only, not exposed to API) # wrapper request to make it easier to pass around (internal only, not exposed to API)
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=model, model=model,
@ -117,6 +115,7 @@ class MetaReferenceInferenceImpl(Inference):
if not ipython and buffer.startswith("<|python_tag|>"): if not ipython and buffer.startswith("<|python_tag|>"):
ipython = True ipython = True
if request.stream:
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress, event_type=ChatCompletionResponseEventType.progress,
@ -126,6 +125,7 @@ class MetaReferenceInferenceImpl(Inference):
), ),
) )
) )
buffer = buffer[len("<|python_tag|>") :] buffer = buffer[len("<|python_tag|>") :]
continue continue