mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
Merge pull request #4478 from BerriAI/litellm_support_response_schema_param_vertex_ai_old
feat(vertex_httpx.py): support the 'response_schema' param for older vertex ai models
This commit is contained in:
commit
58d0330cd7
14 changed files with 444 additions and 171 deletions
|
@ -48,6 +48,7 @@ from tokenizers import Tokenizer
|
|||
import litellm
|
||||
import litellm._service_logger # for storing API inputs, outputs, and metadata
|
||||
import litellm.litellm_core_utils
|
||||
import litellm.litellm_core_utils.json_validation_rule
|
||||
from litellm.caching import DualCache
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.litellm_core_utils.exception_mapping_utils import get_error_message
|
||||
|
@ -580,7 +581,7 @@ def client(original_function):
|
|||
else:
|
||||
return False
|
||||
|
||||
def post_call_processing(original_response, model):
|
||||
def post_call_processing(original_response, model, optional_params: Optional[dict]):
|
||||
try:
|
||||
if original_response is None:
|
||||
pass
|
||||
|
@ -595,11 +596,47 @@ def client(original_function):
|
|||
pass
|
||||
else:
|
||||
if isinstance(original_response, ModelResponse):
|
||||
model_response = original_response.choices[
|
||||
model_response: Optional[str] = original_response.choices[
|
||||
0
|
||||
].message.content
|
||||
### POST-CALL RULES ###
|
||||
rules_obj.post_call_rules(input=model_response, model=model)
|
||||
].message.content # type: ignore
|
||||
if model_response is not None:
|
||||
### POST-CALL RULES ###
|
||||
rules_obj.post_call_rules(
|
||||
input=model_response, model=model
|
||||
)
|
||||
### JSON SCHEMA VALIDATION ###
|
||||
if (
|
||||
optional_params is not None
|
||||
and "response_format" in optional_params
|
||||
and isinstance(
|
||||
optional_params["response_format"], dict
|
||||
)
|
||||
and "type" in optional_params["response_format"]
|
||||
and optional_params["response_format"]["type"]
|
||||
== "json_object"
|
||||
and "response_schema"
|
||||
in optional_params["response_format"]
|
||||
and isinstance(
|
||||
optional_params["response_format"][
|
||||
"response_schema"
|
||||
],
|
||||
dict,
|
||||
)
|
||||
and "enforce_validation"
|
||||
in optional_params["response_format"]
|
||||
and optional_params["response_format"][
|
||||
"enforce_validation"
|
||||
]
|
||||
is True
|
||||
):
|
||||
# schema given, json response expected, and validation enforced
|
||||
litellm.litellm_core_utils.json_validation_rule.validate_schema(
|
||||
schema=optional_params["response_format"][
|
||||
"response_schema"
|
||||
],
|
||||
response=model_response,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
@ -868,7 +905,11 @@ def client(original_function):
|
|||
return result
|
||||
|
||||
### POST-CALL RULES ###
|
||||
post_call_processing(original_response=result, model=model or None)
|
||||
post_call_processing(
|
||||
original_response=result,
|
||||
model=model or None,
|
||||
optional_params=kwargs,
|
||||
)
|
||||
|
||||
# [OPTIONAL] ADD TO CACHE
|
||||
if (
|
||||
|
@ -1317,7 +1358,9 @@ def client(original_function):
|
|||
).total_seconds() * 1000 # return response latency in ms like openai
|
||||
|
||||
### POST-CALL RULES ###
|
||||
post_call_processing(original_response=result, model=model)
|
||||
post_call_processing(
|
||||
original_response=result, model=model, optional_params=kwargs
|
||||
)
|
||||
|
||||
# [OPTIONAL] ADD TO CACHE
|
||||
if (
|
||||
|
@ -1880,8 +1923,7 @@ def supports_response_schema(model: str, custom_llm_provider: Optional[str]) ->
|
|||
Returns:
|
||||
bool: True if the model supports response_schema, False otherwise.
|
||||
|
||||
Raises:
|
||||
Exception: If the given model is not found in model_prices_and_context_window.json.
|
||||
Does not raise error. Defaults to 'False'. Outputs logging.error.
|
||||
"""
|
||||
try:
|
||||
## GET LLM PROVIDER ##
|
||||
|
@ -1901,9 +1943,10 @@ def supports_response_schema(model: str, custom_llm_provider: Optional[str]) ->
|
|||
return True
|
||||
return False
|
||||
except Exception:
|
||||
raise Exception(
|
||||
verbose_logger.error(
|
||||
f"Model not in model_prices_and_context_window.json. You passed model={model}, custom_llm_provider={custom_llm_provider}."
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def supports_function_calling(model: str) -> bool:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue