Fix a bug in meta-reference inference when stream=False

Also introduce a gross hack (to cover grosser(?) hack) to ensure
non-stream requests don't send back responses in SSE format. Not sure
which of these hacks is grosser.
This commit is contained in:
Ashwin Bharambe 2024-10-07 14:35:50 -07:00
parent 353c7dc82a
commit 4fa467731e
2 changed files with 30 additions and 15 deletions

View file

@ -6,7 +6,7 @@
import asyncio
from typing import AsyncIterator, List, Union
from typing import AsyncGenerator, List
from llama_models.sku_list import resolve_model
@ -58,9 +58,7 @@ class MetaReferenceInferenceImpl(Inference):
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncIterator[
Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse]
]:
) -> AsyncGenerator:
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = ChatCompletionRequest(
model=model,
@ -117,15 +115,17 @@ class MetaReferenceInferenceImpl(Inference):
if not ipython and buffer.startswith("<|python_tag|>"):
ipython = True
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
if request.stream:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
content="",
parse_status=ToolCallParseStatus.started,
),
)
)
)
buffer = buffer[len("<|python_tag|>") :]
continue