diff --git a/litellm/main.py b/litellm/main.py index de0716fd96..95e0c1888f 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -257,14 +257,14 @@ class LiteLLM: timeout: Optional[float] = 600, max_retries: Optional[int] = litellm.num_retries, default_headers: Optional[Mapping[str, str]] = None, - ): + ) -> None: self.params = locals() self.chat = Chat(self.params, router_obj=None) class Chat: - def __init__(self, params, router_obj: Optional[Any]): - self.params = params + def __init__(self, params: Dict[str, Any], router_obj: Optional[Any]) -> None: + self.params = dict(params) # Use a shallow copy to prevent side effects. if self.params.get("acompletion", False) is True: self.params.pop("acompletion") self.completions: Union[AsyncCompletions, Completions] = AsyncCompletions( @@ -275,38 +275,40 @@ class Chat: class Completions: - def __init__(self, params, router_obj: Optional[Any]): - self.params = params + def __init__(self, params: Dict[str, Any], router_obj: Optional[Any]) -> None: + self.params = dict(params) self.router_obj = router_obj def create(self, messages, model=None, **kwargs): + params = dict(self.params) for k, v in kwargs.items(): - self.params[k] = v - model = model or self.params.get("model") + params[k] = v + model = model or params.get("model") if self.router_obj is not None: response = self.router_obj.completion( - model=model, messages=messages, **self.params + model=model, messages=messages, **params ) else: - response = completion(model=model, messages=messages, **self.params) + response = completion(model=model, messages=messages, **params) return response class AsyncCompletions: - def __init__(self, params, router_obj: Optional[Any]): - self.params = params + def __init__(self, params: Dict[str, Any], router_obj: Optional[Any]) -> None: + self.params = dict(params) self.router_obj = router_obj async def create(self, messages, model=None, **kwargs): + params = dict(self.params) for k, v in kwargs.items(): - self.params[k] = v - model = model or self.params.get("model") + params[k] = v + model = model or params.get("model") if self.router_obj is not None: response = await self.router_obj.acompletion( - model=model, messages=messages, **self.params + model=model, messages=messages, **params ) else: - response = await acompletion(model=model, messages=messages, **self.params) + response = await acompletion(model=model, messages=messages, **params) return response diff --git a/tests/litellm/test_main.py b/tests/litellm/test_main.py index b3e085df6c..453e2f6677 100644 --- a/tests/litellm/test_main.py +++ b/tests/litellm/test_main.py @@ -11,7 +11,7 @@ sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, call, patch import litellm @@ -341,4 +341,81 @@ async def test_extra_body_with_fallback(respx_mock: respx.MockRouter, set_openro # Verify the response assert response is not None assert response.choices[0].message.content == "Hello from mocked response!" - \ No newline at end of file + + +class Test_Chat: + @pytest.fixture + def mock_completion(self, mocker) -> MagicMock: + return mocker.patch.object(litellm.main, "completion") + + def test_calls_completion_without_side_effect_to_params(self, mock_completion): + params = {} + chatobj = litellm.main.Chat(params, router_obj=None) + chatobj.completions.create( + messages=[{"role": "user", "content": "hello"}], + model="gemini/gemini-1.5-flash", + foo="bar", + ) + chatobj.completions.create( + messages=[{"role": "user", "content": "hello"}], + model="gemini/gemini-1.5-flash", + bar="foo", + ) + assert mock_completion.call_args_list == [ + call(model="gemini/gemini-1.5-flash", messages=[{"role": "user", "content": "hello"}], foo="bar"), + call(model="gemini/gemini-1.5-flash", messages=[{"role": "user", "content": "hello"}], bar="foo"), + ] + assert params == {} + + @pytest.fixture + def mock_acompletion(self, mocker) -> MagicMock: + return mocker.patch.object(litellm.main, "acompletion") + + @pytest.mark.asyncio + async def test_calls_acompletion_without_side_effect_to_params(self, mock_acompletion): + params = {"acompletion": True} + chatobj = litellm.main.Chat(params, router_obj=None) + await chatobj.completions.create( # type: ignore + messages=[{"role": "user", "content": "hello"}], + model="gemini/gemini-1.5-flash", + foo="bar", + ) + await chatobj.completions.create( # type: ignore + messages=[{"role": "user", "content": "hello"}], + model="gemini/gemini-1.5-flash", + bar="foo", + ) + assert mock_acompletion.call_args_list == [ + call(model="gemini/gemini-1.5-flash", messages=[{"role": "user", "content": "hello"}], foo="bar"), + call(model="gemini/gemini-1.5-flash", messages=[{"role": "user", "content": "hello"}], bar="foo"), + ] + assert params == {"acompletion": True} + + def test_calls_completion_with_router_obj(self, mocker): + router_obj = mocker.MagicMock() + chatobj = litellm.main.Chat({}, router_obj=router_obj) + chatobj.completions.create( + messages=[{"role": "user", "content": "hello"}], + model="gemini/gemini-1.5-flash", + foo="bar", + ) + router_obj.completion.assert_called_once_with( + model="gemini/gemini-1.5-flash", + messages=[{"role": "user", "content": "hello"}], + foo="bar" + ) + + @pytest.mark.asyncio + async def test_calls_acompletion_with_router_obj(self, mocker): + router_obj = mocker.AsyncMock() + chatobj = litellm.main.Chat({"acompletion": True}, router_obj=router_obj) + await chatobj.completions.create( # type: ignore + messages=[{"role": "user", "content": "hello"}], + model="gemini/gemini-1.5-flash", + foo="bar", + ) + router_obj.acompletion.assert_called_once_with( + model="gemini/gemini-1.5-flash", + messages=[{"role": "user", "content": "hello"}], + foo="bar" + )