mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
Merge c63859a136
into b82af5b826
This commit is contained in:
commit
a112ccc3db
2 changed files with 96 additions and 17 deletions
|
@ -257,14 +257,14 @@ class LiteLLM:
|
|||
timeout: Optional[float] = 600,
|
||||
max_retries: Optional[int] = litellm.num_retries,
|
||||
default_headers: Optional[Mapping[str, str]] = None,
|
||||
):
|
||||
) -> None:
|
||||
self.params = locals()
|
||||
self.chat = Chat(self.params, router_obj=None)
|
||||
|
||||
|
||||
class Chat:
|
||||
def __init__(self, params, router_obj: Optional[Any]):
|
||||
self.params = params
|
||||
def __init__(self, params: Dict[str, Any], router_obj: Optional[Any]) -> None:
|
||||
self.params = dict(params) # Use a shallow copy to prevent side effects.
|
||||
if self.params.get("acompletion", False) is True:
|
||||
self.params.pop("acompletion")
|
||||
self.completions: Union[AsyncCompletions, Completions] = AsyncCompletions(
|
||||
|
@ -275,38 +275,40 @@ class Chat:
|
|||
|
||||
|
||||
class Completions:
|
||||
def __init__(self, params, router_obj: Optional[Any]):
|
||||
self.params = params
|
||||
def __init__(self, params: Dict[str, Any], router_obj: Optional[Any]) -> None:
|
||||
self.params = dict(params)
|
||||
self.router_obj = router_obj
|
||||
|
||||
def create(self, messages, model=None, **kwargs):
|
||||
params = dict(self.params)
|
||||
for k, v in kwargs.items():
|
||||
self.params[k] = v
|
||||
model = model or self.params.get("model")
|
||||
params[k] = v
|
||||
model = model or params.get("model")
|
||||
if self.router_obj is not None:
|
||||
response = self.router_obj.completion(
|
||||
model=model, messages=messages, **self.params
|
||||
model=model, messages=messages, **params
|
||||
)
|
||||
else:
|
||||
response = completion(model=model, messages=messages, **self.params)
|
||||
response = completion(model=model, messages=messages, **params)
|
||||
return response
|
||||
|
||||
|
||||
class AsyncCompletions:
|
||||
def __init__(self, params, router_obj: Optional[Any]):
|
||||
self.params = params
|
||||
def __init__(self, params: Dict[str, Any], router_obj: Optional[Any]) -> None:
|
||||
self.params = dict(params)
|
||||
self.router_obj = router_obj
|
||||
|
||||
async def create(self, messages, model=None, **kwargs):
|
||||
params = dict(self.params)
|
||||
for k, v in kwargs.items():
|
||||
self.params[k] = v
|
||||
model = model or self.params.get("model")
|
||||
params[k] = v
|
||||
model = model or params.get("model")
|
||||
if self.router_obj is not None:
|
||||
response = await self.router_obj.acompletion(
|
||||
model=model, messages=messages, **self.params
|
||||
model=model, messages=messages, **params
|
||||
)
|
||||
else:
|
||||
response = await acompletion(model=model, messages=messages, **self.params)
|
||||
response = await acompletion(model=model, messages=messages, **params)
|
||||
return response
|
||||
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ sys.path.insert(
|
|||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import litellm
|
||||
|
||||
|
@ -341,4 +341,81 @@ async def test_extra_body_with_fallback(respx_mock: respx.MockRouter, set_openro
|
|||
# Verify the response
|
||||
assert response is not None
|
||||
assert response.choices[0].message.content == "Hello from mocked response!"
|
||||
|
||||
|
||||
|
||||
class Test_Chat:
|
||||
@pytest.fixture
|
||||
def mock_completion(self, mocker) -> MagicMock:
|
||||
return mocker.patch.object(litellm.main, "completion")
|
||||
|
||||
def test_calls_completion_without_side_effect_to_params(self, mock_completion):
|
||||
params = {}
|
||||
chatobj = litellm.main.Chat(params, router_obj=None)
|
||||
chatobj.completions.create(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
model="gemini/gemini-1.5-flash",
|
||||
foo="bar",
|
||||
)
|
||||
chatobj.completions.create(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
model="gemini/gemini-1.5-flash",
|
||||
bar="foo",
|
||||
)
|
||||
assert mock_completion.call_args_list == [
|
||||
call(model="gemini/gemini-1.5-flash", messages=[{"role": "user", "content": "hello"}], foo="bar"),
|
||||
call(model="gemini/gemini-1.5-flash", messages=[{"role": "user", "content": "hello"}], bar="foo"),
|
||||
]
|
||||
assert params == {}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_acompletion(self, mocker) -> MagicMock:
|
||||
return mocker.patch.object(litellm.main, "acompletion")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calls_acompletion_without_side_effect_to_params(self, mock_acompletion):
|
||||
params = {"acompletion": True}
|
||||
chatobj = litellm.main.Chat(params, router_obj=None)
|
||||
await chatobj.completions.create( # type: ignore
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
model="gemini/gemini-1.5-flash",
|
||||
foo="bar",
|
||||
)
|
||||
await chatobj.completions.create( # type: ignore
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
model="gemini/gemini-1.5-flash",
|
||||
bar="foo",
|
||||
)
|
||||
assert mock_acompletion.call_args_list == [
|
||||
call(model="gemini/gemini-1.5-flash", messages=[{"role": "user", "content": "hello"}], foo="bar"),
|
||||
call(model="gemini/gemini-1.5-flash", messages=[{"role": "user", "content": "hello"}], bar="foo"),
|
||||
]
|
||||
assert params == {"acompletion": True}
|
||||
|
||||
def test_calls_completion_with_router_obj(self, mocker):
|
||||
router_obj = mocker.MagicMock()
|
||||
chatobj = litellm.main.Chat({}, router_obj=router_obj)
|
||||
chatobj.completions.create(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
model="gemini/gemini-1.5-flash",
|
||||
foo="bar",
|
||||
)
|
||||
router_obj.completion.assert_called_once_with(
|
||||
model="gemini/gemini-1.5-flash",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
foo="bar"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calls_acompletion_with_router_obj(self, mocker):
|
||||
router_obj = mocker.AsyncMock()
|
||||
chatobj = litellm.main.Chat({"acompletion": True}, router_obj=router_obj)
|
||||
await chatobj.completions.create( # type: ignore
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
model="gemini/gemini-1.5-flash",
|
||||
foo="bar",
|
||||
)
|
||||
router_obj.acompletion.assert_called_once_with(
|
||||
model="gemini/gemini-1.5-flash",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
foo="bar"
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue