feat(utils.py): support validating json schema client-side if user opts in

This commit is contained in:
Krrish Dholakia 2024-08-06 19:35:33 -07:00
parent 5dfde2ee0b
commit 2dd27a4e12
4 changed files with 117 additions and 62 deletions

View file

@ -69,7 +69,10 @@ To use Structured Outputs, simply specify
response_format: { "type": "json_schema", "json_schema": … , "strict": true } response_format: { "type": "json_schema", "json_schema": … , "strict": true }
``` ```
Works for OpenAI models Works for:
- OpenAI models
- Google AI Studio - Gemini models
- Vertex AI models (Gemini + Anthropic)
<Tabs> <Tabs>
<TabItem value="sdk" label="SDK"> <TabItem value="sdk" label="SDK">
@ -202,15 +205,15 @@ curl -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
## Validate JSON Schema ## Validate JSON Schema
:::info
Support for doing this in the openai 'json_schema' format will be [coming soon](https://github.com/BerriAI/litellm/issues/5074#issuecomment-2272355842) Not all vertex models support passing the json_schema to them (e.g. `gemini-1.5-flash`). To solve this, LiteLLM supports client-side validation of the json schema.
::: ```
litellm.enable_json_schema_validation=True
```
If `litellm.enable_json_schema_validation=True` is set, LiteLLM will validate the json response using `jsonvalidator`.
For VertexAI models, LiteLLM supports passing the `response_schema` and validating the JSON output. [**See Code**](https://github.com/BerriAI/litellm/blob/671d8ac496b6229970c7f2a3bdedd6cb84f0746b/litellm/litellm_core_utils/json_validation_rule.py#L4)
This works across Gemini (`vertex_ai_beta/`) + Anthropic (`vertex_ai/`) models.
<Tabs> <Tabs>
@ -218,33 +221,28 @@ This works across Gemini (`vertex_ai_beta/`) + Anthropic (`vertex_ai/`) models.
```python ```python
# !gcloud auth application-default login - run this to add vertex credentials to your env # !gcloud auth application-default login - run this to add vertex credentials to your env
import litellm, os
from litellm import completion from litellm import completion
from pydantic import BaseModel
messages = [{"role": "user", "content": "List 5 cookie recipes"}]
response_schema = { messages=[
"type": "array", {"role": "system", "content": "Extract the event information."},
"items": { {"role": "user", "content": "Alice and Bob are going to a science fair on Friday."},
"type": "object", ]
"properties": {
"recipe_name": { litellm.enable_json_schema_validation = True
"type": "string", litellm.set_verbose = True # see the raw request made by litellm
},
}, class CalendarEvent(BaseModel):
"required": ["recipe_name"], name: str
}, date: str
} participants: list[str]
resp = completion( resp = completion(
model="vertex_ai_beta/gemini-1.5-pro", model="gemini/gemini-1.5-pro",
messages=messages, messages=messages,
response_format={ response_format=CalendarEvent,
"type": "json_object",
"response_schema": response_schema,
"enforce_validation": True, # client-side json schema validation
},
vertex_location="us-east5",
) )
print("Received={}".format(resp)) print("Received={}".format(resp))
@ -252,26 +250,63 @@ print("Received={}".format(resp))
</TabItem> </TabItem>
<TabItem value="proxy" label="PROXY"> <TabItem value="proxy" label="PROXY">
1. Create config.yaml
```yaml
model_list:
- model_name: "gemini-1.5-flash"
litellm_params:
model: "gemini/gemini-1.5-flash"
api_key: os.environ/GEMINI_API_KEY
litellm_settings:
enable_json_schema_validation: True
```
2. Start proxy
```bash
litellm --config /path/to/config.yaml
```
3. Test it!
```bash ```bash
curl http://0.0.0.0:4000/v1/chat/completions \ curl http://0.0.0.0:4000/v1/chat/completions \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-H "Authorization: Bearer $LITELLM_API_KEY" \ -H "Authorization: Bearer $LITELLM_API_KEY" \
-d '{ -d '{
"model": "vertex_ai_beta/gemini-1.5-pro", "model": "gemini-1.5-flash",
"messages": [{"role": "user", "content": "List 5 cookie recipes"}] "messages": [
{"role": "system", "content": "Extract the event information."},
{"role": "user", "content": "Alice and Bob are going to a science fair on Friday."},
],
"response_format": { "response_format": {
"type": "json_object", "type": "json_object",
"enforce_validation: true,
"response_schema": { "response_schema": {
"type": "json_schema",
"json_schema": {
"name": "math_reasoning",
"schema": {
"type": "object",
"properties": {
"steps": {
"type": "array", "type": "array",
"items": { "items": {
"type": "object", "type": "object",
"properties": { "properties": {
"recipe_name": { "explanation": { "type": "string" },
"type": "string", "output": { "type": "string" }
}, },
"required": ["explanation", "output"],
"additionalProperties": false
}
}, },
"required": ["recipe_name"], "final_answer": { "type": "string" }
},
"required": ["steps", "final_answer"],
"additionalProperties": false
},
"strict": true
}, },
} }
}, },

View file

@ -144,6 +144,7 @@ enable_preview_features: bool = False
return_response_headers: bool = ( return_response_headers: bool = (
False # get response headers from LLM Api providers - example x-remaining-requests, False # get response headers from LLM Api providers - example x-remaining-requests,
) )
enable_json_schema_validation: bool = False
################## ##################
logging: bool = True logging: bool = True
enable_caching_on_provider_specific_optional_params: bool = ( enable_caching_on_provider_specific_optional_params: bool = (

View file

@ -4,4 +4,4 @@ model_list:
model: "*" model: "*"
litellm_settings: litellm_settings:
callbacks: ["lakera_prompt_injection"] enable_json_schema_validation: true

View file

@ -631,8 +631,8 @@ def client(original_function):
call_type == CallTypes.completion.value call_type == CallTypes.completion.value
or call_type == CallTypes.acompletion.value or call_type == CallTypes.acompletion.value
): ):
is_coroutine = check_coroutine(original_function) is_coroutine = check_coroutine(original_response)
if is_coroutine == True: if is_coroutine is True:
pass pass
else: else:
if isinstance(original_response, ModelResponse): if isinstance(original_response, ModelResponse):
@ -645,11 +645,30 @@ def client(original_function):
input=model_response, model=model input=model_response, model=model
) )
### JSON SCHEMA VALIDATION ### ### JSON SCHEMA VALIDATION ###
if litellm.enable_json_schema_validation is True:
try: try:
if ( if (
optional_params is not None optional_params is not None
and "response_format" in optional_params and "response_format" in optional_params
and _parsing._completions.is_basemodel_type( and optional_params["response_format"]
is not None
):
json_response_format: Optional[dict] = None
if (
isinstance(
optional_params["response_format"],
dict,
)
and optional_params[
"response_format"
].get("json_schema")
is not None
):
json_response_format = optional_params[
"response_format"
]
elif (
_parsing._completions.is_basemodel_type(
optional_params["response_format"] optional_params["response_format"]
) )
): ):