From 0d949d71ab6f54f74ddf243f692ed9667b0db324 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 2 Apr 2024 08:49:53 -0700 Subject: [PATCH] fix(main.py): support text completion input being a list of strings addresses - https://github.com/BerriAI/litellm/issues/2792, https://github.com/BerriAI/litellm/issues/2777 --- litellm/integrations/langfuse.py | 8 ++++ litellm/main.py | 13 +++++- litellm/tests/test_custom_callback_input.py | 38 ++++++++++++++++ litellm/utils.py | 50 +++++++++++++++------ 4 files changed, 95 insertions(+), 14 deletions(-) diff --git a/litellm/integrations/langfuse.py b/litellm/integrations/langfuse.py index 281afc2d7..db7dcb8cd 100644 --- a/litellm/integrations/langfuse.py +++ b/litellm/integrations/langfuse.py @@ -118,6 +118,11 @@ class LangFuseLogger: ): input = prompt output = response_obj["choices"][0]["message"].json() + elif response_obj is not None and isinstance( + response_obj, litellm.TextCompletionResponse + ): + input = prompt + output = response_obj.choices[0].text elif response_obj is not None and isinstance( response_obj, litellm.ImageResponse ): @@ -242,6 +247,7 @@ class LangFuseLogger: print_verbose(f"Langfuse Layer Logging - logging to langfuse v2 ") + print(f"response_obj: {response_obj}") if supports_tags: metadata_tags = metadata.get("tags", []) tags = metadata_tags @@ -306,11 +312,13 @@ class LangFuseLogger: usage = None if response_obj is not None and response_obj.get("id", None) is not None: generation_id = litellm.utils.get_logging_id(start_time, response_obj) + print(f"getting usage, cost={cost}") usage = { "prompt_tokens": response_obj["usage"]["prompt_tokens"], "completion_tokens": response_obj["usage"]["completion_tokens"], "total_cost": cost if supports_costs else None, } + print(f"constructed usage - {usage}") generation_name = metadata.get("generation_name", None) if generation_name is None: # just log `litellm-{call_type}` as the generation name diff --git a/litellm/main.py b/litellm/main.py index 1fca06fa5..ec5242ba4 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -3163,7 +3163,18 @@ def text_completion( # these are the params supported by Completion() but not ChatCompletion # default case, non OpenAI requests go through here - messages = [{"role": "system", "content": prompt}] + # handle prompt formatting if prompt is a string vs. list of strings + messages = [] + if isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], str): + for p in prompt: + message = {"role": "user", "content": p} + messages.append(message) + elif isinstance(prompt, str): + messages = [{"role": "user", "content": prompt}] + else: + raise Exception( + f"Unmapped prompt format. Your prompt is neither a list of strings nor a string. prompt={prompt}. File an issue - https://github.com/BerriAI/litellm/issues" + ) kwargs.pop("prompt", None) response = completion( model=model, diff --git a/litellm/tests/test_custom_callback_input.py b/litellm/tests/test_custom_callback_input.py index afaf7a54c..4ee8865b0 100644 --- a/litellm/tests/test_custom_callback_input.py +++ b/litellm/tests/test_custom_callback_input.py @@ -686,6 +686,44 @@ async def test_async_chat_vertex_ai_stream(): # Text Completion +@pytest.mark.asyncio +async def test_async_text_completion_bedrock(): + try: + customHandler = CompletionCustomHandler() + litellm.callbacks = [customHandler] + response = await litellm.atext_completion( + model="bedrock/anthropic.claude-3-haiku-20240307-v1:0", + prompt=["Hi 👋 - i'm async text completion bedrock"], + ) + # test streaming + response = await litellm.atext_completion( + model="bedrock/anthropic.claude-3-haiku-20240307-v1:0", + prompt=["Hi 👋 - i'm async text completion bedrock"], + stream=True, + ) + async for chunk in response: + print(f"chunk: {chunk}") + continue + ## test failure callback + try: + response = await litellm.atext_completion( + model="bedrock/", + prompt=["Hi 👋 - i'm async text completion bedrock"], + stream=True, + api_key="my-bad-key", + ) + async for chunk in response: + continue + except: + pass + time.sleep(1) + print(f"customHandler.errors: {customHandler.errors}") + assert len(customHandler.errors) == 0 + litellm.callbacks = [] + except Exception as e: + pytest.fail(f"An exception occurred: {str(e)}") + + ## Test OpenAI text completion + Async @pytest.mark.asyncio async def test_async_text_completion_openai_stream(): diff --git a/litellm/utils.py b/litellm/utils.py index ae775c3a6..c160b424e 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -852,9 +852,16 @@ class TranscriptionResponse(OpenAIObject): ############################################################ -def print_verbose(print_statement, logger_only: bool = False): +def print_verbose( + print_statement, + logger_only: bool = False, + log_level: Literal["DEBUG", "INFO"] = "DEBUG", +): try: - verbose_logger.debug(print_statement) + if log_level == "DEBUG": + verbose_logger.debug(print_statement) + elif log_level == "INFO": + verbose_logger.info(print_statement) if litellm.set_verbose == True and logger_only == False: print(print_statement) # noqa except: @@ -903,10 +910,20 @@ class Logging: raise ValueError( f"Invalid call_type {call_type}. Allowed values: {allowed_values}" ) - if messages is not None and isinstance(messages, str): - messages = [ - {"role": "user", "content": messages} - ] # convert text completion input to the chat completion format + if messages is not None: + if isinstance(messages, str): + messages = [ + {"role": "user", "content": messages} + ] # convert text completion input to the chat completion format + elif ( + isinstance(messages, list) + and len(messages) > 0 + and isinstance(messages[0], str) + ): + new_messages = [] + for m in messages: + new_messages.append({"role": "user", "content": m}) + messages = new_messages self.model = model self.messages = messages self.stream = stream @@ -1199,6 +1216,7 @@ class Logging: or isinstance(result, EmbeddingResponse) or isinstance(result, ImageResponse) or isinstance(result, TranscriptionResponse) + or isinstance(result, TextCompletionResponse) ) and self.stream != True ): # handle streaming separately @@ -4464,7 +4482,7 @@ def get_optional_params( if unsupported_params and not litellm.drop_params: raise UnsupportedParamsError( status_code=500, - message=f"{custom_llm_provider} does not support parameters: {unsupported_params}. To drop these, set `litellm.drop_params=True`.", + message=f"{custom_llm_provider} does not support parameters: {unsupported_params}. To drop these, set `litellm.drop_params=True` or for proxy:\n\n`litellm_settings:\n drop_params: true`\n", ) def _map_and_modify_arg(supported_params: dict, provider: str, model: str): @@ -10374,22 +10392,28 @@ def print_args_passed_to_litellm(original_function, args, kwargs): args_str = ", ".join(map(repr, args)) kwargs_str = ", ".join(f"{key}={repr(value)}" for key, value in kwargs.items()) - print_verbose("\n") # new line before - print_verbose("\033[92mRequest to litellm:\033[0m") + print_verbose("\n", log_level="INFO") # new line before + print_verbose("\033[92mRequest to litellm:\033[0m", log_level="INFO") if args and kwargs: print_verbose( - f"\033[92mlitellm.{original_function.__name__}({args_str}, {kwargs_str})\033[0m" + f"\033[92mlitellm.{original_function.__name__}({args_str}, {kwargs_str})\033[0m", + log_level="INFO", ) elif args: print_verbose( - f"\033[92mlitellm.{original_function.__name__}({args_str})\033[0m" + f"\033[92mlitellm.{original_function.__name__}({args_str})\033[0m", + log_level="INFO", ) elif kwargs: print_verbose( - f"\033[92mlitellm.{original_function.__name__}({kwargs_str})\033[0m" + f"\033[92mlitellm.{original_function.__name__}({kwargs_str})\033[0m", + log_level="INFO", ) else: - print_verbose(f"\033[92mlitellm.{original_function.__name__}()\033[0m") + print_verbose( + f"\033[92mlitellm.{original_function.__name__}()\033[0m", + log_level="INFO", + ) print_verbose("\n") # new line after except: # This should always be non blocking