Merge pull request #371 from promptmetheus/simplify-mock-logic

Simplify mock logic
This commit is contained in:
Krish Dholakia 2023-09-14 09:23:45 -07:00 committed by GitHub
commit 7c9779f0ac
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -72,6 +72,32 @@ async def acompletion(*args, **kwargs):
else:
return response
## Use this in your testing pipeline, if you need to mock an LLM response
def mock_completion(model: str, messages: List, stream: bool = False, mock_response: str = "This is a mock request", **kwargs):
try:
model_response = ModelResponse()
if stream: # return a generator object, iterate through the text in chunks of 3 char / chunk
for i in range(0, len(mock_response), 3):
completion_obj = {"role": "assistant", "content": mock_response[i: i+3]}
yield {
"choices":
[
{
"delta": completion_obj,
"finish_reason": None
},
]
}
else:
## RESPONSE OBJECT
completion_response = "This is a mock request"
model_response["choices"][0]["message"]["content"] = completion_response
model_response["created"] = time.time()
model_response["model"] = "MockResponse"
return model_response
except:
raise Exception("Mock completion response failed")
@client
@timeout( # type: ignore
600
@ -96,6 +122,7 @@ def completion(
# Optional liteLLM function params
*,
return_async=False,
mock_response: Optional[str] = None,
api_key: Optional[str] = None,
api_version: Optional[str] = None,
api_base: Optional[str] = None,
@ -118,6 +145,10 @@ def completion(
caching = False,
cache_params = {}, # optional to specify metadata for caching
) -> ModelResponse:
# If `mock_response` is set, execute the `mock_completion` method instead.
if mock_response:
return mock_completion(model, messages, stream=stream, mock_response=mock_response)
args = locals()
try:
logging = litellm_logging_obj
@ -978,31 +1009,6 @@ def batch_completion(
results = [future.result() for future in completions]
return results
## Use this in your testing pipeline, if you need to mock an LLM response
def mock_completion(model: str, messages: List, stream: bool = False, mock_response: str = "This is a mock request", **kwargs):
try:
model_response = ModelResponse()
if stream: # return a generator object, iterate through the text in chunks of 3 char / chunk
for i in range(0, len(mock_response), 3):
completion_obj = {"role": "assistant", "content": mock_response[i: i+3]}
yield {
"choices":
[
{
"delta": completion_obj,
"finish_reason": None
},
]
}
else:
## RESPONSE OBJECT
completion_response = "This is a mock request"
model_response["choices"][0]["message"]["content"] = completion_response
model_response["created"] = time.time()
model_response["model"] = "MockResponse"
return model_response
except:
raise Exception("Mock completion response failed")
### EMBEDDING ENDPOINTS ####################
@client
@timeout( # type: ignore