fix(main.py): route openai calls to /completion when text_completion is True

This commit is contained in:
Krrish Dholakia 2024-06-19 12:36:43 -07:00
parent 93c5625dc6
commit 9cc104eb03
3 changed files with 149 additions and 93 deletions

View file

@ -1576,6 +1576,7 @@ class OpenAITextCompletion(BaseLLM):
response = openai_client.completions.create(**data) # type: ignore response = openai_client.completions.create(**data) # type: ignore
response_json = response.model_dump() response_json = response.model_dump()
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=prompt, input=prompt,

View file

@ -1078,6 +1078,91 @@ def completion(
"api_base": api_base, "api_base": api_base,
}, },
) )
elif (
custom_llm_provider == "text-completion-openai"
or (text_completion is True and custom_llm_provider == "openai")
or "ft:babbage-002" in model
or "ft:davinci-002" in model # support for finetuned completion models
):
openai.api_type = "openai"
api_base = (
api_base
or litellm.api_base
or get_secret("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
openai.api_version = None
# set API KEY
api_key = (
api_key
or litellm.api_key
or litellm.openai_key
or get_secret("OPENAI_API_KEY")
)
headers = headers or litellm.headers
## LOAD CONFIG - if set
config = litellm.OpenAITextCompletionConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > openai_text_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
if litellm.organization:
openai.organization = litellm.organization
if (
len(messages) > 0
and "content" in messages[0]
and type(messages[0]["content"]) == list
):
# text-davinci-003 can accept a string or array, if it's an array, assume the array is set in messages[0]['content']
# https://platform.openai.com/docs/api-reference/completions/create
prompt = messages[0]["content"]
else:
prompt = " ".join([message["content"] for message in messages]) # type: ignore
## COMPLETION CALL
_response = openai_text_completions.completion(
model=model,
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
api_key=api_key,
api_base=api_base,
acompletion=acompletion,
client=client, # pass AsyncOpenAI, OpenAI client
logging_obj=logging,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
timeout=timeout, # type: ignore
)
if (
optional_params.get("stream", False) == False
and acompletion == False
and text_completion == False
):
# convert to chat completion response
_response = litellm.OpenAITextCompletionConfig().convert_to_chat_model_response_object(
response_object=_response, model_response_object=model_response
)
if optional_params.get("stream", False) or acompletion == True:
## LOGGING
logging.post_call(
input=messages,
api_key=api_key,
original_response=_response,
additional_args={"headers": headers},
)
response = _response
elif ( elif (
model in litellm.open_ai_chat_completion_models model in litellm.open_ai_chat_completion_models
or custom_llm_provider == "custom_openai" or custom_llm_provider == "custom_openai"
@ -1164,89 +1249,6 @@ def completion(
original_response=response, original_response=response,
additional_args={"headers": headers}, additional_args={"headers": headers},
) )
elif (
custom_llm_provider == "text-completion-openai"
or "ft:babbage-002" in model
or "ft:davinci-002" in model # support for finetuned completion models
):
openai.api_type = "openai"
api_base = (
api_base
or litellm.api_base
or get_secret("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
openai.api_version = None
# set API KEY
api_key = (
api_key
or litellm.api_key
or litellm.openai_key
or get_secret("OPENAI_API_KEY")
)
headers = headers or litellm.headers
## LOAD CONFIG - if set
config = litellm.OpenAITextCompletionConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > openai_text_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
if litellm.organization:
openai.organization = litellm.organization
if (
len(messages) > 0
and "content" in messages[0]
and type(messages[0]["content"]) == list
):
# text-davinci-003 can accept a string or array, if it's an array, assume the array is set in messages[0]['content']
# https://platform.openai.com/docs/api-reference/completions/create
prompt = messages[0]["content"]
else:
prompt = " ".join([message["content"] for message in messages]) # type: ignore
## COMPLETION CALL
_response = openai_text_completions.completion(
model=model,
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
api_key=api_key,
api_base=api_base,
acompletion=acompletion,
client=client, # pass AsyncOpenAI, OpenAI client
logging_obj=logging,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
timeout=timeout, # type: ignore
)
if (
optional_params.get("stream", False) == False
and acompletion == False
and text_completion == False
):
# convert to chat completion response
_response = litellm.OpenAITextCompletionConfig().convert_to_chat_model_response_object(
response_object=_response, model_response_object=model_response
)
if optional_params.get("stream", False) or acompletion == True:
## LOGGING
logging.post_call(
input=messages,
api_key=api_key,
original_response=_response,
additional_args={"headers": headers},
)
response = _response
elif ( elif (
"replicate" in model "replicate" in model
or custom_llm_provider == "replicate" or custom_llm_provider == "replicate"

View file

@ -1,24 +1,31 @@
import sys, os, asyncio import asyncio
import os
import sys
import traceback import traceback
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
import os, io import io
import os
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
from unittest.mock import MagicMock, patch
import pytest import pytest
import litellm import litellm
from litellm import ( from litellm import (
embedding, RateLimitError,
completion,
text_completion,
completion_cost,
atext_completion,
TextCompletionResponse, TextCompletionResponse,
atext_completion,
completion,
completion_cost,
embedding,
text_completion,
) )
from litellm import RateLimitError
litellm.num_retries = 3 litellm.num_retries = 3
@ -4082,9 +4089,10 @@ async def test_async_text_completion_chat_model_stream():
async def test_completion_codestral_fim_api(): async def test_completion_codestral_fim_api():
try: try:
litellm.set_verbose = True litellm.set_verbose = True
from litellm._logging import verbose_logger
import logging import logging
from litellm._logging import verbose_logger
verbose_logger.setLevel(level=logging.DEBUG) verbose_logger.setLevel(level=logging.DEBUG)
response = await litellm.atext_completion( response = await litellm.atext_completion(
model="text-completion-codestral/codestral-2405", model="text-completion-codestral/codestral-2405",
@ -4113,9 +4121,10 @@ async def test_completion_codestral_fim_api():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_completion_codestral_fim_api_stream(): async def test_completion_codestral_fim_api_stream():
try: try:
from litellm._logging import verbose_logger
import logging import logging
from litellm._logging import verbose_logger
litellm.set_verbose = False litellm.set_verbose = False
# verbose_logger.setLevel(level=logging.DEBUG) # verbose_logger.setLevel(level=logging.DEBUG)
@ -4145,3 +4154,47 @@ async def test_completion_codestral_fim_api_stream():
# assert cost > 0.0 # assert cost > 0.0
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
def mock_post(*args, **kwargs):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}
mock_response.model_dump.return_value = {
"id": "cmpl-7a59383dd4234092b9e5d652a7ab8143",
"object": "text_completion",
"created": 1718824735,
"model": "Sao10K/L3-70B-Euryale-v2.1",
"choices": [
{
"index": 0,
"text": ") might be faster than then answering, and the added time it takes for the",
"logprobs": None,
"finish_reason": "length",
"stop_reason": None,
}
],
"usage": {"prompt_tokens": 2, "total_tokens": 18, "completion_tokens": 16},
}
return mock_response
def test_completion_vllm():
"""
Asserts a text completion call for vllm actually goes to the text completion endpoint
"""
from openai import OpenAI
client = OpenAI(api_key="my-fake-key")
with patch.object(client.completions, "create", side_effect=mock_post) as mock_call:
response = text_completion(
model="openai/gemini-1.5-flash",
prompt="ping",
client=client,
)
print(response)
assert response.usage.prompt_tokens == 2
mock_call.assert_called_once()