Merge pull request #4405 from BerriAI/litellm_update_mock_completion

[Fix] - use `n` in mock completion responses
This commit is contained in:
Ishaan Jaff 2024-06-25 11:20:30 -07:00 committed by GitHub
commit 2bd993039b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 87 additions and 10 deletions

View file

@ -9731,18 +9731,45 @@ class TextCompletionStreamWrapper:
raise StopAsyncIteration
def mock_completion_streaming_obj(model_response, mock_response, model):
def mock_completion_streaming_obj(
model_response, mock_response, model, n: Optional[int] = None
):
for i in range(0, len(mock_response), 3):
completion_obj = {"role": "assistant", "content": mock_response[i : i + 3]}
model_response.choices[0].delta = completion_obj
completion_obj = Delta(role="assistant", content=mock_response[i : i + 3])
if n is None:
model_response.choices[0].delta = completion_obj
else:
_all_choices = []
for j in range(n):
_streaming_choice = litellm.utils.StreamingChoices(
index=j,
delta=litellm.utils.Delta(
role="assistant", content=mock_response[i : i + 3]
),
)
_all_choices.append(_streaming_choice)
model_response.choices = _all_choices
yield model_response
async def async_mock_completion_streaming_obj(model_response, mock_response, model):
async def async_mock_completion_streaming_obj(
model_response, mock_response, model, n: Optional[int] = None
):
for i in range(0, len(mock_response), 3):
completion_obj = Delta(role="assistant", content=mock_response[i : i + 3])
model_response.choices[0].delta = completion_obj
model_response.choices[0].finish_reason = "stop"
if n is None:
model_response.choices[0].delta = completion_obj
else:
_all_choices = []
for j in range(n):
_streaming_choice = litellm.utils.StreamingChoices(
index=j,
delta=litellm.utils.Delta(
role="assistant", content=mock_response[i : i + 3]
),
)
_all_choices.append(_streaming_choice)
model_response.choices = _all_choices
yield model_response