mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
chore(recorder): update mocks to be closer to non-mock environment (#3442)
# What does this PR do? the @required_args decorator in openai-python is masking the async nature of the {AsyncCompletions,chat.AsyncCompletions}.create method. see https://github.com/openai/openai-python/issues/996 this means two things - 0. we cannot use iscoroutine in the recorder to detect async vs non 1. our mocks are inappropriately introducing identifiable async for (0), we update the iscoroutine check w/ detection of /v1/models, which is the only non-async function we mock & record. for (1), we could leave everything as is and assume (0) will catch errors. to be defensive, we update the unit tests to mock below create methods, allowing the true openai-python create() methods to be tested.
This commit is contained in:
parent
b6cb817897
commit
01bdcce4d2
2 changed files with 113 additions and 109 deletions
|
@ -7,7 +7,6 @@
|
|||
from __future__ import annotations # for forward references
|
||||
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
|
@ -243,11 +242,10 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
global _current_mode, _current_storage
|
||||
|
||||
if _current_mode == InferenceMode.LIVE or _current_storage is None:
|
||||
# Normal operation
|
||||
if inspect.iscoroutinefunction(original_method):
|
||||
return await original_method(self, *args, **kwargs)
|
||||
else:
|
||||
if endpoint == "/v1/models":
|
||||
return original_method(self, *args, **kwargs)
|
||||
else:
|
||||
return await original_method(self, *args, **kwargs)
|
||||
|
||||
# Get base URL based on client type
|
||||
if client_type == "openai":
|
||||
|
@ -298,10 +296,10 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
)
|
||||
|
||||
elif _current_mode == InferenceMode.RECORD:
|
||||
if inspect.iscoroutinefunction(original_method):
|
||||
response = await original_method(self, *args, **kwargs)
|
||||
else:
|
||||
if endpoint == "/v1/models":
|
||||
response = original_method(self, *args, **kwargs)
|
||||
else:
|
||||
response = await original_method(self, *args, **kwargs)
|
||||
|
||||
# we want to store the result of the iterator, not the iterator itself
|
||||
if endpoint == "/v1/models":
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue