diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 62b033aa61..6f55165c64 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -1576,6 +1576,7 @@ class OpenAITextCompletion(BaseLLM): response = openai_client.completions.create(**data) # type: ignore response_json = response.model_dump() + ## LOGGING logging_obj.post_call( input=prompt, diff --git a/litellm/main.py b/litellm/main.py index de611c66a3..7c775310ae 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1078,6 +1078,91 @@ def completion( "api_base": api_base, }, ) + elif ( + custom_llm_provider == "text-completion-openai" + or (text_completion is True and custom_llm_provider == "openai") + or "ft:babbage-002" in model + or "ft:davinci-002" in model # support for finetuned completion models + ): + openai.api_type = "openai" + + api_base = ( + api_base + or litellm.api_base + or get_secret("OPENAI_API_BASE") + or "https://api.openai.com/v1" + ) + + openai.api_version = None + # set API KEY + + api_key = ( + api_key + or litellm.api_key + or litellm.openai_key + or get_secret("OPENAI_API_KEY") + ) + + headers = headers or litellm.headers + + ## LOAD CONFIG - if set + config = litellm.OpenAITextCompletionConfig.get_config() + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > openai_text_config(top_k=3) <- allows for dynamic variables to be passed in + optional_params[k] = v + if litellm.organization: + openai.organization = litellm.organization + + if ( + len(messages) > 0 + and "content" in messages[0] + and type(messages[0]["content"]) == list + ): + # text-davinci-003 can accept a string or array, if it's an array, assume the array is set in messages[0]['content'] + # https://platform.openai.com/docs/api-reference/completions/create + prompt = messages[0]["content"] + else: + prompt = " ".join([message["content"] for message in messages]) # type: ignore + + ## COMPLETION CALL + _response = openai_text_completions.completion( + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + api_key=api_key, + api_base=api_base, + acompletion=acompletion, + client=client, # pass AsyncOpenAI, OpenAI client + logging_obj=logging, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + timeout=timeout, # type: ignore + ) + + if ( + optional_params.get("stream", False) == False + and acompletion == False + and text_completion == False + ): + # convert to chat completion response + _response = litellm.OpenAITextCompletionConfig().convert_to_chat_model_response_object( + response_object=_response, model_response_object=model_response + ) + + if optional_params.get("stream", False) or acompletion == True: + ## LOGGING + logging.post_call( + input=messages, + api_key=api_key, + original_response=_response, + additional_args={"headers": headers}, + ) + response = _response + elif ( model in litellm.open_ai_chat_completion_models or custom_llm_provider == "custom_openai" @@ -1164,89 +1249,6 @@ def completion( original_response=response, additional_args={"headers": headers}, ) - elif ( - custom_llm_provider == "text-completion-openai" - or "ft:babbage-002" in model - or "ft:davinci-002" in model # support for finetuned completion models - ): - openai.api_type = "openai" - - api_base = ( - api_base - or litellm.api_base - or get_secret("OPENAI_API_BASE") - or "https://api.openai.com/v1" - ) - - openai.api_version = None - # set API KEY - - api_key = ( - api_key - or litellm.api_key - or litellm.openai_key - or get_secret("OPENAI_API_KEY") - ) - - headers = headers or litellm.headers - - ## LOAD CONFIG - if set - config = litellm.OpenAITextCompletionConfig.get_config() - for k, v in config.items(): - if ( - k not in optional_params - ): # completion(top_k=3) > openai_text_config(top_k=3) <- allows for dynamic variables to be passed in - optional_params[k] = v - if litellm.organization: - openai.organization = litellm.organization - - if ( - len(messages) > 0 - and "content" in messages[0] - and type(messages[0]["content"]) == list - ): - # text-davinci-003 can accept a string or array, if it's an array, assume the array is set in messages[0]['content'] - # https://platform.openai.com/docs/api-reference/completions/create - prompt = messages[0]["content"] - else: - prompt = " ".join([message["content"] for message in messages]) # type: ignore - - ## COMPLETION CALL - _response = openai_text_completions.completion( - model=model, - messages=messages, - model_response=model_response, - print_verbose=print_verbose, - api_key=api_key, - api_base=api_base, - acompletion=acompletion, - client=client, # pass AsyncOpenAI, OpenAI client - logging_obj=logging, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - timeout=timeout, # type: ignore - ) - - if ( - optional_params.get("stream", False) == False - and acompletion == False - and text_completion == False - ): - # convert to chat completion response - _response = litellm.OpenAITextCompletionConfig().convert_to_chat_model_response_object( - response_object=_response, model_response_object=model_response - ) - - if optional_params.get("stream", False) or acompletion == True: - ## LOGGING - logging.post_call( - input=messages, - api_key=api_key, - original_response=_response, - additional_args={"headers": headers}, - ) - response = _response elif ( "replicate" in model or custom_llm_provider == "replicate" diff --git a/litellm/tests/test_text_completion.py b/litellm/tests/test_text_completion.py index cac448c630..32c969ac72 100644 --- a/litellm/tests/test_text_completion.py +++ b/litellm/tests/test_text_completion.py @@ -1,24 +1,31 @@ -import sys, os, asyncio +import asyncio +import os +import sys import traceback + from dotenv import load_dotenv load_dotenv() -import os, io +import io +import os sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path +from unittest.mock import MagicMock, patch + import pytest + import litellm from litellm import ( - embedding, - completion, - text_completion, - completion_cost, - atext_completion, + RateLimitError, TextCompletionResponse, + atext_completion, + completion, + completion_cost, + embedding, + text_completion, ) -from litellm import RateLimitError litellm.num_retries = 3 @@ -4082,9 +4089,10 @@ async def test_async_text_completion_chat_model_stream(): async def test_completion_codestral_fim_api(): try: litellm.set_verbose = True - from litellm._logging import verbose_logger import logging + from litellm._logging import verbose_logger + verbose_logger.setLevel(level=logging.DEBUG) response = await litellm.atext_completion( model="text-completion-codestral/codestral-2405", @@ -4113,9 +4121,10 @@ async def test_completion_codestral_fim_api(): @pytest.mark.asyncio async def test_completion_codestral_fim_api_stream(): try: - from litellm._logging import verbose_logger import logging + from litellm._logging import verbose_logger + litellm.set_verbose = False # verbose_logger.setLevel(level=logging.DEBUG) @@ -4145,3 +4154,47 @@ async def test_completion_codestral_fim_api_stream(): # assert cost > 0.0 except Exception as e: pytest.fail(f"Error occurred: {e}") + + +def mock_post(*args, **kwargs): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {"Content-Type": "application/json"} + mock_response.model_dump.return_value = { + "id": "cmpl-7a59383dd4234092b9e5d652a7ab8143", + "object": "text_completion", + "created": 1718824735, + "model": "Sao10K/L3-70B-Euryale-v2.1", + "choices": [ + { + "index": 0, + "text": ") might be faster than then answering, and the added time it takes for the", + "logprobs": None, + "finish_reason": "length", + "stop_reason": None, + } + ], + "usage": {"prompt_tokens": 2, "total_tokens": 18, "completion_tokens": 16}, + } + return mock_response + + +def test_completion_vllm(): + """ + Asserts a text completion call for vllm actually goes to the text completion endpoint + """ + from openai import OpenAI + + client = OpenAI(api_key="my-fake-key") + + with patch.object(client.completions, "create", side_effect=mock_post) as mock_call: + response = text_completion( + model="openai/gemini-1.5-flash", + prompt="ping", + client=client, + ) + print(response) + + assert response.usage.prompt_tokens == 2 + + mock_call.assert_called_once()