fix(main.py): fix calling openai gpt-3.5-turbo-instruct via /completions

Fixes https://github.com/BerriAI/litellm/issues/749
This commit is contained in:
Krrish Dholakia 2024-07-25 09:57:19 -07:00
parent d1622f6c0c
commit 5945da4a66
5 changed files with 41 additions and 14 deletions

View file

@ -3833,7 +3833,7 @@ def text_completion(
optional_params["custom_llm_provider"] = custom_llm_provider optional_params["custom_llm_provider"] = custom_llm_provider
# get custom_llm_provider # get custom_llm_provider
_, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore _model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
if custom_llm_provider == "huggingface": if custom_llm_provider == "huggingface":
# if echo == True, for TGI llms we need to set top_n_tokens to 3 # if echo == True, for TGI llms we need to set top_n_tokens to 3
@ -3916,10 +3916,12 @@ def text_completion(
kwargs.pop("prompt", None) kwargs.pop("prompt", None)
if model is not None and model.startswith( if (
"openai/" _model is not None and custom_llm_provider == "openai"
): # for openai compatible endpoints - e.g. vllm, call the native /v1/completions endpoint for text completion calls ): # for openai compatible endpoints - e.g. vllm, call the native /v1/completions endpoint for text completion calls
model = model.replace("openai/", "text-completion-openai/") if _model not in litellm.open_ai_chat_completion_models:
model = "text-completion-openai/" + _model
optional_params.pop("custom_llm_provider", None)
kwargs["text_completion"] = True kwargs["text_completion"] = True
response = completion( response = completion(

View file

@ -1,8 +1,4 @@
model_list: model_list:
- model_name: "*" # all requests where model not in your config go to this deployment - model_name: "test-model"
litellm_params: litellm_params:
model: "openai/*" # passes our validation check that a real provider is given model: "openai/gpt-3.5-turbo-instruct-0914"
api_key: ""
general_settings:
completion_model: "gpt-3.5-turbo"

View file

@ -1,14 +1,18 @@
import sys, os import os
import sys
import traceback import traceback
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
import os, io import io
import os
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import pytest import pytest
import litellm import litellm
@ -21,6 +25,12 @@ def test_get_llm_provider():
# test_get_llm_provider() # test_get_llm_provider()
def test_get_llm_provider_gpt_instruct():
_, response, _, _ = litellm.get_llm_provider(model="gpt-3.5-turbo-instruct-0914")
assert response == "text-completion-openai"
def test_get_llm_provider_mistral_custom_api_base(): def test_get_llm_provider_mistral_custom_api_base():
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider( model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
model="mistral/mistral-large-fr", model="mistral/mistral-large-fr",

View file

@ -3840,7 +3840,26 @@ def test_completion_chatgpt_prompt():
try: try:
print("\n gpt3.5 test\n") print("\n gpt3.5 test\n")
response = text_completion( response = text_completion(
model="gpt-3.5-turbo", prompt="What's the weather in SF?" model="openai/gpt-3.5-turbo", prompt="What's the weather in SF?"
)
print(response)
response_str = response["choices"][0]["text"]
print("\n", response.choices)
print("\n", response.choices[0])
# print(response.choices[0].text)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_chatgpt_prompt()
def test_completion_gpt_instruct():
try:
response = text_completion(
model="gpt-3.5-turbo-instruct-0914",
prompt="What's the weather in SF?",
custom_llm_provider="openai",
) )
print(response) print(response)
response_str = response["choices"][0]["text"] response_str = response["choices"][0]["text"]

View file

@ -2774,7 +2774,7 @@ def get_optional_params(
tool_function["parameters"] = new_parameters tool_function["parameters"] = new_parameters
def _check_valid_arg(supported_params): def _check_valid_arg(supported_params):
verbose_logger.debug( verbose_logger.info(
f"\nLiteLLM completion() model= {model}; provider = {custom_llm_provider}" f"\nLiteLLM completion() model= {model}; provider = {custom_llm_provider}"
) )
verbose_logger.debug( verbose_logger.debug(