LiteLLM Minor Fixes & Improvements (12/16/2024) - p1 (#7263)

* fix(factory.py): skip empty text blocks for bedrock user messages

Fixes https://github.com/BerriAI/litellm/issues/7169

* Add support for Gemini 2.0 GoogleSearch tool (#7257)

* Add support for google_search tool in gemini 2.0

* Add/modify tests

* Fix grounding check

* Remove 2.0 grounding test; exclude experimental model in VERTEX_MODELS_TO_NOT_TEST

* Swap order of tools

* DFix formatting

* fix(get_api_base.py): return api base in streaming response

Fixes https://github.com/BerriAI/litellm/issues/7249

Closes https://github.com/BerriAI/litellm/pull/7250

* fix(cost_calculator.py): only set base model to model if not none

Fixes https://github.com/BerriAI/litellm/issues/7223

* fix(cost_calculator.py): enforce stricter order when picking model for cost calculation

* fix(cost_calculator.py): fix '_select_model_name_for_cost_calc' to return model name with region name prefix if provided

* fix(utils.py): fix 'get_model_info()' to handle edge case where model name starts with custom llm provider AND custom llm provider is given

* fix(cost_calculator.py): handle `custom_llm_provider-` scenario

* fix(cost_calculator.py): e2e working tts cost tracking

ensures initial message is passed in, to cost calculator

* fix(factory.py): suppress linting errors

* fix(cost_calculator.py): strip llm provider from model name after selecting cost calc model

* fix(litellm_logging.py): store initial request in 'input' field + accept base_model to be passed in litellm_params directly

* test: handle none env var value in flaky test

* fix(litellm_logging.py): fix linting errors

---------

Co-authored-by: Sam B <samlingx@gmail.com>
This commit is contained in:
Krish Dholakia 2024-12-17 15:33:36 -08:00 committed by GitHub
parent 7f18a82f72
commit f966e279a6
20 changed files with 581 additions and 261 deletions

View file

@ -85,6 +85,10 @@ from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response impo
convert_to_streaming_response,
convert_to_streaming_response_async,
)
from litellm.litellm_core_utils.llm_response_utils.get_api_base import get_api_base
from litellm.litellm_core_utils.llm_response_utils.get_formatted_prompt import (
get_formatted_prompt,
)
from litellm.litellm_core_utils.llm_response_utils.get_headers import (
get_response_headers,
)
@ -3904,121 +3908,6 @@ def get_model_region(
return None
def get_api_base(
model: str, optional_params: Union[dict, LiteLLM_Params]
) -> Optional[str]:
"""
Returns the api base used for calling the model.
Parameters:
- model: str - the model passed to litellm.completion()
- optional_params - the 'litellm_params' in router.completion *OR* additional params passed to litellm.completion - eg. api_base, api_key, etc. See `LiteLLM_Params` - https://github.com/BerriAI/litellm/blob/f09e6ba98d65e035a79f73bc069145002ceafd36/litellm/router.py#L67
Returns:
- string (api_base) or None
Example:
```
from litellm import get_api_base
get_api_base(model="gemini/gemini-pro")
```
"""
try:
if isinstance(optional_params, LiteLLM_Params):
_optional_params = optional_params
elif "model" in optional_params:
_optional_params = LiteLLM_Params(**optional_params)
else: # prevent needing to copy and pop the dict
_optional_params = LiteLLM_Params(
model=model, **optional_params
) # convert to pydantic object
except Exception as e:
verbose_logger.debug("Error occurred in getting api base - {}".format(str(e)))
return None
# get llm provider
if _optional_params.api_base is not None:
return _optional_params.api_base
if litellm.model_alias_map and model in litellm.model_alias_map:
model = litellm.model_alias_map[model]
try:
(
model,
custom_llm_provider,
dynamic_api_key,
dynamic_api_base,
) = get_llm_provider(
model=model,
custom_llm_provider=_optional_params.custom_llm_provider,
api_base=_optional_params.api_base,
api_key=_optional_params.api_key,
)
except Exception as e:
verbose_logger.debug("Error occurred in getting api base - {}".format(str(e)))
custom_llm_provider = None
dynamic_api_base = None
if dynamic_api_base is not None:
return dynamic_api_base
stream: bool = getattr(optional_params, "stream", False)
if (
_optional_params.vertex_location is not None
and _optional_params.vertex_project is not None
):
from litellm.llms.vertex_ai.vertex_ai_partner_models.main import (
VertexPartnerProvider,
create_vertex_url,
)
if "claude" in model:
_api_base = create_vertex_url(
vertex_location=_optional_params.vertex_location,
vertex_project=_optional_params.vertex_project,
model=model,
stream=stream,
partner=VertexPartnerProvider.claude,
)
else:
if stream:
_api_base = "{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/google/models/{}:streamGenerateContent".format(
_optional_params.vertex_location,
_optional_params.vertex_project,
_optional_params.vertex_location,
model,
)
else:
_api_base = "{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/google/models/{}:generateContent".format(
_optional_params.vertex_location,
_optional_params.vertex_project,
_optional_params.vertex_location,
model,
)
return _api_base
if custom_llm_provider is None:
return None
if custom_llm_provider == "gemini":
if stream:
_api_base = "https://generativelanguage.googleapis.com/v1beta/models/{}:streamGenerateContent".format(
model
)
else:
_api_base = "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent".format(
model
)
return _api_base
elif custom_llm_provider == "openai":
_api_base = "https://api.openai.com"
return _api_base
return None
def get_first_chars_messages(kwargs: dict) -> str:
try:
_messages = kwargs.get("messages")
@ -4034,54 +3923,6 @@ def _count_characters(text: str) -> int:
return len(filtered_text)
def get_formatted_prompt(
data: dict,
call_type: Literal[
"completion",
"embedding",
"image_generation",
"audio_transcription",
"moderation",
"text_completion",
],
) -> str:
"""
Extracts the prompt from the input data based on the call type.
Returns a string.
"""
prompt = ""
if call_type == "completion":
for message in data["messages"]:
if message.get("content", None) is not None:
content = message.get("content")
if isinstance(content, str):
prompt += message["content"]
elif isinstance(content, List):
for c in content:
if c["type"] == "text":
prompt += c["text"]
if "tool_calls" in message:
for tool_call in message["tool_calls"]:
if "function" in tool_call:
function_arguments = tool_call["function"]["arguments"]
prompt += function_arguments
elif call_type == "text_completion":
prompt = data["prompt"]
elif call_type == "embedding" or call_type == "moderation":
if isinstance(data["input"], str):
prompt = data["input"]
elif isinstance(data["input"], list):
for m in data["input"]:
prompt += m
elif call_type == "image_generation":
prompt = data["prompt"]
elif call_type == "audio_transcription":
if "prompt" in data:
prompt = data["prompt"]
return prompt
def get_response_string(response_obj: ModelResponse) -> str:
_choices: List[Union[Choices, StreamingChoices]] = response_obj.choices
@ -4283,6 +4124,62 @@ def _check_provider_match(model_info: dict, custom_llm_provider: Optional[str])
return True
from typing import TypedDict
class PotentialModelNamesAndCustomLLMProvider(TypedDict):
split_model: str
combined_model_name: str
stripped_model_name: str
combined_stripped_model_name: str
custom_llm_provider: str
def _get_potential_model_names(
model: str, custom_llm_provider: Optional[str]
) -> PotentialModelNamesAndCustomLLMProvider:
if custom_llm_provider is None:
# Get custom_llm_provider
try:
split_model, custom_llm_provider, _, _ = get_llm_provider(model=model)
except Exception:
split_model = model
combined_model_name = model
stripped_model_name = _strip_model_name(
model=model, custom_llm_provider=custom_llm_provider
)
combined_stripped_model_name = stripped_model_name
elif custom_llm_provider and model.startswith(
custom_llm_provider + "/"
): # handle case where custom_llm_provider is provided and model starts with custom_llm_provider
split_model = model.split("/")[1]
combined_model_name = model
stripped_model_name = _strip_model_name(
model=split_model, custom_llm_provider=custom_llm_provider
)
combined_stripped_model_name = "{}/{}".format(
custom_llm_provider, stripped_model_name
)
else:
split_model = model
combined_model_name = "{}/{}".format(custom_llm_provider, model)
stripped_model_name = _strip_model_name(
model=model, custom_llm_provider=custom_llm_provider
)
combined_stripped_model_name = "{}/{}".format(
custom_llm_provider,
_strip_model_name(model=model, custom_llm_provider=custom_llm_provider),
)
return PotentialModelNamesAndCustomLLMProvider(
split_model=split_model,
combined_model_name=combined_model_name,
stripped_model_name=stripped_model_name,
combined_stripped_model_name=combined_stripped_model_name,
custom_llm_provider=cast(str, custom_llm_provider),
)
def get_model_info( # noqa: PLR0915
model: str, custom_llm_provider: Optional[str] = None
) -> ModelInfo:
@ -4390,28 +4287,16 @@ def get_model_info( # noqa: PLR0915
elif model + "@latest" in litellm.vertex_ai_ai21_models:
model = model + "@latest"
##########################
if custom_llm_provider is None:
# Get custom_llm_provider
try:
split_model, custom_llm_provider, _, _ = get_llm_provider(model=model)
except Exception:
split_model = model
combined_model_name = model
stripped_model_name = _strip_model_name(
model=model, custom_llm_provider=custom_llm_provider
)
combined_stripped_model_name = stripped_model_name
else:
split_model = model
combined_model_name = "{}/{}".format(custom_llm_provider, model)
stripped_model_name = _strip_model_name(
model=model, custom_llm_provider=custom_llm_provider
)
combined_stripped_model_name = "{}/{}".format(
custom_llm_provider,
_strip_model_name(model=model, custom_llm_provider=custom_llm_provider),
)
potential_model_names = _get_potential_model_names(
model=model, custom_llm_provider=custom_llm_provider
)
combined_model_name = potential_model_names["combined_model_name"]
stripped_model_name = potential_model_names["stripped_model_name"]
combined_stripped_model_name = potential_model_names[
"combined_stripped_model_name"
]
split_model = potential_model_names["split_model"]
custom_llm_provider = potential_model_names["custom_llm_provider"]
#########################
supported_openai_params = litellm.get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
@ -5956,8 +5841,10 @@ def _get_base_model_from_metadata(model_call_details=None):
if model_call_details is None:
return None
litellm_params = model_call_details.get("litellm_params", {})
if litellm_params is not None:
_base_model = litellm_params.get("base_model", None)
if _base_model is not None:
return _base_model
metadata = litellm_params.get("metadata", {})
return _get_base_model_from_litellm_call_metadata(metadata=metadata)