use litellm proxy with vertex ai sdk

This commit is contained in:
Ishaan Jaff 2024-08-21 17:46:23 -07:00
parent f310ce541f
commit 0ea1f367d7
3 changed files with 233 additions and 6 deletions

View file

@ -84,7 +84,6 @@ vertexai.init(
api_endpoint=LITELLM_PROXY_BASE, api_endpoint=LITELLM_PROXY_BASE,
credentials=credentials, credentials=credentials,
api_transport="rest", api_transport="rest",
request_metadata=[("Authorization", f"Bearer {LITELLM_PROXY_API_KEY}")],
) )
model = GenerativeModel("gemini-1.5-flash-001") model = GenerativeModel("gemini-1.5-flash-001")
@ -143,7 +142,7 @@ vertexai.init(
api_endpoint=LITELLM_PROXY_BASE, api_endpoint=LITELLM_PROXY_BASE,
credentials=credentials, credentials=credentials,
api_transport="rest", api_transport="rest",
request_metadata=[("Authorization", f"Bearer {LITELLM_PROXY_API_KEY}")],
) )
model = GenerativeModel("gemini-1.5-flash-001") model = GenerativeModel("gemini-1.5-flash-001")
@ -216,7 +215,7 @@ vertexai.init(
api_endpoint=LITELLM_PROXY_BASE, api_endpoint=LITELLM_PROXY_BASE,
credentials=credentials, credentials=credentials,
api_transport="rest", api_transport="rest",
request_metadata=[("Authorization", f"Bearer {LITELLM_PROXY_API_KEY}")], )
def embed_text( def embed_text(
@ -249,6 +248,80 @@ curl http://localhost:4000/vertex-ai/publishers/google/models/textembedding-geck
### Imagen API ### Imagen API
<Tabs>
<TabItem value="py" label="Vertex Python SDK">
```python
from typing import List, Optional
from vertexai.preview.vision_models import ImageGenerationModel
import vertexai
from google.auth.credentials import Credentials
LITELLM_PROXY_API_KEY = "sk-1234"
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex-ai"
import datetime
class CredentialsWrapper(Credentials):
def __init__(self, token=None):
super().__init__()
self.token = token
self.expiry = None # or set to a future date if needed
def refresh(self, request):
pass
def apply(self, headers, token=None):
headers["Authorization"] = f"Bearer {self.token}"
@property
def expired(self):
return False # Always consider the token as non-expired
@property
def valid(self):
return True # Always consider the credentials as valid
credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
vertexai.init(
project="adroit-crow-413218",
location="us-central1",
api_endpoint=LITELLM_PROXY_BASE,
credentials=credentials,
api_transport="rest",
)
model = ImageGenerationModel.from_pretrained("imagen-3.0-generate-001")
images = model.generate_images(
prompt=prompt,
# Optional parameters
number_of_images=1,
language="en",
# You can't use a seed value and watermark at the same time.
# add_watermark=False,
# seed=100,
aspect_ratio="1:1",
safety_filter_level="block_some",
person_generation="allow_adult",
)
images[0].save(location=output_file, include_generation_parameters=False)
# Optional. View the generated image in a notebook.
# images[0].show()
print(f"Created output image using {len(images[0]._image_bytes)} bytes")
```
</TabItem>
<TabItem value="curl" label="Curl">
```shell ```shell
curl http://localhost:4000/vertex-ai/publishers/google/models/imagen-3.0-generate-001:predict \ curl http://localhost:4000/vertex-ai/publishers/google/models/imagen-3.0-generate-001:predict \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
@ -256,8 +329,86 @@ curl http://localhost:4000/vertex-ai/publishers/google/models/imagen-3.0-generat
-d '{"instances":[{"prompt": "make an otter"}], "parameters": {"sampleCount": 1}}' -d '{"instances":[{"prompt": "make an otter"}], "parameters": {"sampleCount": 1}}'
``` ```
</TabItem>
</Tabs>
### Count Tokens API ### Count Tokens API
<Tabs>
<TabItem value="py" label="Vertex Python SDK">
```python
from typing import List, Optional
from vertexai.generative_models import GenerativeModel
import vertexai
from google.auth.credentials import Credentials
LITELLM_PROXY_API_KEY = "sk-1234"
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex-ai"
import datetime
class CredentialsWrapper(Credentials):
def __init__(self, token=None):
super().__init__()
self.token = token
self.expiry = None # or set to a future date if needed
def refresh(self, request):
pass
def apply(self, headers, token=None):
headers["Authorization"] = f"Bearer {self.token}"
@property
def expired(self):
return False # Always consider the token as non-expired
@property
def valid(self):
return True # Always consider the credentials as valid
credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
vertexai.init(
project="adroit-crow-413218",
location="us-central1",
api_endpoint=LITELLM_PROXY_BASE,
credentials=credentials,
api_transport="rest",
)
model = GenerativeModel("gemini-1.5-flash-001")
prompt = "Why is the sky blue?"
# Prompt tokens count
response = model.count_tokens(prompt)
print(f"Prompt Token Count: {response.total_tokens}")
print(f"Prompt Character Count: {response.total_billable_characters}")
# Send text to Gemini
response = model.generate_content(prompt)
# Response tokens count
usage_metadata = response.usage_metadata
print(f"Prompt Token Count: {usage_metadata.prompt_token_count}")
print(f"Candidates Token Count: {usage_metadata.candidates_token_count}")
print(f"Total Token Count: {usage_metadata.total_token_count}")
```
</TabItem>
<TabItem value="curl" label="Curl">
```shell ```shell
curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-001:countTokens \ curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-001:countTokens \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
@ -265,10 +416,83 @@ curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-0
-d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}' -d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}'
``` ```
</TabItem>
</Tabs>
### Tuning API ### Tuning API
Create Fine Tuning Job Create Fine Tuning Job
<Tabs>
<TabItem value="py" label="Vertex Python SDK">
```python
from typing import List, Optional
from vertexai.preview.tuning import sft
import vertexai
from google.auth.credentials import Credentials
LITELLM_PROXY_API_KEY = "sk-1234"
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex-ai"
import datetime
class CredentialsWrapper(Credentials):
def __init__(self, token=None):
super().__init__()
self.token = token
self.expiry = None # or set to a future date if needed
def refresh(self, request):
pass
def apply(self, headers, token=None):
headers["Authorization"] = f"Bearer {self.token}"
@property
def expired(self):
return False # Always consider the token as non-expired
@property
def valid(self):
return True # Always consider the credentials as valid
credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
vertexai.init(
project="adroit-crow-413218",
location="us-central1",
api_endpoint=LITELLM_PROXY_BASE,
credentials=credentials,
api_transport="rest",
)
# TODO(developer): Update project
vertexai.init(project=PROJECT_ID, location="us-central1")
sft_tuning_job = sft.train(
source_model="gemini-1.0-pro-002",
train_dataset="gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl",
)
# Polling for job completion
while not sft_tuning_job.has_ended:
time.sleep(60)
sft_tuning_job.refresh()
print(sft_tuning_job.tuned_model_name)
print(sft_tuning_job.tuned_model_endpoint_name)
print(sft_tuning_job.experiment)
```
</TabItem>
<TabItem value="curl" label="Curl">
```shell ```shell
curl http://localhost:4000/vertex-ai/tuningJobs \ curl http://localhost:4000/vertex-ai/tuningJobs \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
@ -279,4 +503,8 @@ curl http://localhost:4000/vertex-ai/tuningJobs \
"training_dataset_uri": "gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl" "training_dataset_uri": "gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl"
} }
}' }'
``` ```
</TabItem>
</Tabs>

View file

@ -1694,7 +1694,7 @@ vertexai.init(
api_endpoint=LITELLM_PROXY_BASE, api_endpoint=LITELLM_PROXY_BASE,
credentials = credentials, credentials = credentials,
api_transport="rest", api_transport="rest",
request_metadata=[("Authorization", f"Bearer {LITELLM_PROXY_API_KEY}")],
) )
model = MultiModalEmbeddingModel.from_pretrained("multimodalembedding") model = MultiModalEmbeddingModel.from_pretrained("multimodalembedding")

View file

@ -42,7 +42,6 @@ vertexai.init(
api_endpoint=LITELLM_PROXY_BASE, api_endpoint=LITELLM_PROXY_BASE,
credentials=credentials, credentials=credentials,
api_transport="rest", api_transport="rest",
request_metadata=[("Authorization", f"Bearer {LITELLM_PROXY_API_KEY}")],
) )
model = MultiModalEmbeddingModel.from_pretrained("multimodalembedding") model = MultiModalEmbeddingModel.from_pretrained("multimodalembedding")