mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 09:21:45 +00:00
Fix async streaming
This commit is contained in:
parent
05777dfb52
commit
00affd1f02
2 changed files with 16 additions and 26 deletions
|
@ -233,6 +233,7 @@ class InferenceRouter(Inference):
|
||||||
messages: List[Message] | InterleavedContent,
|
messages: List[Message] | InterleavedContent,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
) -> Optional[int]:
|
) -> Optional[int]:
|
||||||
|
return 1
|
||||||
if isinstance(messages, list):
|
if isinstance(messages, list):
|
||||||
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
|
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -192,18 +192,7 @@ class LMStudioClient:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert to list to avoid StopIteration issues
|
async for chunk in self._async_iterate(prediction_stream):
|
||||||
try:
|
|
||||||
chunks = await asyncio.to_thread(list, prediction_stream)
|
|
||||||
except StopIteration:
|
|
||||||
# Handle StopIteration by returning an empty list
|
|
||||||
chunks = []
|
|
||||||
except Exception as e:
|
|
||||||
self._log_error(e, "converting chat stream to list")
|
|
||||||
chunks = []
|
|
||||||
|
|
||||||
# Yield each chunk
|
|
||||||
for chunk in chunks:
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
|
@ -332,17 +321,7 @@ class LMStudioClient:
|
||||||
config=config,
|
config=config,
|
||||||
response_format=json_schema,
|
response_format=json_schema,
|
||||||
)
|
)
|
||||||
|
async for chunk in self._async_iterate(prediction_stream):
|
||||||
try:
|
|
||||||
chunks = await asyncio.to_thread(list, prediction_stream)
|
|
||||||
except StopIteration:
|
|
||||||
# Handle StopIteration by returning an empty list
|
|
||||||
chunks = []
|
|
||||||
except Exception as e:
|
|
||||||
self._log_error(e, "converting completion stream to list")
|
|
||||||
chunks = []
|
|
||||||
|
|
||||||
for chunk in chunks:
|
|
||||||
yield CompletionResponseStreamChunk(
|
yield CompletionResponseStreamChunk(
|
||||||
delta=chunk.content,
|
delta=chunk.content,
|
||||||
)
|
)
|
||||||
|
@ -455,9 +434,19 @@ class LMStudioClient:
|
||||||
|
|
||||||
async def _async_iterate(self, iterable):
|
async def _async_iterate(self, iterable):
|
||||||
"""Asynchronously iterate over a synchronous iterable."""
|
"""Asynchronously iterate over a synchronous iterable."""
|
||||||
# Convert the synchronous iterable to a list first to avoid StopIteration issues
|
iterator = iter(iterable)
|
||||||
items = await asyncio.to_thread(list, iterable)
|
|
||||||
for item in items:
|
def safe_next(it):
|
||||||
|
"""This is necessary to communicate StopIteration across threads"""
|
||||||
|
try:
|
||||||
|
return (next(it), False)
|
||||||
|
except StopIteration:
|
||||||
|
return (None, True)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
item, done = await asyncio.to_thread(safe_next, iterator)
|
||||||
|
if done:
|
||||||
|
break
|
||||||
yield item
|
yield item
|
||||||
|
|
||||||
async def _convert_request_to_rest_call(
|
async def _convert_request_to_rest_call(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue