Merge branch 'BerriAI:main' into main

This commit is contained in:
greenscale-nandesh 2024-04-17 12:24:29 -07:00 committed by GitHub
commit 907e3973fd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
44 changed files with 1001 additions and 156 deletions

View file

@ -5436,7 +5436,9 @@ def get_api_base(model: str, optional_params: dict) -> Optional[str]:
get_api_base(model="gemini/gemini-pro")
```
"""
_optional_params = LiteLLM_Params(**optional_params) # convert to pydantic object
_optional_params = LiteLLM_Params(
model=model, **optional_params
) # convert to pydantic object
# get llm provider
try:
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
@ -7842,6 +7844,19 @@ def exception_type(
response=original_exception.response,
)
elif custom_llm_provider == "vertex_ai":
if completion_kwargs is not None:
# add model, deployment and model_group to the exception message
_model = completion_kwargs.get("model")
_kwargs = completion_kwargs.get("kwargs", {}) or {}
_metadata = _kwargs.get("metadata", {}) or {}
_model_group = _metadata.get("model_group")
_deployment = _metadata.get("deployment")
error_str += f"\nmodel: {_model}\n"
if _model_group is not None:
error_str += f"model_group: {_model_group}\n"
if _deployment is not None:
error_str += f"deployment: {_deployment}\n"
if (
"Vertex AI API has not been used in project" in error_str
or "Unable to find your project" in error_str
@ -10609,16 +10624,16 @@ def trim_messages(
messages = copy.deepcopy(messages)
try:
print_verbose(f"trimming messages")
if max_tokens == None:
if max_tokens is None:
# Check if model is valid
if model in litellm.model_cost:
max_tokens_for_model = litellm.model_cost[model]["max_tokens"]
max_tokens_for_model = litellm.model_cost[model].get("max_input_tokens", litellm.model_cost[model]["max_tokens"])
max_tokens = int(max_tokens_for_model * trim_ratio)
else:
# if user did not specify max tokens
# if user did not specify max (input) tokens
# or passed an llm litellm does not know
# do nothing, just return messages
return
return messages
system_message = ""
for message in messages: