mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Merge c63859a136
into b82af5b826
This commit is contained in:
commit
a112ccc3db
2 changed files with 96 additions and 17 deletions
|
@ -257,14 +257,14 @@ class LiteLLM:
|
||||||
timeout: Optional[float] = 600,
|
timeout: Optional[float] = 600,
|
||||||
max_retries: Optional[int] = litellm.num_retries,
|
max_retries: Optional[int] = litellm.num_retries,
|
||||||
default_headers: Optional[Mapping[str, str]] = None,
|
default_headers: Optional[Mapping[str, str]] = None,
|
||||||
):
|
) -> None:
|
||||||
self.params = locals()
|
self.params = locals()
|
||||||
self.chat = Chat(self.params, router_obj=None)
|
self.chat = Chat(self.params, router_obj=None)
|
||||||
|
|
||||||
|
|
||||||
class Chat:
|
class Chat:
|
||||||
def __init__(self, params, router_obj: Optional[Any]):
|
def __init__(self, params: Dict[str, Any], router_obj: Optional[Any]) -> None:
|
||||||
self.params = params
|
self.params = dict(params) # Use a shallow copy to prevent side effects.
|
||||||
if self.params.get("acompletion", False) is True:
|
if self.params.get("acompletion", False) is True:
|
||||||
self.params.pop("acompletion")
|
self.params.pop("acompletion")
|
||||||
self.completions: Union[AsyncCompletions, Completions] = AsyncCompletions(
|
self.completions: Union[AsyncCompletions, Completions] = AsyncCompletions(
|
||||||
|
@ -275,38 +275,40 @@ class Chat:
|
||||||
|
|
||||||
|
|
||||||
class Completions:
|
class Completions:
|
||||||
def __init__(self, params, router_obj: Optional[Any]):
|
def __init__(self, params: Dict[str, Any], router_obj: Optional[Any]) -> None:
|
||||||
self.params = params
|
self.params = dict(params)
|
||||||
self.router_obj = router_obj
|
self.router_obj = router_obj
|
||||||
|
|
||||||
def create(self, messages, model=None, **kwargs):
|
def create(self, messages, model=None, **kwargs):
|
||||||
|
params = dict(self.params)
|
||||||
for k, v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
self.params[k] = v
|
params[k] = v
|
||||||
model = model or self.params.get("model")
|
model = model or params.get("model")
|
||||||
if self.router_obj is not None:
|
if self.router_obj is not None:
|
||||||
response = self.router_obj.completion(
|
response = self.router_obj.completion(
|
||||||
model=model, messages=messages, **self.params
|
model=model, messages=messages, **params
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = completion(model=model, messages=messages, **self.params)
|
response = completion(model=model, messages=messages, **params)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
class AsyncCompletions:
|
class AsyncCompletions:
|
||||||
def __init__(self, params, router_obj: Optional[Any]):
|
def __init__(self, params: Dict[str, Any], router_obj: Optional[Any]) -> None:
|
||||||
self.params = params
|
self.params = dict(params)
|
||||||
self.router_obj = router_obj
|
self.router_obj = router_obj
|
||||||
|
|
||||||
async def create(self, messages, model=None, **kwargs):
|
async def create(self, messages, model=None, **kwargs):
|
||||||
|
params = dict(self.params)
|
||||||
for k, v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
self.params[k] = v
|
params[k] = v
|
||||||
model = model or self.params.get("model")
|
model = model or params.get("model")
|
||||||
if self.router_obj is not None:
|
if self.router_obj is not None:
|
||||||
response = await self.router_obj.acompletion(
|
response = await self.router_obj.acompletion(
|
||||||
model=model, messages=messages, **self.params
|
model=model, messages=messages, **params
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = await acompletion(model=model, messages=messages, **self.params)
|
response = await acompletion(model=model, messages=messages, **params)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,7 @@ sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, call, patch
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
|
@ -341,4 +341,81 @@ async def test_extra_body_with_fallback(respx_mock: respx.MockRouter, set_openro
|
||||||
# Verify the response
|
# Verify the response
|
||||||
assert response is not None
|
assert response is not None
|
||||||
assert response.choices[0].message.content == "Hello from mocked response!"
|
assert response.choices[0].message.content == "Hello from mocked response!"
|
||||||
|
|
||||||
|
|
||||||
|
class Test_Chat:
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_completion(self, mocker) -> MagicMock:
|
||||||
|
return mocker.patch.object(litellm.main, "completion")
|
||||||
|
|
||||||
|
def test_calls_completion_without_side_effect_to_params(self, mock_completion):
|
||||||
|
params = {}
|
||||||
|
chatobj = litellm.main.Chat(params, router_obj=None)
|
||||||
|
chatobj.completions.create(
|
||||||
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
|
model="gemini/gemini-1.5-flash",
|
||||||
|
foo="bar",
|
||||||
|
)
|
||||||
|
chatobj.completions.create(
|
||||||
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
|
model="gemini/gemini-1.5-flash",
|
||||||
|
bar="foo",
|
||||||
|
)
|
||||||
|
assert mock_completion.call_args_list == [
|
||||||
|
call(model="gemini/gemini-1.5-flash", messages=[{"role": "user", "content": "hello"}], foo="bar"),
|
||||||
|
call(model="gemini/gemini-1.5-flash", messages=[{"role": "user", "content": "hello"}], bar="foo"),
|
||||||
|
]
|
||||||
|
assert params == {}
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_acompletion(self, mocker) -> MagicMock:
|
||||||
|
return mocker.patch.object(litellm.main, "acompletion")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_calls_acompletion_without_side_effect_to_params(self, mock_acompletion):
|
||||||
|
params = {"acompletion": True}
|
||||||
|
chatobj = litellm.main.Chat(params, router_obj=None)
|
||||||
|
await chatobj.completions.create( # type: ignore
|
||||||
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
|
model="gemini/gemini-1.5-flash",
|
||||||
|
foo="bar",
|
||||||
|
)
|
||||||
|
await chatobj.completions.create( # type: ignore
|
||||||
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
|
model="gemini/gemini-1.5-flash",
|
||||||
|
bar="foo",
|
||||||
|
)
|
||||||
|
assert mock_acompletion.call_args_list == [
|
||||||
|
call(model="gemini/gemini-1.5-flash", messages=[{"role": "user", "content": "hello"}], foo="bar"),
|
||||||
|
call(model="gemini/gemini-1.5-flash", messages=[{"role": "user", "content": "hello"}], bar="foo"),
|
||||||
|
]
|
||||||
|
assert params == {"acompletion": True}
|
||||||
|
|
||||||
|
def test_calls_completion_with_router_obj(self, mocker):
|
||||||
|
router_obj = mocker.MagicMock()
|
||||||
|
chatobj = litellm.main.Chat({}, router_obj=router_obj)
|
||||||
|
chatobj.completions.create(
|
||||||
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
|
model="gemini/gemini-1.5-flash",
|
||||||
|
foo="bar",
|
||||||
|
)
|
||||||
|
router_obj.completion.assert_called_once_with(
|
||||||
|
model="gemini/gemini-1.5-flash",
|
||||||
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
|
foo="bar"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_calls_acompletion_with_router_obj(self, mocker):
|
||||||
|
router_obj = mocker.AsyncMock()
|
||||||
|
chatobj = litellm.main.Chat({"acompletion": True}, router_obj=router_obj)
|
||||||
|
await chatobj.completions.create( # type: ignore
|
||||||
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
|
model="gemini/gemini-1.5-flash",
|
||||||
|
foo="bar",
|
||||||
|
)
|
||||||
|
router_obj.acompletion.assert_called_once_with(
|
||||||
|
model="gemini/gemini-1.5-flash",
|
||||||
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
|
foo="bar"
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue