From 00affd1f028b2a0b1822b583a72e6ab4d6304ba8 Mon Sep 17 00:00:00 2001 From: Neil Mehta Date: Mon, 24 Mar 2025 14:10:49 -0400 Subject: [PATCH] Fix async streaming --- llama_stack/distribution/routers/routers.py | 1 + .../remote/inference/lmstudio/_client.py | 41 +++++++------------ 2 files changed, 16 insertions(+), 26 deletions(-) diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index d88df00bd..3b14bb989 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -233,6 +233,7 @@ class InferenceRouter(Inference): messages: List[Message] | InterleavedContent, tool_prompt_format: Optional[ToolPromptFormat] = None, ) -> Optional[int]: + return 1 if isinstance(messages, list): encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format) else: diff --git a/llama_stack/providers/remote/inference/lmstudio/_client.py b/llama_stack/providers/remote/inference/lmstudio/_client.py index f03cb7bc0..4c4df3717 100644 --- a/llama_stack/providers/remote/inference/lmstudio/_client.py +++ b/llama_stack/providers/remote/inference/lmstudio/_client.py @@ -192,18 +192,7 @@ class LMStudioClient: ) ) - # Convert to list to avoid StopIteration issues - try: - chunks = await asyncio.to_thread(list, prediction_stream) - except StopIteration: - # Handle StopIteration by returning an empty list - chunks = [] - except Exception as e: - self._log_error(e, "converting chat stream to list") - chunks = [] - - # Yield each chunk - for chunk in chunks: + async for chunk in self._async_iterate(prediction_stream): yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.progress, @@ -332,17 +321,7 @@ class LMStudioClient: config=config, response_format=json_schema, ) - - try: - chunks = await asyncio.to_thread(list, prediction_stream) - except StopIteration: - # Handle StopIteration by returning an empty list - chunks = [] - except Exception as e: - self._log_error(e, "converting completion stream to list") - chunks = [] - - for chunk in chunks: + async for chunk in self._async_iterate(prediction_stream): yield CompletionResponseStreamChunk( delta=chunk.content, ) @@ -455,9 +434,19 @@ class LMStudioClient: async def _async_iterate(self, iterable): """Asynchronously iterate over a synchronous iterable.""" - # Convert the synchronous iterable to a list first to avoid StopIteration issues - items = await asyncio.to_thread(list, iterable) - for item in items: + iterator = iter(iterable) + + def safe_next(it): + """This is necessary to communicate StopIteration across threads""" + try: + return (next(it), False) + except StopIteration: + return (None, True) + + while True: + item, done = await asyncio.to_thread(safe_next, iterator) + if done: + break yield item async def _convert_request_to_rest_call(