fix(main.py): support text completion input being a list of strings

addresses - https://github.com/BerriAI/litellm/issues/2792, https://github.com/BerriAI/litellm/issues/2777
This commit is contained in:
Krrish Dholakia 2024-04-02 08:49:53 -07:00
parent 71db88115d
commit 0d949d71ab
4 changed files with 95 additions and 14 deletions

View file

@ -118,6 +118,11 @@ class LangFuseLogger:
): ):
input = prompt input = prompt
output = response_obj["choices"][0]["message"].json() output = response_obj["choices"][0]["message"].json()
elif response_obj is not None and isinstance(
response_obj, litellm.TextCompletionResponse
):
input = prompt
output = response_obj.choices[0].text
elif response_obj is not None and isinstance( elif response_obj is not None and isinstance(
response_obj, litellm.ImageResponse response_obj, litellm.ImageResponse
): ):
@ -242,6 +247,7 @@ class LangFuseLogger:
print_verbose(f"Langfuse Layer Logging - logging to langfuse v2 ") print_verbose(f"Langfuse Layer Logging - logging to langfuse v2 ")
print(f"response_obj: {response_obj}")
if supports_tags: if supports_tags:
metadata_tags = metadata.get("tags", []) metadata_tags = metadata.get("tags", [])
tags = metadata_tags tags = metadata_tags
@ -306,11 +312,13 @@ class LangFuseLogger:
usage = None usage = None
if response_obj is not None and response_obj.get("id", None) is not None: if response_obj is not None and response_obj.get("id", None) is not None:
generation_id = litellm.utils.get_logging_id(start_time, response_obj) generation_id = litellm.utils.get_logging_id(start_time, response_obj)
print(f"getting usage, cost={cost}")
usage = { usage = {
"prompt_tokens": response_obj["usage"]["prompt_tokens"], "prompt_tokens": response_obj["usage"]["prompt_tokens"],
"completion_tokens": response_obj["usage"]["completion_tokens"], "completion_tokens": response_obj["usage"]["completion_tokens"],
"total_cost": cost if supports_costs else None, "total_cost": cost if supports_costs else None,
} }
print(f"constructed usage - {usage}")
generation_name = metadata.get("generation_name", None) generation_name = metadata.get("generation_name", None)
if generation_name is None: if generation_name is None:
# just log `litellm-{call_type}` as the generation name # just log `litellm-{call_type}` as the generation name

View file

@ -3163,7 +3163,18 @@ def text_completion(
# these are the params supported by Completion() but not ChatCompletion # these are the params supported by Completion() but not ChatCompletion
# default case, non OpenAI requests go through here # default case, non OpenAI requests go through here
messages = [{"role": "system", "content": prompt}] # handle prompt formatting if prompt is a string vs. list of strings
messages = []
if isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], str):
for p in prompt:
message = {"role": "user", "content": p}
messages.append(message)
elif isinstance(prompt, str):
messages = [{"role": "user", "content": prompt}]
else:
raise Exception(
f"Unmapped prompt format. Your prompt is neither a list of strings nor a string. prompt={prompt}. File an issue - https://github.com/BerriAI/litellm/issues"
)
kwargs.pop("prompt", None) kwargs.pop("prompt", None)
response = completion( response = completion(
model=model, model=model,

View file

@ -686,6 +686,44 @@ async def test_async_chat_vertex_ai_stream():
# Text Completion # Text Completion
@pytest.mark.asyncio
async def test_async_text_completion_bedrock():
try:
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
response = await litellm.atext_completion(
model="bedrock/anthropic.claude-3-haiku-20240307-v1:0",
prompt=["Hi 👋 - i'm async text completion bedrock"],
)
# test streaming
response = await litellm.atext_completion(
model="bedrock/anthropic.claude-3-haiku-20240307-v1:0",
prompt=["Hi 👋 - i'm async text completion bedrock"],
stream=True,
)
async for chunk in response:
print(f"chunk: {chunk}")
continue
## test failure callback
try:
response = await litellm.atext_completion(
model="bedrock/",
prompt=["Hi 👋 - i'm async text completion bedrock"],
stream=True,
api_key="my-bad-key",
)
async for chunk in response:
continue
except:
pass
time.sleep(1)
print(f"customHandler.errors: {customHandler.errors}")
assert len(customHandler.errors) == 0
litellm.callbacks = []
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
## Test OpenAI text completion + Async ## Test OpenAI text completion + Async
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_text_completion_openai_stream(): async def test_async_text_completion_openai_stream():

View file

@ -852,9 +852,16 @@ class TranscriptionResponse(OpenAIObject):
############################################################ ############################################################
def print_verbose(print_statement, logger_only: bool = False): def print_verbose(
print_statement,
logger_only: bool = False,
log_level: Literal["DEBUG", "INFO"] = "DEBUG",
):
try: try:
verbose_logger.debug(print_statement) if log_level == "DEBUG":
verbose_logger.debug(print_statement)
elif log_level == "INFO":
verbose_logger.info(print_statement)
if litellm.set_verbose == True and logger_only == False: if litellm.set_verbose == True and logger_only == False:
print(print_statement) # noqa print(print_statement) # noqa
except: except:
@ -903,10 +910,20 @@ class Logging:
raise ValueError( raise ValueError(
f"Invalid call_type {call_type}. Allowed values: {allowed_values}" f"Invalid call_type {call_type}. Allowed values: {allowed_values}"
) )
if messages is not None and isinstance(messages, str): if messages is not None:
messages = [ if isinstance(messages, str):
{"role": "user", "content": messages} messages = [
] # convert text completion input to the chat completion format {"role": "user", "content": messages}
] # convert text completion input to the chat completion format
elif (
isinstance(messages, list)
and len(messages) > 0
and isinstance(messages[0], str)
):
new_messages = []
for m in messages:
new_messages.append({"role": "user", "content": m})
messages = new_messages
self.model = model self.model = model
self.messages = messages self.messages = messages
self.stream = stream self.stream = stream
@ -1199,6 +1216,7 @@ class Logging:
or isinstance(result, EmbeddingResponse) or isinstance(result, EmbeddingResponse)
or isinstance(result, ImageResponse) or isinstance(result, ImageResponse)
or isinstance(result, TranscriptionResponse) or isinstance(result, TranscriptionResponse)
or isinstance(result, TextCompletionResponse)
) )
and self.stream != True and self.stream != True
): # handle streaming separately ): # handle streaming separately
@ -4464,7 +4482,7 @@ def get_optional_params(
if unsupported_params and not litellm.drop_params: if unsupported_params and not litellm.drop_params:
raise UnsupportedParamsError( raise UnsupportedParamsError(
status_code=500, status_code=500,
message=f"{custom_llm_provider} does not support parameters: {unsupported_params}. To drop these, set `litellm.drop_params=True`.", message=f"{custom_llm_provider} does not support parameters: {unsupported_params}. To drop these, set `litellm.drop_params=True` or for proxy:\n\n`litellm_settings:\n drop_params: true`\n",
) )
def _map_and_modify_arg(supported_params: dict, provider: str, model: str): def _map_and_modify_arg(supported_params: dict, provider: str, model: str):
@ -10374,22 +10392,28 @@ def print_args_passed_to_litellm(original_function, args, kwargs):
args_str = ", ".join(map(repr, args)) args_str = ", ".join(map(repr, args))
kwargs_str = ", ".join(f"{key}={repr(value)}" for key, value in kwargs.items()) kwargs_str = ", ".join(f"{key}={repr(value)}" for key, value in kwargs.items())
print_verbose("\n") # new line before print_verbose("\n", log_level="INFO") # new line before
print_verbose("\033[92mRequest to litellm:\033[0m") print_verbose("\033[92mRequest to litellm:\033[0m", log_level="INFO")
if args and kwargs: if args and kwargs:
print_verbose( print_verbose(
f"\033[92mlitellm.{original_function.__name__}({args_str}, {kwargs_str})\033[0m" f"\033[92mlitellm.{original_function.__name__}({args_str}, {kwargs_str})\033[0m",
log_level="INFO",
) )
elif args: elif args:
print_verbose( print_verbose(
f"\033[92mlitellm.{original_function.__name__}({args_str})\033[0m" f"\033[92mlitellm.{original_function.__name__}({args_str})\033[0m",
log_level="INFO",
) )
elif kwargs: elif kwargs:
print_verbose( print_verbose(
f"\033[92mlitellm.{original_function.__name__}({kwargs_str})\033[0m" f"\033[92mlitellm.{original_function.__name__}({kwargs_str})\033[0m",
log_level="INFO",
) )
else: else:
print_verbose(f"\033[92mlitellm.{original_function.__name__}()\033[0m") print_verbose(
f"\033[92mlitellm.{original_function.__name__}()\033[0m",
log_level="INFO",
)
print_verbose("\n") # new line after print_verbose("\n") # new line after
except: except:
# This should always be non blocking # This should always be non blocking