Merge pull request #4478 from BerriAI/litellm_support_response_schema_param_vertex_ai_old

feat(vertex_httpx.py): support the 'response_schema' param for older vertex ai models
This commit is contained in:
Krish Dholakia 2024-06-29 20:17:39 -07:00 committed by GitHub
commit 58d0330cd7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 444 additions and 171 deletions

View file

@ -48,6 +48,7 @@ from tokenizers import Tokenizer
import litellm
import litellm._service_logger # for storing API inputs, outputs, and metadata
import litellm.litellm_core_utils
import litellm.litellm_core_utils.json_validation_rule
from litellm.caching import DualCache
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.exception_mapping_utils import get_error_message
@ -580,7 +581,7 @@ def client(original_function):
else:
return False
def post_call_processing(original_response, model):
def post_call_processing(original_response, model, optional_params: Optional[dict]):
try:
if original_response is None:
pass
@ -595,11 +596,47 @@ def client(original_function):
pass
else:
if isinstance(original_response, ModelResponse):
model_response = original_response.choices[
model_response: Optional[str] = original_response.choices[
0
].message.content
### POST-CALL RULES ###
rules_obj.post_call_rules(input=model_response, model=model)
].message.content # type: ignore
if model_response is not None:
### POST-CALL RULES ###
rules_obj.post_call_rules(
input=model_response, model=model
)
### JSON SCHEMA VALIDATION ###
if (
optional_params is not None
and "response_format" in optional_params
and isinstance(
optional_params["response_format"], dict
)
and "type" in optional_params["response_format"]
and optional_params["response_format"]["type"]
== "json_object"
and "response_schema"
in optional_params["response_format"]
and isinstance(
optional_params["response_format"][
"response_schema"
],
dict,
)
and "enforce_validation"
in optional_params["response_format"]
and optional_params["response_format"][
"enforce_validation"
]
is True
):
# schema given, json response expected, and validation enforced
litellm.litellm_core_utils.json_validation_rule.validate_schema(
schema=optional_params["response_format"][
"response_schema"
],
response=model_response,
)
except Exception as e:
raise e
@ -868,7 +905,11 @@ def client(original_function):
return result
### POST-CALL RULES ###
post_call_processing(original_response=result, model=model or None)
post_call_processing(
original_response=result,
model=model or None,
optional_params=kwargs,
)
# [OPTIONAL] ADD TO CACHE
if (
@ -1317,7 +1358,9 @@ def client(original_function):
).total_seconds() * 1000 # return response latency in ms like openai
### POST-CALL RULES ###
post_call_processing(original_response=result, model=model)
post_call_processing(
original_response=result, model=model, optional_params=kwargs
)
# [OPTIONAL] ADD TO CACHE
if (
@ -1880,8 +1923,7 @@ def supports_response_schema(model: str, custom_llm_provider: Optional[str]) ->
Returns:
bool: True if the model supports response_schema, False otherwise.
Raises:
Exception: If the given model is not found in model_prices_and_context_window.json.
Does not raise error. Defaults to 'False'. Outputs logging.error.
"""
try:
## GET LLM PROVIDER ##
@ -1901,9 +1943,10 @@ def supports_response_schema(model: str, custom_llm_provider: Optional[str]) ->
return True
return False
except Exception:
raise Exception(
verbose_logger.error(
f"Model not in model_prices_and_context_window.json. You passed model={model}, custom_llm_provider={custom_llm_provider}."
)
return False
def supports_function_calling(model: str) -> bool: