fix(key_management_endpoints.py): override metadata field value on up… (#7008)

* fix(key_management_endpoints.py): override metadata field value on update

allow user to override tags

* feat(__init__.py): expose new disable_end_user_cost_tracking_prometheus_only metric

allow disabling end user cost tracking on prometheus - fixes cardinality issue

* fix(litellm_pre_call_utils.py): add key/team level enforced params

Fixes https://github.com/BerriAI/litellm/issues/6652

* fix(key_management_endpoints.py): allow user to pass in `enforced_params` as a top level param on /key/generate and /key/update

* docs(enterprise.md): add docs on enforcing required params for llm requests

* Add support of Galadriel API (#7005)

* fix(router.py): robust retry after handling

set retry after time to 0 if >0 healthy deployments. handle base case = 1 deployment

* test(test_router.py): fix test

* feat(bedrock/): add support for 'nova' models

also adds explicit 'converse/' route for simpler routing

* fix: fix 'supports_pdf_input'

return if model supports pdf input on get_model_info

* feat(converse_transformation.py): support bedrock pdf input

* docs(document_understanding.md): add document understanding to docs

* fix(litellm_pre_call_utils.py): fix linting error

* fix(init.py): fix passing of bedrock converse models

* feat(bedrock/converse): support 'response_format={"type": "json_object"}'

* fix(converse_handler.py): fix linting error

* fix(base_llm_unit_tests.py): fix test

* fix: fix test

* test: fix test

* test: fix test

* test: remove duplicate test

---------

Co-authored-by: h4n0 <4738254+h4n0@users.noreply.github.com>
This commit is contained in:
Krish Dholakia 2024-12-03 23:03:50 -08:00 committed by GitHub
parent d558b643be
commit 6bb934c0ac
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
37 changed files with 1297 additions and 503 deletions

View file

@ -3142,7 +3142,7 @@ def get_optional_params( # noqa: PLR0915
model=model, custom_llm_provider=custom_llm_provider
)
base_model = litellm.AmazonConverseConfig()._get_base_model(model)
if base_model in litellm.BEDROCK_CONVERSE_MODELS:
if base_model in litellm.bedrock_converse_models:
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.AmazonConverseConfig().map_openai_params(
model=model,
@ -4308,6 +4308,10 @@ def _strip_stable_vertex_version(model_name) -> str:
return re.sub(r"-\d+$", "", model_name)
def _strip_bedrock_region(model_name) -> str:
return litellm.AmazonConverseConfig()._get_base_model(model_name)
def _strip_openai_finetune_model_name(model_name: str) -> str:
"""
Strips the organization, custom suffix, and ID from an OpenAI fine-tuned model name.
@ -4324,16 +4328,50 @@ def _strip_openai_finetune_model_name(model_name: str) -> str:
return re.sub(r"(:[^:]+){3}$", "", model_name)
def _strip_model_name(model: str) -> str:
strip_version = _strip_stable_vertex_version(model_name=model)
strip_finetune = _strip_openai_finetune_model_name(model_name=strip_version)
return strip_finetune
def _strip_model_name(model: str, custom_llm_provider: Optional[str]) -> str:
if custom_llm_provider and custom_llm_provider == "bedrock":
strip_bedrock_region = _strip_bedrock_region(model_name=model)
return strip_bedrock_region
elif custom_llm_provider and (
custom_llm_provider == "vertex_ai" or custom_llm_provider == "gemini"
):
strip_version = _strip_stable_vertex_version(model_name=model)
return strip_version
else:
strip_finetune = _strip_openai_finetune_model_name(model_name=model)
return strip_finetune
def _get_model_info_from_model_cost(key: str) -> dict:
return litellm.model_cost[key]
def _check_provider_match(model_info: dict, custom_llm_provider: Optional[str]) -> bool:
"""
Check if the model info provider matches the custom provider.
"""
if custom_llm_provider and (
"litellm_provider" in model_info
and model_info["litellm_provider"] != custom_llm_provider
):
if custom_llm_provider == "vertex_ai" and model_info[
"litellm_provider"
].startswith("vertex_ai"):
return True
elif custom_llm_provider == "fireworks_ai" and model_info[
"litellm_provider"
].startswith("fireworks_ai"):
return True
elif custom_llm_provider == "bedrock" and model_info[
"litellm_provider"
].startswith("bedrock"):
return True
else:
return False
return True
def get_model_info( # noqa: PLR0915
model: str, custom_llm_provider: Optional[str] = None
) -> ModelInfo:
@ -4388,6 +4426,7 @@ def get_model_info( # noqa: PLR0915
supports_prompt_caching: Optional[bool]
supports_audio_input: Optional[bool]
supports_audio_output: Optional[bool]
supports_pdf_input: Optional[bool]
Raises:
Exception: If the model is not mapped yet.
@ -4445,15 +4484,21 @@ def get_model_info( # noqa: PLR0915
except Exception:
split_model = model
combined_model_name = model
stripped_model_name = _strip_model_name(model=model)
stripped_model_name = _strip_model_name(
model=model, custom_llm_provider=custom_llm_provider
)
combined_stripped_model_name = stripped_model_name
else:
split_model = model
combined_model_name = "{}/{}".format(custom_llm_provider, model)
stripped_model_name = _strip_model_name(model=model)
combined_stripped_model_name = "{}/{}".format(
custom_llm_provider, _strip_model_name(model=model)
stripped_model_name = _strip_model_name(
model=model, custom_llm_provider=custom_llm_provider
)
combined_stripped_model_name = "{}/{}".format(
custom_llm_provider,
_strip_model_name(model=model, custom_llm_provider=custom_llm_provider),
)
#########################
supported_openai_params = litellm.get_supported_openai_params(
@ -4476,6 +4521,7 @@ def get_model_info( # noqa: PLR0915
supports_function_calling=None,
supports_assistant_prefill=None,
supports_prompt_caching=None,
supports_pdf_input=None,
)
elif custom_llm_provider == "ollama" or custom_llm_provider == "ollama_chat":
return litellm.OllamaConfig().get_model_info(model)
@ -4488,40 +4534,25 @@ def get_model_info( # noqa: PLR0915
4. 'stripped_model_name' in litellm.model_cost. Checks if 'ft:gpt-3.5-turbo' in model map, if 'ft:gpt-3.5-turbo:my-org:custom_suffix:id' given.
5. 'split_model' in litellm.model_cost. Checks "llama3-8b-8192" in litellm.model_cost if model="groq/llama3-8b-8192"
"""
_model_info: Optional[Dict[str, Any]] = None
key: Optional[str] = None
if combined_model_name in litellm.model_cost:
key = combined_model_name
_model_info = _get_model_info_from_model_cost(key=key)
_model_info["supported_openai_params"] = supported_openai_params
if (
"litellm_provider" in _model_info
and _model_info["litellm_provider"] != custom_llm_provider
if not _check_provider_match(
model_info=_model_info, custom_llm_provider=custom_llm_provider
):
if custom_llm_provider == "vertex_ai" and _model_info[
"litellm_provider"
].startswith("vertex_ai"):
pass
else:
_model_info = None
_model_info = None
if _model_info is None and model in litellm.model_cost:
key = model
_model_info = _get_model_info_from_model_cost(key=key)
_model_info["supported_openai_params"] = supported_openai_params
if (
"litellm_provider" in _model_info
and _model_info["litellm_provider"] != custom_llm_provider
if not _check_provider_match(
model_info=_model_info, custom_llm_provider=custom_llm_provider
):
if custom_llm_provider == "vertex_ai" and _model_info[
"litellm_provider"
].startswith("vertex_ai"):
pass
elif custom_llm_provider == "fireworks_ai" and _model_info[
"litellm_provider"
].startswith("fireworks_ai"):
pass
else:
_model_info = None
_model_info = None
if (
_model_info is None
and combined_stripped_model_name in litellm.model_cost
@ -4529,57 +4560,26 @@ def get_model_info( # noqa: PLR0915
key = combined_stripped_model_name
_model_info = _get_model_info_from_model_cost(key=key)
_model_info["supported_openai_params"] = supported_openai_params
if (
"litellm_provider" in _model_info
and _model_info["litellm_provider"] != custom_llm_provider
if not _check_provider_match(
model_info=_model_info, custom_llm_provider=custom_llm_provider
):
if custom_llm_provider == "vertex_ai" and _model_info[
"litellm_provider"
].startswith("vertex_ai"):
pass
elif custom_llm_provider == "fireworks_ai" and _model_info[
"litellm_provider"
].startswith("fireworks_ai"):
pass
else:
_model_info = None
_model_info = None
if _model_info is None and stripped_model_name in litellm.model_cost:
key = stripped_model_name
_model_info = _get_model_info_from_model_cost(key=key)
_model_info["supported_openai_params"] = supported_openai_params
if (
"litellm_provider" in _model_info
and _model_info["litellm_provider"] != custom_llm_provider
if not _check_provider_match(
model_info=_model_info, custom_llm_provider=custom_llm_provider
):
if custom_llm_provider == "vertex_ai" and _model_info[
"litellm_provider"
].startswith("vertex_ai"):
pass
elif custom_llm_provider == "fireworks_ai" and _model_info[
"litellm_provider"
].startswith("fireworks_ai"):
pass
else:
_model_info = None
_model_info = None
if _model_info is None and split_model in litellm.model_cost:
key = split_model
_model_info = _get_model_info_from_model_cost(key=key)
_model_info["supported_openai_params"] = supported_openai_params
if (
"litellm_provider" in _model_info
and _model_info["litellm_provider"] != custom_llm_provider
if not _check_provider_match(
model_info=_model_info, custom_llm_provider=custom_llm_provider
):
if custom_llm_provider == "vertex_ai" and _model_info[
"litellm_provider"
].startswith("vertex_ai"):
pass
elif custom_llm_provider == "fireworks_ai" and _model_info[
"litellm_provider"
].startswith("fireworks_ai"):
pass
else:
_model_info = None
_model_info = None
if _model_info is None or key is None:
raise ValueError(
"This model isn't mapped yet. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json"
@ -4675,6 +4675,7 @@ def get_model_info( # noqa: PLR0915
),
supports_audio_input=_model_info.get("supports_audio_input", False),
supports_audio_output=_model_info.get("supports_audio_output", False),
supports_pdf_input=_model_info.get("supports_pdf_input", False),
tpm=_model_info.get("tpm", None),
rpm=_model_info.get("rpm", None),
)
@ -6195,11 +6196,21 @@ class ProviderConfigManager:
return OpenAIGPTConfig()
def get_end_user_id_for_cost_tracking(litellm_params: dict) -> Optional[str]:
def get_end_user_id_for_cost_tracking(
litellm_params: dict,
service_type: Literal["litellm_logging", "prometheus"] = "litellm_logging",
) -> Optional[str]:
"""
Used for enforcing `disable_end_user_cost_tracking` param.
service_type: "litellm_logging" or "prometheus" - used to allow prometheus only disable cost tracking.
"""
proxy_server_request = litellm_params.get("proxy_server_request") or {}
if litellm.disable_end_user_cost_tracking:
return None
if (
service_type == "prometheus"
and litellm.disable_end_user_cost_tracking_prometheus_only
):
return None
return proxy_server_request.get("body", {}).get("user", None)