add parameter mapping with vertex ai

This commit is contained in:
Ishaan Jaff 2024-05-20 13:28:20 -07:00
parent 2c25bfa8df
commit 518db13982
4 changed files with 30 additions and 2 deletions

View file

@ -521,6 +521,19 @@ response = await litellm.aimage_generation(
)
```
**Generating multiple images**
Use the `n` parameter to pass how many images you want generated
```python
response = await litellm.aimage_generation(
prompt="An olympic size swimming pool",
model="vertex_ai/imagegeneration@006",
vertex_ai_project="adroit-crow-413218",
vertex_ai_location="us-central1",
n=1,
)
```
## Extra
### Using `GOOGLE_APPLICATION_CREDENTIALS`

View file

@ -153,15 +153,21 @@ class VertexLLM(BaseLLM):
{
"prompt": "a cat"
}
]
],
"parameters": {
"sampleCount": 1
}
} \
"https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
"""
auth_header = self._ensure_access_token()
optional_params = optional_params or {
"sampleCount": 1
} # default optional params
request_data = {
"instances": [{"prompt": prompt}],
"parameters": {"sampleCount": 1},
"parameters": optional_params,
}
request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\""

View file

@ -184,6 +184,7 @@ async def test_aimage_generation_vertex_ai():
model="vertex_ai/imagegeneration@006",
vertex_ai_project="adroit-crow-413218",
vertex_ai_location="us-central1",
n=1,
)
assert response.data is not None
assert len(response.data) > 0

View file

@ -4946,6 +4946,14 @@ def get_optional_params_image_gen(
width, height = size.split("x")
optional_params["width"] = int(width)
optional_params["height"] = int(height)
elif custom_llm_provider == "vertex_ai":
supported_params = ["n"]
"""
All params here: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218
"""
_check_valid_arg(supported_params=supported_params)
if n is not None:
optional_params["sampleCount"] = int(n)
for k in passed_params.keys():
if k not in default_params.keys():