fix(utils.py): support json schema validation

This commit is contained in:
Krrish Dholakia 2024-06-29 15:05:52 -07:00
parent 05dfc63b88
commit b699d9a8b9
5 changed files with 216 additions and 12 deletions

View file

@ -48,6 +48,7 @@ from tokenizers import Tokenizer
import litellm
import litellm._service_logger # for storing API inputs, outputs, and metadata
import litellm.litellm_core_utils
import litellm.litellm_core_utils.json_validation_rule
from litellm.caching import DualCache
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.llm_request_utils import _ensure_extra_body_is_safe
@ -579,7 +580,7 @@ def client(original_function):
else:
return False
def post_call_processing(original_response, model):
def post_call_processing(original_response, model, optional_params: Optional[dict]):
try:
if original_response is None:
pass
@ -594,11 +595,41 @@ def client(original_function):
pass
else:
if isinstance(original_response, ModelResponse):
model_response = original_response.choices[
model_response: Optional[str] = original_response.choices[
0
].message.content
### POST-CALL RULES ###
rules_obj.post_call_rules(input=model_response, model=model)
].message.content # type: ignore
if model_response is not None:
### POST-CALL RULES ###
rules_obj.post_call_rules(
input=model_response, model=model
)
### JSON SCHEMA VALIDATION ###
if (
optional_params is not None
and "response_format" in optional_params
and isinstance(
optional_params["response_format"], dict
)
and "type" in optional_params["response_format"]
and optional_params["response_format"]["type"]
== "json_object"
and "response_schema"
in optional_params["response_format"]
and isinstance(
optional_params["response_format"][
"response_schema"
],
dict,
)
):
# schema given, json response expected
litellm.litellm_core_utils.json_validation_rule.validate_schema(
schema=optional_params["response_format"][
"response_schema"
],
response=model_response,
)
except Exception as e:
raise e
@ -867,7 +898,11 @@ def client(original_function):
return result
### POST-CALL RULES ###
post_call_processing(original_response=result, model=model or None)
post_call_processing(
original_response=result,
model=model or None,
optional_params=kwargs,
)
# [OPTIONAL] ADD TO CACHE
if (