mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 18:19:51 +00:00
Adding memory provider mocks
This commit is contained in:
parent
7693786322
commit
ac1791f8b1
4 changed files with 153 additions and 3 deletions
|
|
@ -8,6 +8,8 @@ import inspect
|
|||
|
||||
from typing import Any, Dict, List, Set
|
||||
|
||||
from unittest.mock import MagicMock, Mock, NonCallableMagicMock, NonCallableMock
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.providers.datatypes import * # noqa: F403
|
||||
|
|
@ -33,6 +35,13 @@ from llama_stack.distribution.distribution import builtin_automatically_routed_a
|
|||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
|
||||
ALLOWED_MOCK_IMPLEMENTATIONS = {
|
||||
NonCallableMock,
|
||||
Mock,
|
||||
NonCallableMagicMock,
|
||||
MagicMock,
|
||||
}
|
||||
|
||||
|
||||
class InvalidProviderError(Exception):
|
||||
pass
|
||||
|
|
@ -346,8 +355,18 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
|||
else:
|
||||
# Check if the method is actually implemented in the class
|
||||
method_owner = next(
|
||||
(cls for cls in mro if name in cls.__dict__), None
|
||||
(
|
||||
cls
|
||||
for cls in mro
|
||||
if name in cls.__dict__
|
||||
# Finding method owner for a mock object will fail
|
||||
# implementaiont compliance check. In case we are holding
|
||||
# a mock object we are going to skip this validation
|
||||
or cls in ALLOWED_MOCK_IMPLEMENTATIONS
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if (
|
||||
method_owner is None
|
||||
or method_owner.__name__ == protocol.__name__
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue