feat(utils.py): support validating json schema client-side if user opts in

This commit is contained in:
Krrish Dholakia 2024-08-06 19:35:33 -07:00
parent 5dfde2ee0b
commit 2dd27a4e12
4 changed files with 117 additions and 62 deletions

View file

@ -69,7 +69,10 @@ To use Structured Outputs, simply specify
response_format: { "type": "json_schema", "json_schema": … , "strict": true }
```
Works for OpenAI models
Works for:
- OpenAI models
- Google AI Studio - Gemini models
- Vertex AI models (Gemini + Anthropic)
<Tabs>
<TabItem value="sdk" label="SDK">
@ -202,15 +205,15 @@ curl -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
## Validate JSON Schema
:::info
Support for doing this in the openai 'json_schema' format will be [coming soon](https://github.com/BerriAI/litellm/issues/5074#issuecomment-2272355842)
Not all vertex models support passing the json_schema to them (e.g. `gemini-1.5-flash`). To solve this, LiteLLM supports client-side validation of the json schema.
:::
```
litellm.enable_json_schema_validation=True
```
If `litellm.enable_json_schema_validation=True` is set, LiteLLM will validate the json response using `jsonvalidator`.
For VertexAI models, LiteLLM supports passing the `response_schema` and validating the JSON output.
This works across Gemini (`vertex_ai_beta/`) + Anthropic (`vertex_ai/`) models.
[**See Code**](https://github.com/BerriAI/litellm/blob/671d8ac496b6229970c7f2a3bdedd6cb84f0746b/litellm/litellm_core_utils/json_validation_rule.py#L4)
<Tabs>
@ -218,33 +221,28 @@ This works across Gemini (`vertex_ai_beta/`) + Anthropic (`vertex_ai/`) models.
```python
# !gcloud auth application-default login - run this to add vertex credentials to your env
import litellm, os
from litellm import completion
from pydantic import BaseModel
messages = [{"role": "user", "content": "List 5 cookie recipes"}]
response_schema = {
"type": "array",
"items": {
"type": "object",
"properties": {
"recipe_name": {
"type": "string",
},
},
"required": ["recipe_name"],
},
}
messages=[
{"role": "system", "content": "Extract the event information."},
{"role": "user", "content": "Alice and Bob are going to a science fair on Friday."},
]
litellm.enable_json_schema_validation = True
litellm.set_verbose = True # see the raw request made by litellm
class CalendarEvent(BaseModel):
name: str
date: str
participants: list[str]
resp = completion(
model="vertex_ai_beta/gemini-1.5-pro",
model="gemini/gemini-1.5-pro",
messages=messages,
response_format={
"type": "json_object",
"response_schema": response_schema,
"enforce_validation": True, # client-side json schema validation
},
vertex_location="us-east5",
response_format=CalendarEvent,
)
print("Received={}".format(resp))
@ -252,26 +250,63 @@ print("Received={}".format(resp))
</TabItem>
<TabItem value="proxy" label="PROXY">
1. Create config.yaml
```yaml
model_list:
- model_name: "gemini-1.5-flash"
litellm_params:
model: "gemini/gemini-1.5-flash"
api_key: os.environ/GEMINI_API_KEY
litellm_settings:
enable_json_schema_validation: True
```
2. Start proxy
```bash
litellm --config /path/to/config.yaml
```
3. Test it!
```bash
curl http://0.0.0.0:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer $LITELLM_API_KEY" \
-d '{
"model": "vertex_ai_beta/gemini-1.5-pro",
"messages": [{"role": "user", "content": "List 5 cookie recipes"}]
"model": "gemini-1.5-flash",
"messages": [
{"role": "system", "content": "Extract the event information."},
{"role": "user", "content": "Alice and Bob are going to a science fair on Friday."},
],
"response_format": {
"type": "json_object",
"enforce_validation: true,
"response_schema": {
"type": "array",
"items": {
"type": "json_schema",
"json_schema": {
"name": "math_reasoning",
"schema": {
"type": "object",
"properties": {
"recipe_name": {
"type": "string",
},
"steps": {
"type": "array",
"items": {
"type": "object",
"properties": {
"explanation": { "type": "string" },
"output": { "type": "string" }
},
"required": ["explanation", "output"],
"additionalProperties": false
}
},
"final_answer": { "type": "string" }
},
"required": ["recipe_name"],
"required": ["steps", "final_answer"],
"additionalProperties": false
},
"strict": true
},
}
},

View file

@ -144,6 +144,7 @@ enable_preview_features: bool = False
return_response_headers: bool = (
False # get response headers from LLM Api providers - example x-remaining-requests,
)
enable_json_schema_validation: bool = False
##################
logging: bool = True
enable_caching_on_provider_specific_optional_params: bool = (

View file

@ -4,4 +4,4 @@ model_list:
model: "*"
litellm_settings:
callbacks: ["lakera_prompt_injection"]
enable_json_schema_validation: true

View file

@ -631,8 +631,8 @@ def client(original_function):
call_type == CallTypes.completion.value
or call_type == CallTypes.acompletion.value
):
is_coroutine = check_coroutine(original_function)
if is_coroutine == True:
is_coroutine = check_coroutine(original_response)
if is_coroutine is True:
pass
else:
if isinstance(original_response, ModelResponse):
@ -645,30 +645,49 @@ def client(original_function):
input=model_response, model=model
)
### JSON SCHEMA VALIDATION ###
try:
if (
optional_params is not None
and "response_format" in optional_params
and _parsing._completions.is_basemodel_type(
optional_params["response_format"]
)
):
json_response_format = (
type_to_response_format_param(
response_format=optional_params[
if litellm.enable_json_schema_validation is True:
try:
if (
optional_params is not None
and "response_format" in optional_params
and optional_params["response_format"]
is not None
):
json_response_format: Optional[dict] = None
if (
isinstance(
optional_params["response_format"],
dict,
)
and optional_params[
"response_format"
].get("json_schema")
is not None
):
json_response_format = optional_params[
"response_format"
]
)
)
if json_response_format is not None:
litellm.litellm_core_utils.json_validation_rule.validate_schema(
schema=json_response_format[
"json_schema"
]["schema"],
response=model_response,
)
except TypeError:
pass
elif (
_parsing._completions.is_basemodel_type(
optional_params["response_format"]
)
):
json_response_format = (
type_to_response_format_param(
response_format=optional_params[
"response_format"
]
)
)
if json_response_format is not None:
litellm.litellm_core_utils.json_validation_rule.validate_schema(
schema=json_response_format[
"json_schema"
]["schema"],
response=model_response,
)
except TypeError:
pass
if (
optional_params is not None
and "response_format" in optional_params