mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
further simplify the async management
This commit is contained in:
parent
c6403706b4
commit
a7caacf1bf
2 changed files with 15 additions and 23 deletions
|
@ -275,7 +275,7 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
||||||
if recording:
|
if recording:
|
||||||
response_body = recording["response"]["body"]
|
response_body = recording["response"]["body"]
|
||||||
|
|
||||||
if recording["response"].get("is_streaming", False) or endpoint == "/v1/models":
|
if recording["response"].get("is_streaming", False):
|
||||||
|
|
||||||
async def replay_stream():
|
async def replay_stream():
|
||||||
for chunk in response_body:
|
for chunk in response_body:
|
||||||
|
@ -298,6 +298,10 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
||||||
else:
|
else:
|
||||||
response = original_method(self, *args, **kwargs)
|
response = original_method(self, *args, **kwargs)
|
||||||
|
|
||||||
|
# we want to store the result of the iterator, not the iterator itself
|
||||||
|
if endpoint == "/v1/models":
|
||||||
|
response = [m async for m in response]
|
||||||
|
|
||||||
request_data = {
|
request_data = {
|
||||||
"method": method,
|
"method": method,
|
||||||
"url": url,
|
"url": url,
|
||||||
|
@ -310,7 +314,7 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
||||||
# Determine if this is a streaming request based on request parameters
|
# Determine if this is a streaming request based on request parameters
|
||||||
is_streaming = body.get("stream", False)
|
is_streaming = body.get("stream", False)
|
||||||
|
|
||||||
if is_streaming or endpoint == "/v1/models":
|
if is_streaming:
|
||||||
# For streaming responses, we need to collect all chunks immediately before yielding
|
# For streaming responses, we need to collect all chunks immediately before yielding
|
||||||
# This ensures the recording is saved even if the generator isn't fully consumed
|
# This ensures the recording is saved even if the generator isn't fully consumed
|
||||||
chunks = []
|
chunks = []
|
||||||
|
@ -377,15 +381,10 @@ def patch_inference_clients():
|
||||||
)
|
)
|
||||||
|
|
||||||
def patched_models_list(self, *args, **kwargs):
|
def patched_models_list(self, *args, **kwargs):
|
||||||
import asyncio
|
|
||||||
|
|
||||||
task = asyncio.create_task(
|
|
||||||
_patched_inference_method(_original_methods["models_list"], self, "openai", "/v1/models", *args, **kwargs)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _iter():
|
async def _iter():
|
||||||
result = await task
|
for item in await _patched_inference_method(
|
||||||
async for item in result:
|
_original_methods["models_list"], self, "openai", "/v1/models", *args, **kwargs
|
||||||
|
):
|
||||||
yield item
|
yield item
|
||||||
|
|
||||||
return _iter()
|
return _iter()
|
||||||
|
|
|
@ -221,16 +221,9 @@ class TestInferenceRecording:
|
||||||
async def test_replay_mode_models(self, temp_storage_dir):
|
async def test_replay_mode_models(self, temp_storage_dir):
|
||||||
"""Test that replay mode returns stored responses without making real model listing calls."""
|
"""Test that replay mode returns stored responses without making real model listing calls."""
|
||||||
|
|
||||||
class MockAsyncPaginator:
|
async def _async_iterator(models):
|
||||||
def __init__(self, models: list[OpenAIModel]):
|
for model in models:
|
||||||
self._models = models
|
yield model
|
||||||
|
|
||||||
def __aiter__(self):
|
|
||||||
return self._async_iterator()
|
|
||||||
|
|
||||||
async def _async_iterator(self):
|
|
||||||
for model in self._models:
|
|
||||||
yield model
|
|
||||||
|
|
||||||
models = [
|
models = [
|
||||||
OpenAIModel(id="foo", created=1, object="model", owned_by="test"),
|
OpenAIModel(id="foo", created=1, object="model", owned_by="test"),
|
||||||
|
@ -243,21 +236,21 @@ class TestInferenceRecording:
|
||||||
|
|
||||||
# baseline - mock works without recording
|
# baseline - mock works without recording
|
||||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||||
client.models._get_api_list = Mock(return_value=MockAsyncPaginator(models))
|
client.models._get_api_list = Mock(return_value=_async_iterator(models))
|
||||||
assert {m.id async for m in client.models.list()} == expected_ids
|
assert {m.id async for m in client.models.list()} == expected_ids
|
||||||
client.models._get_api_list.assert_called_once()
|
client.models._get_api_list.assert_called_once()
|
||||||
|
|
||||||
# record the call
|
# record the call
|
||||||
with inference_recording(mode=InferenceMode.RECORD, storage_dir=temp_storage_dir):
|
with inference_recording(mode=InferenceMode.RECORD, storage_dir=temp_storage_dir):
|
||||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||||
client.models._get_api_list = Mock(return_value=MockAsyncPaginator(models))
|
client.models._get_api_list = Mock(return_value=_async_iterator(models))
|
||||||
assert {m.id async for m in client.models.list()} == expected_ids
|
assert {m.id async for m in client.models.list()} == expected_ids
|
||||||
client.models._get_api_list.assert_called_once()
|
client.models._get_api_list.assert_called_once()
|
||||||
|
|
||||||
# replay the call
|
# replay the call
|
||||||
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=temp_storage_dir):
|
with inference_recording(mode=InferenceMode.REPLAY, storage_dir=temp_storage_dir):
|
||||||
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
client = AsyncOpenAI(base_url="http://localhost:11434/v1", api_key="test")
|
||||||
client.models._get_api_list = Mock(return_value=MockAsyncPaginator(models))
|
client.models._get_api_list = Mock(return_value=_async_iterator(models))
|
||||||
assert {m.id async for m in client.models.list()} == expected_ids
|
assert {m.id async for m in client.models.list()} == expected_ids
|
||||||
client.models._get_api_list.assert_not_called()
|
client.models._get_api_list.assert_not_called()
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue