fixes to mock completion

This commit is contained in:
Krrish Dholakia 2023-09-14 10:03:57 -07:00
parent 1b2cf704af
commit e2ea4adb84
7 changed files with 26 additions and 25 deletions

View file

@ -17,7 +17,8 @@ from litellm.utils import (
CustomStreamWrapper, CustomStreamWrapper,
read_config_args, read_config_args,
completion_with_fallbacks, completion_with_fallbacks,
get_llm_provider get_llm_provider,
mock_completion_streaming_obj
) )
from .llms import anthropic from .llms import anthropic
from .llms import together_ai from .llms import together_ai
@ -72,30 +73,22 @@ async def acompletion(*args, **kwargs):
else: else:
return response return response
## Use this in your testing pipeline, if you need to mock an LLM response
def mock_completion(model: str, messages: List, stream: bool = False, mock_response: str = "This is a mock request", **kwargs): def mock_completion(model: str, messages: List, stream: bool = False, mock_response: str = "This is a mock request", **kwargs):
try: try:
model_response = ModelResponse() model_response = ModelResponse(stream=stream)
if stream: # return a generator object, iterate through the text in chunks of 3 char / chunk if stream is True:
for i in range(0, len(mock_response), 3): # don't try to access stream object,
completion_obj = {"role": "assistant", "content": mock_response[i: i+3]} response = mock_completion_streaming_obj(model_response, mock_response=mock_response, model=model)
yield { return response
"choices":
[
{
"delta": completion_obj,
"finish_reason": None
},
]
}
else:
## RESPONSE OBJECT
completion_response = "This is a mock request" completion_response = "This is a mock request"
model_response["choices"][0]["message"]["content"] = completion_response model_response["choices"][0]["message"]["content"] = completion_response
model_response["created"] = time.time() model_response["created"] = time.time()
model_response["model"] = "MockResponse" model_response["model"] = model
return model_response return model_response
except: except:
traceback.print_exc()
raise Exception("Mock completion response failed") raise Exception("Mock completion response failed")
@client @client

View file

@ -13,11 +13,13 @@ def test_mock_request():
try: try:
model = "gpt-3.5-turbo" model = "gpt-3.5-turbo"
messages = [{"role": "user", "content": "Hey, I'm a mock request"}] messages = [{"role": "user", "content": "Hey, I'm a mock request"}]
response = litellm.mock_completion(model=model, messages=messages) response = litellm.mock_completion(model=model, messages=messages, stream=False)
print(response) print(response)
print(type(response))
except: except:
traceback.print_exc() traceback.print_exc()
# test_mock_request()
def test_streaming_mock_request(): def test_streaming_mock_request():
try: try:
model = "gpt-3.5-turbo" model = "gpt-3.5-turbo"

View file

@ -2291,7 +2291,7 @@ class CustomStreamWrapper:
# Log the type of the received item # Log the type of the received item
self.logging_obj.post_call(str(type(completion_stream))) self.logging_obj.post_call(str(type(completion_stream)))
if model in litellm.cohere_models: if model in litellm.cohere_models:
# cohere does not return an iterator, so we need to wrap it in one # these do not return an iterator, so we need to wrap it in one
self.completion_stream = iter(completion_stream) self.completion_stream = iter(completion_stream)
else: else:
self.completion_stream = completion_stream self.completion_stream = completion_stream
@ -2461,6 +2461,12 @@ class CustomStreamWrapper:
raise StopAsyncIteration raise StopAsyncIteration
def mock_completion_streaming_obj(model_response, mock_response, model):
for i in range(0, len(mock_response), 3):
completion_obj = {"role": "assistant", "content": mock_response[i: i+3]}
model_response.choices[0].delta = completion_obj
yield model_response
########## Reading Config File ############################ ########## Reading Config File ############################
def read_config_args(config_path): def read_config_args(config_path):
try: try:

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "litellm" name = "litellm"
version = "0.1.624" version = "0.1.625"
description = "Library to easily interface with LLM API providers" description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"] authors = ["BerriAI"]
license = "MIT License" license = "MIT License"