This commit is contained in:
Jun Komoda 2025-04-24 00:55:38 -07:00 committed by GitHub
commit a112ccc3db
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 96 additions and 17 deletions

View file

@ -257,14 +257,14 @@ class LiteLLM:
timeout: Optional[float] = 600,
max_retries: Optional[int] = litellm.num_retries,
default_headers: Optional[Mapping[str, str]] = None,
):
) -> None:
self.params = locals()
self.chat = Chat(self.params, router_obj=None)
class Chat:
def __init__(self, params, router_obj: Optional[Any]):
self.params = params
def __init__(self, params: Dict[str, Any], router_obj: Optional[Any]) -> None:
self.params = dict(params) # Use a shallow copy to prevent side effects.
if self.params.get("acompletion", False) is True:
self.params.pop("acompletion")
self.completions: Union[AsyncCompletions, Completions] = AsyncCompletions(
@ -275,38 +275,40 @@ class Chat:
class Completions:
def __init__(self, params, router_obj: Optional[Any]):
self.params = params
def __init__(self, params: Dict[str, Any], router_obj: Optional[Any]) -> None:
self.params = dict(params)
self.router_obj = router_obj
def create(self, messages, model=None, **kwargs):
params = dict(self.params)
for k, v in kwargs.items():
self.params[k] = v
model = model or self.params.get("model")
params[k] = v
model = model or params.get("model")
if self.router_obj is not None:
response = self.router_obj.completion(
model=model, messages=messages, **self.params
model=model, messages=messages, **params
)
else:
response = completion(model=model, messages=messages, **self.params)
response = completion(model=model, messages=messages, **params)
return response
class AsyncCompletions:
def __init__(self, params, router_obj: Optional[Any]):
self.params = params
def __init__(self, params: Dict[str, Any], router_obj: Optional[Any]) -> None:
self.params = dict(params)
self.router_obj = router_obj
async def create(self, messages, model=None, **kwargs):
params = dict(self.params)
for k, v in kwargs.items():
self.params[k] = v
model = model or self.params.get("model")
params[k] = v
model = model or params.get("model")
if self.router_obj is not None:
response = await self.router_obj.acompletion(
model=model, messages=messages, **self.params
model=model, messages=messages, **params
)
else:
response = await acompletion(model=model, messages=messages, **self.params)
response = await acompletion(model=model, messages=messages, **params)
return response

View file

@ -11,7 +11,7 @@ sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, call, patch
import litellm
@ -341,4 +341,81 @@ async def test_extra_body_with_fallback(respx_mock: respx.MockRouter, set_openro
# Verify the response
assert response is not None
assert response.choices[0].message.content == "Hello from mocked response!"
class Test_Chat:
@pytest.fixture
def mock_completion(self, mocker) -> MagicMock:
return mocker.patch.object(litellm.main, "completion")
def test_calls_completion_without_side_effect_to_params(self, mock_completion):
params = {}
chatobj = litellm.main.Chat(params, router_obj=None)
chatobj.completions.create(
messages=[{"role": "user", "content": "hello"}],
model="gemini/gemini-1.5-flash",
foo="bar",
)
chatobj.completions.create(
messages=[{"role": "user", "content": "hello"}],
model="gemini/gemini-1.5-flash",
bar="foo",
)
assert mock_completion.call_args_list == [
call(model="gemini/gemini-1.5-flash", messages=[{"role": "user", "content": "hello"}], foo="bar"),
call(model="gemini/gemini-1.5-flash", messages=[{"role": "user", "content": "hello"}], bar="foo"),
]
assert params == {}
@pytest.fixture
def mock_acompletion(self, mocker) -> MagicMock:
return mocker.patch.object(litellm.main, "acompletion")
@pytest.mark.asyncio
async def test_calls_acompletion_without_side_effect_to_params(self, mock_acompletion):
params = {"acompletion": True}
chatobj = litellm.main.Chat(params, router_obj=None)
await chatobj.completions.create( # type: ignore
messages=[{"role": "user", "content": "hello"}],
model="gemini/gemini-1.5-flash",
foo="bar",
)
await chatobj.completions.create( # type: ignore
messages=[{"role": "user", "content": "hello"}],
model="gemini/gemini-1.5-flash",
bar="foo",
)
assert mock_acompletion.call_args_list == [
call(model="gemini/gemini-1.5-flash", messages=[{"role": "user", "content": "hello"}], foo="bar"),
call(model="gemini/gemini-1.5-flash", messages=[{"role": "user", "content": "hello"}], bar="foo"),
]
assert params == {"acompletion": True}
def test_calls_completion_with_router_obj(self, mocker):
router_obj = mocker.MagicMock()
chatobj = litellm.main.Chat({}, router_obj=router_obj)
chatobj.completions.create(
messages=[{"role": "user", "content": "hello"}],
model="gemini/gemini-1.5-flash",
foo="bar",
)
router_obj.completion.assert_called_once_with(
model="gemini/gemini-1.5-flash",
messages=[{"role": "user", "content": "hello"}],
foo="bar"
)
@pytest.mark.asyncio
async def test_calls_acompletion_with_router_obj(self, mocker):
router_obj = mocker.AsyncMock()
chatobj = litellm.main.Chat({"acompletion": True}, router_obj=router_obj)
await chatobj.completions.create( # type: ignore
messages=[{"role": "user", "content": "hello"}],
model="gemini/gemini-1.5-flash",
foo="bar",
)
router_obj.acompletion.assert_called_once_with(
model="gemini/gemini-1.5-flash",
messages=[{"role": "user", "content": "hello"}],
foo="bar"
)