fix(utils.py): support json schema validation

This commit is contained in:
Krrish Dholakia 2024-06-29 15:05:52 -07:00
parent 05dfc63b88
commit b699d9a8b9
5 changed files with 216 additions and 12 deletions

View file

@ -849,6 +849,7 @@ from .exceptions import (
APIResponseValidationError,
UnprocessableEntityError,
InternalServerError,
JSONSchemaValidationError,
LITELLM_EXCEPTION_TYPES,
)
from .budget_manager import BudgetManager

View file

@ -551,7 +551,7 @@ class APIError(openai.APIError): # type: ignore
message,
llm_provider,
model,
request: httpx.Request,
request: Optional[httpx.Request] = None,
litellm_debug_info: Optional[str] = None,
max_retries: Optional[int] = None,
num_retries: Optional[int] = None,
@ -563,6 +563,8 @@ class APIError(openai.APIError): # type: ignore
self.litellm_debug_info = litellm_debug_info
self.max_retries = max_retries
self.num_retries = num_retries
if request is None:
request = httpx.Request(method="POST", url="https://api.openai.com/v1")
super().__init__(self.message, request=request, body=None) # type: ignore
def __str__(self):
@ -664,6 +666,22 @@ class OpenAIError(openai.OpenAIError): # type: ignore
self.llm_provider = "openai"
class JSONSchemaValidationError(APIError):
def __init__(
self, model: str, llm_provider: str, raw_response: str, schema: str
) -> None:
self.raw_response = raw_response
self.schema = schema
self.model = model
message = "litellm.JSONSchemaValidationError: model={}, returned an invalid response={}, for schema={}.\nAccess raw response with `e.raw_response`".format(
model, raw_response, schema
)
self.message = message
super().__init__(
model=model, message=message, llm_provider=llm_provider, status_code=500
)
LITELLM_EXCEPTION_TYPES = [
AuthenticationError,
NotFoundError,
@ -682,6 +700,7 @@ LITELLM_EXCEPTION_TYPES = [
APIResponseValidationError,
OpenAIError,
InternalServerError,
JSONSchemaValidationError,
]

View file

@ -0,0 +1,23 @@
import json
def validate_schema(schema: dict, response: str):
"""
Validate if the returned json response follows the schema.
Params:
- schema - dict: JSON schema
- response - str: Received json response as string.
"""
from jsonschema import ValidationError, validate
from litellm import JSONSchemaValidationError
response_dict = json.loads(response)
try:
validate(response_dict, schema=schema)
except ValidationError:
raise JSONSchemaValidationError(
model="", llm_provider="", raw_response=response, schema=json.dumps(schema)
)

View file

@ -880,15 +880,133 @@ Using this JSON schema:
mock_call.assert_called_once()
def vertex_httpx_mock_post_valid_response(*args, **kwargs):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json.return_value = {
"candidates": [
{
"content": {
"role": "model",
"parts": [
{
"text": '[{"recipe_name": "Chocolate Chip Cookies"}, {"recipe_name": "Oatmeal Raisin Cookies"}, {"recipe_name": "Peanut Butter Cookies"}, {"recipe_name": "Sugar Cookies"}, {"recipe_name": "Snickerdoodles"}]\n'
}
],
},
"finishReason": "STOP",
"safetyRatings": [
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"probability": "NEGLIGIBLE",
"probabilityScore": 0.09790669,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severityScore": 0.11736965,
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"probability": "NEGLIGIBLE",
"probabilityScore": 0.1261379,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severityScore": 0.08601588,
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"probability": "NEGLIGIBLE",
"probabilityScore": 0.083441176,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severityScore": 0.0355444,
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"probability": "NEGLIGIBLE",
"probabilityScore": 0.071981624,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severityScore": 0.08108212,
},
],
}
],
"usageMetadata": {
"promptTokenCount": 60,
"candidatesTokenCount": 55,
"totalTokenCount": 115,
},
}
return mock_response
def vertex_httpx_mock_post_invalid_schema_response(*args, **kwargs):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json.return_value = {
"candidates": [
{
"content": {
"role": "model",
"parts": [
{"text": '[{"recipe_world": "Chocolate Chip Cookies"}]\n'}
],
},
"finishReason": "STOP",
"safetyRatings": [
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"probability": "NEGLIGIBLE",
"probabilityScore": 0.09790669,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severityScore": 0.11736965,
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"probability": "NEGLIGIBLE",
"probabilityScore": 0.1261379,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severityScore": 0.08601588,
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"probability": "NEGLIGIBLE",
"probabilityScore": 0.083441176,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severityScore": 0.0355444,
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"probability": "NEGLIGIBLE",
"probabilityScore": 0.071981624,
"severity": "HARM_SEVERITY_NEGLIGIBLE",
"severityScore": 0.08108212,
},
],
}
],
"usageMetadata": {
"promptTokenCount": 60,
"candidatesTokenCount": 55,
"totalTokenCount": 115,
},
}
return mock_response
@pytest.mark.parametrize(
"model, supports_response_schema",
[
("vertex_ai_beta/gemini-1.5-pro-001", True),
("vertex_ai_beta/gemini-1.5-flash", False),
],
) # "vertex_ai",
)
@pytest.mark.parametrize(
"invalid_response",
[True, False],
)
@pytest.mark.asyncio
async def test_gemini_pro_json_schema_httpx(model, supports_response_schema):
async def test_gemini_pro_json_schema_args_sent_httpx(
model, supports_response_schema, invalid_response
):
load_vertex_ai_credentials()
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
@ -912,7 +1030,12 @@ async def test_gemini_pro_json_schema_httpx(model, supports_response_schema):
client = HTTPHandler()
with patch.object(client, "post", new=MagicMock()) as mock_call:
httpx_response = MagicMock()
if invalid_response is True:
httpx_response.side_effect = vertex_httpx_mock_post_invalid_schema_response
else:
httpx_response.side_effect = vertex_httpx_mock_post_valid_response
with patch.object(client, "post", new=httpx_response) as mock_call:
try:
_ = completion(
model=model,
@ -923,8 +1046,11 @@ async def test_gemini_pro_json_schema_httpx(model, supports_response_schema):
},
client=client,
)
except Exception:
pass
if invalid_response is True:
pytest.fail("Expected this to fail")
except litellm.JSONSchemaValidationError as e:
if invalid_response is False:
pytest.fail("Expected this to pass. Got={}".format(e))
mock_call.assert_called_once()
print(mock_call.call_args.kwargs)

View file

@ -48,6 +48,7 @@ from tokenizers import Tokenizer
import litellm
import litellm._service_logger # for storing API inputs, outputs, and metadata
import litellm.litellm_core_utils
import litellm.litellm_core_utils.json_validation_rule
from litellm.caching import DualCache
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.llm_request_utils import _ensure_extra_body_is_safe
@ -579,7 +580,7 @@ def client(original_function):
else:
return False
def post_call_processing(original_response, model):
def post_call_processing(original_response, model, optional_params: Optional[dict]):
try:
if original_response is None:
pass
@ -594,11 +595,41 @@ def client(original_function):
pass
else:
if isinstance(original_response, ModelResponse):
model_response = original_response.choices[
model_response: Optional[str] = original_response.choices[
0
].message.content
].message.content # type: ignore
if model_response is not None:
### POST-CALL RULES ###
rules_obj.post_call_rules(input=model_response, model=model)
rules_obj.post_call_rules(
input=model_response, model=model
)
### JSON SCHEMA VALIDATION ###
if (
optional_params is not None
and "response_format" in optional_params
and isinstance(
optional_params["response_format"], dict
)
and "type" in optional_params["response_format"]
and optional_params["response_format"]["type"]
== "json_object"
and "response_schema"
in optional_params["response_format"]
and isinstance(
optional_params["response_format"][
"response_schema"
],
dict,
)
):
# schema given, json response expected
litellm.litellm_core_utils.json_validation_rule.validate_schema(
schema=optional_params["response_format"][
"response_schema"
],
response=model_response,
)
except Exception as e:
raise e
@ -867,7 +898,11 @@ def client(original_function):
return result
### POST-CALL RULES ###
post_call_processing(original_response=result, model=model or None)
post_call_processing(
original_response=result,
model=model or None,
optional_params=kwargs,
)
# [OPTIONAL] ADD TO CACHE
if (