Fix async streaming

This commit is contained in:
Neil Mehta 2025-03-24 14:10:49 -04:00 committed by Matt Clayton
parent 05777dfb52
commit 00affd1f02
2 changed files with 16 additions and 26 deletions

View file

@ -233,6 +233,7 @@ class InferenceRouter(Inference):
messages: List[Message] | InterleavedContent,
tool_prompt_format: Optional[ToolPromptFormat] = None,
) -> Optional[int]:
return 1
if isinstance(messages, list):
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
else:

View file

@ -192,18 +192,7 @@ class LMStudioClient:
)
)
# Convert to list to avoid StopIteration issues
try:
chunks = await asyncio.to_thread(list, prediction_stream)
except StopIteration:
# Handle StopIteration by returning an empty list
chunks = []
except Exception as e:
self._log_error(e, "converting chat stream to list")
chunks = []
# Yield each chunk
for chunk in chunks:
async for chunk in self._async_iterate(prediction_stream):
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
@ -332,17 +321,7 @@ class LMStudioClient:
config=config,
response_format=json_schema,
)
try:
chunks = await asyncio.to_thread(list, prediction_stream)
except StopIteration:
# Handle StopIteration by returning an empty list
chunks = []
except Exception as e:
self._log_error(e, "converting completion stream to list")
chunks = []
for chunk in chunks:
async for chunk in self._async_iterate(prediction_stream):
yield CompletionResponseStreamChunk(
delta=chunk.content,
)
@ -455,9 +434,19 @@ class LMStudioClient:
async def _async_iterate(self, iterable):
"""Asynchronously iterate over a synchronous iterable."""
# Convert the synchronous iterable to a list first to avoid StopIteration issues
items = await asyncio.to_thread(list, iterable)
for item in items:
iterator = iter(iterable)
def safe_next(it):
"""This is necessary to communicate StopIteration across threads"""
try:
return (next(it), False)
except StopIteration:
return (None, True)
while True:
item, done = await asyncio.to_thread(safe_next, iterator)
if done:
break
yield item
async def _convert_request_to_rest_call(