mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
LiteLLM Minor Fixes & Improvements (12/16/2024) - p1 (#7263)
* fix(factory.py): skip empty text blocks for bedrock user messages Fixes https://github.com/BerriAI/litellm/issues/7169 * Add support for Gemini 2.0 GoogleSearch tool (#7257) * Add support for google_search tool in gemini 2.0 * Add/modify tests * Fix grounding check * Remove 2.0 grounding test; exclude experimental model in VERTEX_MODELS_TO_NOT_TEST * Swap order of tools * DFix formatting * fix(get_api_base.py): return api base in streaming response Fixes https://github.com/BerriAI/litellm/issues/7249 Closes https://github.com/BerriAI/litellm/pull/7250 * fix(cost_calculator.py): only set base model to model if not none Fixes https://github.com/BerriAI/litellm/issues/7223 * fix(cost_calculator.py): enforce stricter order when picking model for cost calculation * fix(cost_calculator.py): fix '_select_model_name_for_cost_calc' to return model name with region name prefix if provided * fix(utils.py): fix 'get_model_info()' to handle edge case where model name starts with custom llm provider AND custom llm provider is given * fix(cost_calculator.py): handle `custom_llm_provider-` scenario * fix(cost_calculator.py): e2e working tts cost tracking ensures initial message is passed in, to cost calculator * fix(factory.py): suppress linting errors * fix(cost_calculator.py): strip llm provider from model name after selecting cost calc model * fix(litellm_logging.py): store initial request in 'input' field + accept base_model to be passed in litellm_params directly * test: handle none env var value in flaky test * fix(litellm_logging.py): fix linting errors --------- Co-authored-by: Sam B <samlingx@gmail.com>
This commit is contained in:
parent
f42b31b743
commit
179d2f56b7
20 changed files with 581 additions and 261 deletions
259
litellm/utils.py
259
litellm/utils.py
|
@ -85,6 +85,10 @@ from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response impo
|
|||
convert_to_streaming_response,
|
||||
convert_to_streaming_response_async,
|
||||
)
|
||||
from litellm.litellm_core_utils.llm_response_utils.get_api_base import get_api_base
|
||||
from litellm.litellm_core_utils.llm_response_utils.get_formatted_prompt import (
|
||||
get_formatted_prompt,
|
||||
)
|
||||
from litellm.litellm_core_utils.llm_response_utils.get_headers import (
|
||||
get_response_headers,
|
||||
)
|
||||
|
@ -3904,121 +3908,6 @@ def get_model_region(
|
|||
return None
|
||||
|
||||
|
||||
def get_api_base(
|
||||
model: str, optional_params: Union[dict, LiteLLM_Params]
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Returns the api base used for calling the model.
|
||||
|
||||
Parameters:
|
||||
- model: str - the model passed to litellm.completion()
|
||||
- optional_params - the 'litellm_params' in router.completion *OR* additional params passed to litellm.completion - eg. api_base, api_key, etc. See `LiteLLM_Params` - https://github.com/BerriAI/litellm/blob/f09e6ba98d65e035a79f73bc069145002ceafd36/litellm/router.py#L67
|
||||
|
||||
Returns:
|
||||
- string (api_base) or None
|
||||
|
||||
Example:
|
||||
```
|
||||
from litellm import get_api_base
|
||||
|
||||
get_api_base(model="gemini/gemini-pro")
|
||||
```
|
||||
"""
|
||||
|
||||
try:
|
||||
if isinstance(optional_params, LiteLLM_Params):
|
||||
_optional_params = optional_params
|
||||
elif "model" in optional_params:
|
||||
_optional_params = LiteLLM_Params(**optional_params)
|
||||
else: # prevent needing to copy and pop the dict
|
||||
_optional_params = LiteLLM_Params(
|
||||
model=model, **optional_params
|
||||
) # convert to pydantic object
|
||||
except Exception as e:
|
||||
verbose_logger.debug("Error occurred in getting api base - {}".format(str(e)))
|
||||
return None
|
||||
# get llm provider
|
||||
|
||||
if _optional_params.api_base is not None:
|
||||
return _optional_params.api_base
|
||||
|
||||
if litellm.model_alias_map and model in litellm.model_alias_map:
|
||||
model = litellm.model_alias_map[model]
|
||||
try:
|
||||
(
|
||||
model,
|
||||
custom_llm_provider,
|
||||
dynamic_api_key,
|
||||
dynamic_api_base,
|
||||
) = get_llm_provider(
|
||||
model=model,
|
||||
custom_llm_provider=_optional_params.custom_llm_provider,
|
||||
api_base=_optional_params.api_base,
|
||||
api_key=_optional_params.api_key,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.debug("Error occurred in getting api base - {}".format(str(e)))
|
||||
custom_llm_provider = None
|
||||
dynamic_api_base = None
|
||||
|
||||
if dynamic_api_base is not None:
|
||||
return dynamic_api_base
|
||||
|
||||
stream: bool = getattr(optional_params, "stream", False)
|
||||
|
||||
if (
|
||||
_optional_params.vertex_location is not None
|
||||
and _optional_params.vertex_project is not None
|
||||
):
|
||||
from litellm.llms.vertex_ai.vertex_ai_partner_models.main import (
|
||||
VertexPartnerProvider,
|
||||
create_vertex_url,
|
||||
)
|
||||
|
||||
if "claude" in model:
|
||||
_api_base = create_vertex_url(
|
||||
vertex_location=_optional_params.vertex_location,
|
||||
vertex_project=_optional_params.vertex_project,
|
||||
model=model,
|
||||
stream=stream,
|
||||
partner=VertexPartnerProvider.claude,
|
||||
)
|
||||
else:
|
||||
if stream:
|
||||
_api_base = "{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/google/models/{}:streamGenerateContent".format(
|
||||
_optional_params.vertex_location,
|
||||
_optional_params.vertex_project,
|
||||
_optional_params.vertex_location,
|
||||
model,
|
||||
)
|
||||
else:
|
||||
_api_base = "{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/google/models/{}:generateContent".format(
|
||||
_optional_params.vertex_location,
|
||||
_optional_params.vertex_project,
|
||||
_optional_params.vertex_location,
|
||||
model,
|
||||
)
|
||||
return _api_base
|
||||
|
||||
if custom_llm_provider is None:
|
||||
return None
|
||||
|
||||
if custom_llm_provider == "gemini":
|
||||
if stream:
|
||||
_api_base = "https://generativelanguage.googleapis.com/v1beta/models/{}:streamGenerateContent".format(
|
||||
model
|
||||
)
|
||||
else:
|
||||
_api_base = "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent".format(
|
||||
model
|
||||
)
|
||||
return _api_base
|
||||
elif custom_llm_provider == "openai":
|
||||
_api_base = "https://api.openai.com"
|
||||
return _api_base
|
||||
return None
|
||||
|
||||
|
||||
def get_first_chars_messages(kwargs: dict) -> str:
|
||||
try:
|
||||
_messages = kwargs.get("messages")
|
||||
|
@ -4034,54 +3923,6 @@ def _count_characters(text: str) -> int:
|
|||
return len(filtered_text)
|
||||
|
||||
|
||||
def get_formatted_prompt(
|
||||
data: dict,
|
||||
call_type: Literal[
|
||||
"completion",
|
||||
"embedding",
|
||||
"image_generation",
|
||||
"audio_transcription",
|
||||
"moderation",
|
||||
"text_completion",
|
||||
],
|
||||
) -> str:
|
||||
"""
|
||||
Extracts the prompt from the input data based on the call type.
|
||||
|
||||
Returns a string.
|
||||
"""
|
||||
prompt = ""
|
||||
if call_type == "completion":
|
||||
for message in data["messages"]:
|
||||
if message.get("content", None) is not None:
|
||||
content = message.get("content")
|
||||
if isinstance(content, str):
|
||||
prompt += message["content"]
|
||||
elif isinstance(content, List):
|
||||
for c in content:
|
||||
if c["type"] == "text":
|
||||
prompt += c["text"]
|
||||
if "tool_calls" in message:
|
||||
for tool_call in message["tool_calls"]:
|
||||
if "function" in tool_call:
|
||||
function_arguments = tool_call["function"]["arguments"]
|
||||
prompt += function_arguments
|
||||
elif call_type == "text_completion":
|
||||
prompt = data["prompt"]
|
||||
elif call_type == "embedding" or call_type == "moderation":
|
||||
if isinstance(data["input"], str):
|
||||
prompt = data["input"]
|
||||
elif isinstance(data["input"], list):
|
||||
for m in data["input"]:
|
||||
prompt += m
|
||||
elif call_type == "image_generation":
|
||||
prompt = data["prompt"]
|
||||
elif call_type == "audio_transcription":
|
||||
if "prompt" in data:
|
||||
prompt = data["prompt"]
|
||||
return prompt
|
||||
|
||||
|
||||
def get_response_string(response_obj: ModelResponse) -> str:
|
||||
_choices: List[Union[Choices, StreamingChoices]] = response_obj.choices
|
||||
|
||||
|
@ -4283,6 +4124,62 @@ def _check_provider_match(model_info: dict, custom_llm_provider: Optional[str])
|
|||
return True
|
||||
|
||||
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
class PotentialModelNamesAndCustomLLMProvider(TypedDict):
|
||||
split_model: str
|
||||
combined_model_name: str
|
||||
stripped_model_name: str
|
||||
combined_stripped_model_name: str
|
||||
custom_llm_provider: str
|
||||
|
||||
|
||||
def _get_potential_model_names(
|
||||
model: str, custom_llm_provider: Optional[str]
|
||||
) -> PotentialModelNamesAndCustomLLMProvider:
|
||||
if custom_llm_provider is None:
|
||||
# Get custom_llm_provider
|
||||
try:
|
||||
split_model, custom_llm_provider, _, _ = get_llm_provider(model=model)
|
||||
except Exception:
|
||||
split_model = model
|
||||
combined_model_name = model
|
||||
stripped_model_name = _strip_model_name(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
combined_stripped_model_name = stripped_model_name
|
||||
elif custom_llm_provider and model.startswith(
|
||||
custom_llm_provider + "/"
|
||||
): # handle case where custom_llm_provider is provided and model starts with custom_llm_provider
|
||||
split_model = model.split("/")[1]
|
||||
combined_model_name = model
|
||||
stripped_model_name = _strip_model_name(
|
||||
model=split_model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
combined_stripped_model_name = "{}/{}".format(
|
||||
custom_llm_provider, stripped_model_name
|
||||
)
|
||||
else:
|
||||
split_model = model
|
||||
combined_model_name = "{}/{}".format(custom_llm_provider, model)
|
||||
stripped_model_name = _strip_model_name(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
combined_stripped_model_name = "{}/{}".format(
|
||||
custom_llm_provider,
|
||||
_strip_model_name(model=model, custom_llm_provider=custom_llm_provider),
|
||||
)
|
||||
|
||||
return PotentialModelNamesAndCustomLLMProvider(
|
||||
split_model=split_model,
|
||||
combined_model_name=combined_model_name,
|
||||
stripped_model_name=stripped_model_name,
|
||||
combined_stripped_model_name=combined_stripped_model_name,
|
||||
custom_llm_provider=cast(str, custom_llm_provider),
|
||||
)
|
||||
|
||||
|
||||
def get_model_info( # noqa: PLR0915
|
||||
model: str, custom_llm_provider: Optional[str] = None
|
||||
) -> ModelInfo:
|
||||
|
@ -4390,28 +4287,16 @@ def get_model_info( # noqa: PLR0915
|
|||
elif model + "@latest" in litellm.vertex_ai_ai21_models:
|
||||
model = model + "@latest"
|
||||
##########################
|
||||
if custom_llm_provider is None:
|
||||
# Get custom_llm_provider
|
||||
try:
|
||||
split_model, custom_llm_provider, _, _ = get_llm_provider(model=model)
|
||||
except Exception:
|
||||
split_model = model
|
||||
combined_model_name = model
|
||||
stripped_model_name = _strip_model_name(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
combined_stripped_model_name = stripped_model_name
|
||||
else:
|
||||
split_model = model
|
||||
combined_model_name = "{}/{}".format(custom_llm_provider, model)
|
||||
stripped_model_name = _strip_model_name(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
combined_stripped_model_name = "{}/{}".format(
|
||||
custom_llm_provider,
|
||||
_strip_model_name(model=model, custom_llm_provider=custom_llm_provider),
|
||||
)
|
||||
|
||||
potential_model_names = _get_potential_model_names(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
combined_model_name = potential_model_names["combined_model_name"]
|
||||
stripped_model_name = potential_model_names["stripped_model_name"]
|
||||
combined_stripped_model_name = potential_model_names[
|
||||
"combined_stripped_model_name"
|
||||
]
|
||||
split_model = potential_model_names["split_model"]
|
||||
custom_llm_provider = potential_model_names["custom_llm_provider"]
|
||||
#########################
|
||||
supported_openai_params = litellm.get_supported_openai_params(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
|
@ -5956,8 +5841,10 @@ def _get_base_model_from_metadata(model_call_details=None):
|
|||
if model_call_details is None:
|
||||
return None
|
||||
litellm_params = model_call_details.get("litellm_params", {})
|
||||
|
||||
if litellm_params is not None:
|
||||
_base_model = litellm_params.get("base_model", None)
|
||||
if _base_model is not None:
|
||||
return _base_model
|
||||
metadata = litellm_params.get("metadata", {})
|
||||
|
||||
return _get_base_model_from_litellm_call_metadata(metadata=metadata)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue