mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-05 04:17:32 +00:00
further simplify the async management
This commit is contained in:
parent
c6403706b4
commit
a7caacf1bf
2 changed files with 15 additions and 23 deletions
|
@ -275,7 +275,7 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
if recording:
|
||||
response_body = recording["response"]["body"]
|
||||
|
||||
if recording["response"].get("is_streaming", False) or endpoint == "/v1/models":
|
||||
if recording["response"].get("is_streaming", False):
|
||||
|
||||
async def replay_stream():
|
||||
for chunk in response_body:
|
||||
|
@ -298,6 +298,10 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
else:
|
||||
response = original_method(self, *args, **kwargs)
|
||||
|
||||
# we want to store the result of the iterator, not the iterator itself
|
||||
if endpoint == "/v1/models":
|
||||
response = [m async for m in response]
|
||||
|
||||
request_data = {
|
||||
"method": method,
|
||||
"url": url,
|
||||
|
@ -310,7 +314,7 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
# Determine if this is a streaming request based on request parameters
|
||||
is_streaming = body.get("stream", False)
|
||||
|
||||
if is_streaming or endpoint == "/v1/models":
|
||||
if is_streaming:
|
||||
# For streaming responses, we need to collect all chunks immediately before yielding
|
||||
# This ensures the recording is saved even if the generator isn't fully consumed
|
||||
chunks = []
|
||||
|
@ -377,15 +381,10 @@ def patch_inference_clients():
|
|||
)
|
||||
|
||||
def patched_models_list(self, *args, **kwargs):
|
||||
import asyncio
|
||||
|
||||
task = asyncio.create_task(
|
||||
_patched_inference_method(_original_methods["models_list"], self, "openai", "/v1/models", *args, **kwargs)
|
||||
)
|
||||
|
||||
async def _iter():
|
||||
result = await task
|
||||
async for item in result:
|
||||
for item in await _patched_inference_method(
|
||||
_original_methods["models_list"], self, "openai", "/v1/models", *args, **kwargs
|
||||
):
|
||||
yield item
|
||||
|
||||
return _iter()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue