forked from phoenix/litellm-mirror
fix(main.py): support text completion input being a list of strings
addresses - https://github.com/BerriAI/litellm/issues/2792, https://github.com/BerriAI/litellm/issues/2777
This commit is contained in:
parent
71db88115d
commit
0d949d71ab
4 changed files with 95 additions and 14 deletions
|
@ -118,6 +118,11 @@ class LangFuseLogger:
|
||||||
):
|
):
|
||||||
input = prompt
|
input = prompt
|
||||||
output = response_obj["choices"][0]["message"].json()
|
output = response_obj["choices"][0]["message"].json()
|
||||||
|
elif response_obj is not None and isinstance(
|
||||||
|
response_obj, litellm.TextCompletionResponse
|
||||||
|
):
|
||||||
|
input = prompt
|
||||||
|
output = response_obj.choices[0].text
|
||||||
elif response_obj is not None and isinstance(
|
elif response_obj is not None and isinstance(
|
||||||
response_obj, litellm.ImageResponse
|
response_obj, litellm.ImageResponse
|
||||||
):
|
):
|
||||||
|
@ -242,6 +247,7 @@ class LangFuseLogger:
|
||||||
|
|
||||||
print_verbose(f"Langfuse Layer Logging - logging to langfuse v2 ")
|
print_verbose(f"Langfuse Layer Logging - logging to langfuse v2 ")
|
||||||
|
|
||||||
|
print(f"response_obj: {response_obj}")
|
||||||
if supports_tags:
|
if supports_tags:
|
||||||
metadata_tags = metadata.get("tags", [])
|
metadata_tags = metadata.get("tags", [])
|
||||||
tags = metadata_tags
|
tags = metadata_tags
|
||||||
|
@ -306,11 +312,13 @@ class LangFuseLogger:
|
||||||
usage = None
|
usage = None
|
||||||
if response_obj is not None and response_obj.get("id", None) is not None:
|
if response_obj is not None and response_obj.get("id", None) is not None:
|
||||||
generation_id = litellm.utils.get_logging_id(start_time, response_obj)
|
generation_id = litellm.utils.get_logging_id(start_time, response_obj)
|
||||||
|
print(f"getting usage, cost={cost}")
|
||||||
usage = {
|
usage = {
|
||||||
"prompt_tokens": response_obj["usage"]["prompt_tokens"],
|
"prompt_tokens": response_obj["usage"]["prompt_tokens"],
|
||||||
"completion_tokens": response_obj["usage"]["completion_tokens"],
|
"completion_tokens": response_obj["usage"]["completion_tokens"],
|
||||||
"total_cost": cost if supports_costs else None,
|
"total_cost": cost if supports_costs else None,
|
||||||
}
|
}
|
||||||
|
print(f"constructed usage - {usage}")
|
||||||
generation_name = metadata.get("generation_name", None)
|
generation_name = metadata.get("generation_name", None)
|
||||||
if generation_name is None:
|
if generation_name is None:
|
||||||
# just log `litellm-{call_type}` as the generation name
|
# just log `litellm-{call_type}` as the generation name
|
||||||
|
|
|
@ -3163,7 +3163,18 @@ def text_completion(
|
||||||
# these are the params supported by Completion() but not ChatCompletion
|
# these are the params supported by Completion() but not ChatCompletion
|
||||||
|
|
||||||
# default case, non OpenAI requests go through here
|
# default case, non OpenAI requests go through here
|
||||||
messages = [{"role": "system", "content": prompt}]
|
# handle prompt formatting if prompt is a string vs. list of strings
|
||||||
|
messages = []
|
||||||
|
if isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], str):
|
||||||
|
for p in prompt:
|
||||||
|
message = {"role": "user", "content": p}
|
||||||
|
messages.append(message)
|
||||||
|
elif isinstance(prompt, str):
|
||||||
|
messages = [{"role": "user", "content": prompt}]
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
f"Unmapped prompt format. Your prompt is neither a list of strings nor a string. prompt={prompt}. File an issue - https://github.com/BerriAI/litellm/issues"
|
||||||
|
)
|
||||||
kwargs.pop("prompt", None)
|
kwargs.pop("prompt", None)
|
||||||
response = completion(
|
response = completion(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
|
@ -686,6 +686,44 @@ async def test_async_chat_vertex_ai_stream():
|
||||||
# Text Completion
|
# Text Completion
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_text_completion_bedrock():
|
||||||
|
try:
|
||||||
|
customHandler = CompletionCustomHandler()
|
||||||
|
litellm.callbacks = [customHandler]
|
||||||
|
response = await litellm.atext_completion(
|
||||||
|
model="bedrock/anthropic.claude-3-haiku-20240307-v1:0",
|
||||||
|
prompt=["Hi 👋 - i'm async text completion bedrock"],
|
||||||
|
)
|
||||||
|
# test streaming
|
||||||
|
response = await litellm.atext_completion(
|
||||||
|
model="bedrock/anthropic.claude-3-haiku-20240307-v1:0",
|
||||||
|
prompt=["Hi 👋 - i'm async text completion bedrock"],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
async for chunk in response:
|
||||||
|
print(f"chunk: {chunk}")
|
||||||
|
continue
|
||||||
|
## test failure callback
|
||||||
|
try:
|
||||||
|
response = await litellm.atext_completion(
|
||||||
|
model="bedrock/",
|
||||||
|
prompt=["Hi 👋 - i'm async text completion bedrock"],
|
||||||
|
stream=True,
|
||||||
|
api_key="my-bad-key",
|
||||||
|
)
|
||||||
|
async for chunk in response:
|
||||||
|
continue
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
time.sleep(1)
|
||||||
|
print(f"customHandler.errors: {customHandler.errors}")
|
||||||
|
assert len(customHandler.errors) == 0
|
||||||
|
litellm.callbacks = []
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An exception occurred: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
## Test OpenAI text completion + Async
|
## Test OpenAI text completion + Async
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_async_text_completion_openai_stream():
|
async def test_async_text_completion_openai_stream():
|
||||||
|
|
|
@ -852,9 +852,16 @@ class TranscriptionResponse(OpenAIObject):
|
||||||
|
|
||||||
|
|
||||||
############################################################
|
############################################################
|
||||||
def print_verbose(print_statement, logger_only: bool = False):
|
def print_verbose(
|
||||||
|
print_statement,
|
||||||
|
logger_only: bool = False,
|
||||||
|
log_level: Literal["DEBUG", "INFO"] = "DEBUG",
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
verbose_logger.debug(print_statement)
|
if log_level == "DEBUG":
|
||||||
|
verbose_logger.debug(print_statement)
|
||||||
|
elif log_level == "INFO":
|
||||||
|
verbose_logger.info(print_statement)
|
||||||
if litellm.set_verbose == True and logger_only == False:
|
if litellm.set_verbose == True and logger_only == False:
|
||||||
print(print_statement) # noqa
|
print(print_statement) # noqa
|
||||||
except:
|
except:
|
||||||
|
@ -903,10 +910,20 @@ class Logging:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid call_type {call_type}. Allowed values: {allowed_values}"
|
f"Invalid call_type {call_type}. Allowed values: {allowed_values}"
|
||||||
)
|
)
|
||||||
if messages is not None and isinstance(messages, str):
|
if messages is not None:
|
||||||
messages = [
|
if isinstance(messages, str):
|
||||||
{"role": "user", "content": messages}
|
messages = [
|
||||||
] # convert text completion input to the chat completion format
|
{"role": "user", "content": messages}
|
||||||
|
] # convert text completion input to the chat completion format
|
||||||
|
elif (
|
||||||
|
isinstance(messages, list)
|
||||||
|
and len(messages) > 0
|
||||||
|
and isinstance(messages[0], str)
|
||||||
|
):
|
||||||
|
new_messages = []
|
||||||
|
for m in messages:
|
||||||
|
new_messages.append({"role": "user", "content": m})
|
||||||
|
messages = new_messages
|
||||||
self.model = model
|
self.model = model
|
||||||
self.messages = messages
|
self.messages = messages
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
|
@ -1199,6 +1216,7 @@ class Logging:
|
||||||
or isinstance(result, EmbeddingResponse)
|
or isinstance(result, EmbeddingResponse)
|
||||||
or isinstance(result, ImageResponse)
|
or isinstance(result, ImageResponse)
|
||||||
or isinstance(result, TranscriptionResponse)
|
or isinstance(result, TranscriptionResponse)
|
||||||
|
or isinstance(result, TextCompletionResponse)
|
||||||
)
|
)
|
||||||
and self.stream != True
|
and self.stream != True
|
||||||
): # handle streaming separately
|
): # handle streaming separately
|
||||||
|
@ -4464,7 +4482,7 @@ def get_optional_params(
|
||||||
if unsupported_params and not litellm.drop_params:
|
if unsupported_params and not litellm.drop_params:
|
||||||
raise UnsupportedParamsError(
|
raise UnsupportedParamsError(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
message=f"{custom_llm_provider} does not support parameters: {unsupported_params}. To drop these, set `litellm.drop_params=True`.",
|
message=f"{custom_llm_provider} does not support parameters: {unsupported_params}. To drop these, set `litellm.drop_params=True` or for proxy:\n\n`litellm_settings:\n drop_params: true`\n",
|
||||||
)
|
)
|
||||||
|
|
||||||
def _map_and_modify_arg(supported_params: dict, provider: str, model: str):
|
def _map_and_modify_arg(supported_params: dict, provider: str, model: str):
|
||||||
|
@ -10374,22 +10392,28 @@ def print_args_passed_to_litellm(original_function, args, kwargs):
|
||||||
|
|
||||||
args_str = ", ".join(map(repr, args))
|
args_str = ", ".join(map(repr, args))
|
||||||
kwargs_str = ", ".join(f"{key}={repr(value)}" for key, value in kwargs.items())
|
kwargs_str = ", ".join(f"{key}={repr(value)}" for key, value in kwargs.items())
|
||||||
print_verbose("\n") # new line before
|
print_verbose("\n", log_level="INFO") # new line before
|
||||||
print_verbose("\033[92mRequest to litellm:\033[0m")
|
print_verbose("\033[92mRequest to litellm:\033[0m", log_level="INFO")
|
||||||
if args and kwargs:
|
if args and kwargs:
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"\033[92mlitellm.{original_function.__name__}({args_str}, {kwargs_str})\033[0m"
|
f"\033[92mlitellm.{original_function.__name__}({args_str}, {kwargs_str})\033[0m",
|
||||||
|
log_level="INFO",
|
||||||
)
|
)
|
||||||
elif args:
|
elif args:
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"\033[92mlitellm.{original_function.__name__}({args_str})\033[0m"
|
f"\033[92mlitellm.{original_function.__name__}({args_str})\033[0m",
|
||||||
|
log_level="INFO",
|
||||||
)
|
)
|
||||||
elif kwargs:
|
elif kwargs:
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"\033[92mlitellm.{original_function.__name__}({kwargs_str})\033[0m"
|
f"\033[92mlitellm.{original_function.__name__}({kwargs_str})\033[0m",
|
||||||
|
log_level="INFO",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print_verbose(f"\033[92mlitellm.{original_function.__name__}()\033[0m")
|
print_verbose(
|
||||||
|
f"\033[92mlitellm.{original_function.__name__}()\033[0m",
|
||||||
|
log_level="INFO",
|
||||||
|
)
|
||||||
print_verbose("\n") # new line after
|
print_verbose("\n") # new line after
|
||||||
except:
|
except:
|
||||||
# This should always be non blocking
|
# This should always be non blocking
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue