Merge pull request #4405 from BerriAI/litellm_update_mock_completion

[Fix] - use `n` in mock completion responses
This commit is contained in:
Ishaan Jaff 2024-06-25 11:20:30 -07:00 committed by GitHub
commit 2bd993039b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 87 additions and 10 deletions

View file

@ -429,6 +429,7 @@ def mock_completion(
model: str, model: str,
messages: List, messages: List,
stream: Optional[bool] = False, stream: Optional[bool] = False,
n: Optional[int] = None,
mock_response: Union[str, Exception, dict] = "This is a mock request", mock_response: Union[str, Exception, dict] = "This is a mock request",
mock_tool_calls: Optional[List] = None, mock_tool_calls: Optional[List] = None,
logging=None, logging=None,
@ -487,18 +488,32 @@ def mock_completion(
if kwargs.get("acompletion", False) == True: if kwargs.get("acompletion", False) == True:
return CustomStreamWrapper( return CustomStreamWrapper(
completion_stream=async_mock_completion_streaming_obj( completion_stream=async_mock_completion_streaming_obj(
model_response, mock_response=mock_response, model=model model_response, mock_response=mock_response, model=model, n=n
), ),
model=model, model=model,
custom_llm_provider="openai", custom_llm_provider="openai",
logging_obj=logging, logging_obj=logging,
) )
response = mock_completion_streaming_obj( response = mock_completion_streaming_obj(
model_response, mock_response=mock_response, model=model model_response,
mock_response=mock_response,
model=model,
n=n,
) )
return response return response
if n is None:
model_response["choices"][0]["message"]["content"] = mock_response model_response["choices"][0]["message"]["content"] = mock_response
else:
_all_choices = []
for i in range(n):
_choice = litellm.utils.Choices(
index=i,
message=litellm.utils.Message(
content=mock_response, role="assistant"
),
)
_all_choices.append(_choice)
model_response["choices"] = _all_choices
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = model model_response["model"] = model
@ -945,6 +960,7 @@ def completion(
model, model,
messages, messages,
stream=stream, stream=stream,
n=n,
mock_response=mock_response, mock_response=mock_response,
mock_tool_calls=mock_tool_calls, mock_tool_calls=mock_tool_calls,
logging=logging, logging=logging,

View file

@ -58,3 +58,37 @@ async def test_async_mock_streaming_request():
assert ( assert (
complete_response == "LiteLLM is awesome" complete_response == "LiteLLM is awesome"
), f"Unexpected response got {complete_response}" ), f"Unexpected response got {complete_response}"
def test_mock_request_n_greater_than_1():
try:
model = "gpt-3.5-turbo"
messages = [{"role": "user", "content": "Hey, I'm a mock request"}]
response = litellm.mock_completion(model=model, messages=messages, n=5)
print("response: ", response)
assert len(response.choices) == 5
for choice in response.choices:
assert choice.message.content == "This is a mock request"
except:
traceback.print_exc()
@pytest.mark.asyncio()
async def test_async_mock_streaming_request_n_greater_than_1():
generator = await litellm.acompletion(
messages=[{"role": "user", "content": "Why is LiteLLM amazing?"}],
mock_response="LiteLLM is awesome",
stream=True,
model="gpt-3.5-turbo",
n=5,
)
complete_response = ""
async for chunk in generator:
print(chunk)
# complete_response += chunk["choices"][0]["delta"]["content"] or ""
# assert (
# complete_response == "LiteLLM is awesome"
# ), f"Unexpected response got {complete_response}"

View file

@ -9731,18 +9731,45 @@ class TextCompletionStreamWrapper:
raise StopAsyncIteration raise StopAsyncIteration
def mock_completion_streaming_obj(model_response, mock_response, model): def mock_completion_streaming_obj(
model_response, mock_response, model, n: Optional[int] = None
):
for i in range(0, len(mock_response), 3): for i in range(0, len(mock_response), 3):
completion_obj = {"role": "assistant", "content": mock_response[i : i + 3]} completion_obj = Delta(role="assistant", content=mock_response[i : i + 3])
if n is None:
model_response.choices[0].delta = completion_obj model_response.choices[0].delta = completion_obj
else:
_all_choices = []
for j in range(n):
_streaming_choice = litellm.utils.StreamingChoices(
index=j,
delta=litellm.utils.Delta(
role="assistant", content=mock_response[i : i + 3]
),
)
_all_choices.append(_streaming_choice)
model_response.choices = _all_choices
yield model_response yield model_response
async def async_mock_completion_streaming_obj(model_response, mock_response, model): async def async_mock_completion_streaming_obj(
model_response, mock_response, model, n: Optional[int] = None
):
for i in range(0, len(mock_response), 3): for i in range(0, len(mock_response), 3):
completion_obj = Delta(role="assistant", content=mock_response[i : i + 3]) completion_obj = Delta(role="assistant", content=mock_response[i : i + 3])
if n is None:
model_response.choices[0].delta = completion_obj model_response.choices[0].delta = completion_obj
model_response.choices[0].finish_reason = "stop" else:
_all_choices = []
for j in range(n):
_streaming_choice = litellm.utils.StreamingChoices(
index=j,
delta=litellm.utils.Delta(
role="assistant", content=mock_response[i : i + 3]
),
)
_all_choices.append(_streaming_choice)
model_response.choices = _all_choices
yield model_response yield model_response