mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
Fix async streaming
This commit is contained in:
parent
05777dfb52
commit
00affd1f02
2 changed files with 16 additions and 26 deletions
|
@ -233,6 +233,7 @@ class InferenceRouter(Inference):
|
|||
messages: List[Message] | InterleavedContent,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
) -> Optional[int]:
|
||||
return 1
|
||||
if isinstance(messages, list):
|
||||
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
|
||||
else:
|
||||
|
|
|
@ -192,18 +192,7 @@ class LMStudioClient:
|
|||
)
|
||||
)
|
||||
|
||||
# Convert to list to avoid StopIteration issues
|
||||
try:
|
||||
chunks = await asyncio.to_thread(list, prediction_stream)
|
||||
except StopIteration:
|
||||
# Handle StopIteration by returning an empty list
|
||||
chunks = []
|
||||
except Exception as e:
|
||||
self._log_error(e, "converting chat stream to list")
|
||||
chunks = []
|
||||
|
||||
# Yield each chunk
|
||||
for chunk in chunks:
|
||||
async for chunk in self._async_iterate(prediction_stream):
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
|
@ -332,17 +321,7 @@ class LMStudioClient:
|
|||
config=config,
|
||||
response_format=json_schema,
|
||||
)
|
||||
|
||||
try:
|
||||
chunks = await asyncio.to_thread(list, prediction_stream)
|
||||
except StopIteration:
|
||||
# Handle StopIteration by returning an empty list
|
||||
chunks = []
|
||||
except Exception as e:
|
||||
self._log_error(e, "converting completion stream to list")
|
||||
chunks = []
|
||||
|
||||
for chunk in chunks:
|
||||
async for chunk in self._async_iterate(prediction_stream):
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta=chunk.content,
|
||||
)
|
||||
|
@ -455,9 +434,19 @@ class LMStudioClient:
|
|||
|
||||
async def _async_iterate(self, iterable):
|
||||
"""Asynchronously iterate over a synchronous iterable."""
|
||||
# Convert the synchronous iterable to a list first to avoid StopIteration issues
|
||||
items = await asyncio.to_thread(list, iterable)
|
||||
for item in items:
|
||||
iterator = iter(iterable)
|
||||
|
||||
def safe_next(it):
|
||||
"""This is necessary to communicate StopIteration across threads"""
|
||||
try:
|
||||
return (next(it), False)
|
||||
except StopIteration:
|
||||
return (None, True)
|
||||
|
||||
while True:
|
||||
item, done = await asyncio.to_thread(safe_next, iterator)
|
||||
if done:
|
||||
break
|
||||
yield item
|
||||
|
||||
async def _convert_request_to_rest_call(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue