forked from phoenix/litellm-mirror
Merge pull request #2991 from BerriAI/litellm_fix_text_completion_caching
[Feat] Support + Test caching for TextCompletion
This commit is contained in:
commit
0540ca4918
2 changed files with 51 additions and 1 deletions
|
@ -707,6 +707,40 @@ async def test_redis_cache_acompletion_stream():
|
||||||
# test_redis_cache_acompletion_stream()
|
# test_redis_cache_acompletion_stream()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_redis_cache_atext_completion():
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
|
prompt = f"write a one sentence poem about: {uuid.uuid4()}"
|
||||||
|
litellm.cache = Cache(
|
||||||
|
type="redis",
|
||||||
|
host=os.environ["REDIS_HOST"],
|
||||||
|
port=os.environ["REDIS_PORT"],
|
||||||
|
password=os.environ["REDIS_PASSWORD"],
|
||||||
|
supported_call_types=["atext_completion"],
|
||||||
|
)
|
||||||
|
print("test for caching, atext_completion")
|
||||||
|
|
||||||
|
response1 = await litellm.atext_completion(
|
||||||
|
model="gpt-3.5-turbo-instruct", prompt=prompt, max_tokens=40, temperature=1
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
print("\n\n Response 1 content: ", response1, "\n\n")
|
||||||
|
|
||||||
|
response2 = await litellm.atext_completion(
|
||||||
|
model="gpt-3.5-turbo-instruct", prompt=prompt, max_tokens=40, temperature=1
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response2)
|
||||||
|
|
||||||
|
assert response1 == response2
|
||||||
|
assert response1.choices == response2.choices
|
||||||
|
except Exception as e:
|
||||||
|
print(f"{str(e)}\n\n{traceback.format_exc()}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_redis_cache_acompletion_stream_bedrock():
|
async def test_redis_cache_acompletion_stream_bedrock():
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
|
@ -2996,7 +2996,7 @@ def client(original_function):
|
||||||
)
|
)
|
||||||
): # allow users to control returning cached responses from the completion function
|
): # allow users to control returning cached responses from the completion function
|
||||||
# checking cache
|
# checking cache
|
||||||
print_verbose(f"INSIDE CHECKING CACHE")
|
print_verbose("INSIDE CHECKING CACHE")
|
||||||
if (
|
if (
|
||||||
litellm.cache is not None
|
litellm.cache is not None
|
||||||
and str(original_function.__name__)
|
and str(original_function.__name__)
|
||||||
|
@ -3103,6 +3103,22 @@ def client(original_function):
|
||||||
response_object=cached_result,
|
response_object=cached_result,
|
||||||
model_response_object=ModelResponse(),
|
model_response_object=ModelResponse(),
|
||||||
)
|
)
|
||||||
|
if (
|
||||||
|
call_type == CallTypes.atext_completion.value
|
||||||
|
and isinstance(cached_result, dict)
|
||||||
|
):
|
||||||
|
if kwargs.get("stream", False) == True:
|
||||||
|
cached_result = convert_to_streaming_response_async(
|
||||||
|
response_object=cached_result,
|
||||||
|
)
|
||||||
|
cached_result = CustomStreamWrapper(
|
||||||
|
completion_stream=cached_result,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="cached_response",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cached_result = TextCompletionResponse(**cached_result)
|
||||||
elif call_type == CallTypes.aembedding.value and isinstance(
|
elif call_type == CallTypes.aembedding.value and isinstance(
|
||||||
cached_result, dict
|
cached_result, dict
|
||||||
):
|
):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue