feat(vertex_ai_context_caching.py): check gemini cache, if key already exists

This commit is contained in:
Krrish Dholakia 2024-08-26 20:28:18 -07:00
parent b0cc1df2d6
commit 0eea01dae9
6 changed files with 171 additions and 56 deletions

View file

@ -5070,6 +5070,10 @@ def get_max_tokens(model: str) -> Optional[int]:
)
def _strip_stable_vertex_version(model_name) -> str:
return re.sub(r"-\d+$", "", model_name)
def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> ModelInfo:
"""
Get a dict for the maximum tokens (context window), input_cost_per_token, output_cost_per_token for a given model.
@ -5171,9 +5175,15 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
except:
pass
combined_model_name = model
combined_stripped_model_name = _strip_stable_vertex_version(
model_name=model
)
else:
split_model = model
combined_model_name = "{}/{}".format(custom_llm_provider, model)
combined_stripped_model_name = "{}/{}".format(
custom_llm_provider, _strip_stable_vertex_version(model_name=model)
)
#########################
supported_openai_params = litellm.get_supported_openai_params(
@ -5200,8 +5210,9 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
"""
Check if: (in order of specificity)
1. 'custom_llm_provider/model' in litellm.model_cost. Checks "groq/llama3-8b-8192" if model="llama3-8b-8192" and custom_llm_provider="groq"
2. 'model' in litellm.model_cost. Checks "groq/llama3-8b-8192" in litellm.model_cost if model="groq/llama3-8b-8192" and custom_llm_provider=None
3. 'split_model' in litellm.model_cost. Checks "llama3-8b-8192" in litellm.model_cost if model="groq/llama3-8b-8192"
2. 'combined_stripped_model_name' in litellm.model_cost. Checks if 'gemini/gemini-1.5-flash' in model map, if 'gemini/gemini-1.5-flash-001' given.
3. 'model' in litellm.model_cost. Checks "groq/llama3-8b-8192" in litellm.model_cost if model="groq/llama3-8b-8192" and custom_llm_provider=None
4. 'split_model' in litellm.model_cost. Checks "llama3-8b-8192" in litellm.model_cost if model="groq/llama3-8b-8192"
"""
if combined_model_name in litellm.model_cost:
key = combined_model_name
@ -5217,6 +5228,26 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
pass
else:
raise Exception
elif combined_stripped_model_name in litellm.model_cost:
key = model
_model_info = litellm.model_cost[combined_stripped_model_name]
_model_info["supported_openai_params"] = supported_openai_params
if (
"litellm_provider" in _model_info
and _model_info["litellm_provider"] != custom_llm_provider
):
if custom_llm_provider == "vertex_ai" and _model_info[
"litellm_provider"
].startswith("vertex_ai"):
pass
else:
raise Exception(
"Got provider={}, Expected provider={}, for model={}".format(
_model_info["litellm_provider"],
custom_llm_provider,
model,
)
)
elif model in litellm.model_cost:
key = model
_model_info = litellm.model_cost[model]
@ -5320,9 +5351,9 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
"supports_assistant_prefill", False
),
)
except Exception:
except Exception as e:
raise Exception(
"This model isn't mapped yet. model={}, custom_llm_provider={}. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json".format(
"This model isn't mapped yet. model={}, custom_llm_provider={}. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json.".format(
model, custom_llm_provider
)
)