mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
use litellm proxy with vertex ai sdk
This commit is contained in:
parent
f310ce541f
commit
0ea1f367d7
3 changed files with 233 additions and 6 deletions
|
@ -84,7 +84,6 @@ vertexai.init(
|
||||||
api_endpoint=LITELLM_PROXY_BASE,
|
api_endpoint=LITELLM_PROXY_BASE,
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
api_transport="rest",
|
api_transport="rest",
|
||||||
request_metadata=[("Authorization", f"Bearer {LITELLM_PROXY_API_KEY}")],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
model = GenerativeModel("gemini-1.5-flash-001")
|
model = GenerativeModel("gemini-1.5-flash-001")
|
||||||
|
@ -143,7 +142,7 @@ vertexai.init(
|
||||||
api_endpoint=LITELLM_PROXY_BASE,
|
api_endpoint=LITELLM_PROXY_BASE,
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
api_transport="rest",
|
api_transport="rest",
|
||||||
request_metadata=[("Authorization", f"Bearer {LITELLM_PROXY_API_KEY}")],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
model = GenerativeModel("gemini-1.5-flash-001")
|
model = GenerativeModel("gemini-1.5-flash-001")
|
||||||
|
@ -216,7 +215,7 @@ vertexai.init(
|
||||||
api_endpoint=LITELLM_PROXY_BASE,
|
api_endpoint=LITELLM_PROXY_BASE,
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
api_transport="rest",
|
api_transport="rest",
|
||||||
request_metadata=[("Authorization", f"Bearer {LITELLM_PROXY_API_KEY}")],
|
)
|
||||||
|
|
||||||
|
|
||||||
def embed_text(
|
def embed_text(
|
||||||
|
@ -249,6 +248,80 @@ curl http://localhost:4000/vertex-ai/publishers/google/models/textembedding-geck
|
||||||
|
|
||||||
### Imagen API
|
### Imagen API
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="py" label="Vertex Python SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
from typing import List, Optional
|
||||||
|
from vertexai.preview.vision_models import ImageGenerationModel
|
||||||
|
import vertexai
|
||||||
|
from google.auth.credentials import Credentials
|
||||||
|
|
||||||
|
LITELLM_PROXY_API_KEY = "sk-1234"
|
||||||
|
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex-ai"
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class CredentialsWrapper(Credentials):
|
||||||
|
def __init__(self, token=None):
|
||||||
|
super().__init__()
|
||||||
|
self.token = token
|
||||||
|
self.expiry = None # or set to a future date if needed
|
||||||
|
|
||||||
|
def refresh(self, request):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def apply(self, headers, token=None):
|
||||||
|
headers["Authorization"] = f"Bearer {self.token}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expired(self):
|
||||||
|
return False # Always consider the token as non-expired
|
||||||
|
|
||||||
|
@property
|
||||||
|
def valid(self):
|
||||||
|
return True # Always consider the credentials as valid
|
||||||
|
|
||||||
|
|
||||||
|
credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
|
||||||
|
|
||||||
|
vertexai.init(
|
||||||
|
project="adroit-crow-413218",
|
||||||
|
location="us-central1",
|
||||||
|
api_endpoint=LITELLM_PROXY_BASE,
|
||||||
|
credentials=credentials,
|
||||||
|
api_transport="rest",
|
||||||
|
)
|
||||||
|
|
||||||
|
model = ImageGenerationModel.from_pretrained("imagen-3.0-generate-001")
|
||||||
|
|
||||||
|
images = model.generate_images(
|
||||||
|
prompt=prompt,
|
||||||
|
# Optional parameters
|
||||||
|
number_of_images=1,
|
||||||
|
language="en",
|
||||||
|
# You can't use a seed value and watermark at the same time.
|
||||||
|
# add_watermark=False,
|
||||||
|
# seed=100,
|
||||||
|
aspect_ratio="1:1",
|
||||||
|
safety_filter_level="block_some",
|
||||||
|
person_generation="allow_adult",
|
||||||
|
)
|
||||||
|
|
||||||
|
images[0].save(location=output_file, include_generation_parameters=False)
|
||||||
|
|
||||||
|
# Optional. View the generated image in a notebook.
|
||||||
|
# images[0].show()
|
||||||
|
|
||||||
|
print(f"Created output image using {len(images[0]._image_bytes)} bytes")
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="curl" label="Curl">
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:4000/vertex-ai/publishers/google/models/imagen-3.0-generate-001:predict \
|
curl http://localhost:4000/vertex-ai/publishers/google/models/imagen-3.0-generate-001:predict \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
|
@ -256,8 +329,86 @@ curl http://localhost:4000/vertex-ai/publishers/google/models/imagen-3.0-generat
|
||||||
-d '{"instances":[{"prompt": "make an otter"}], "parameters": {"sampleCount": 1}}'
|
-d '{"instances":[{"prompt": "make an otter"}], "parameters": {"sampleCount": 1}}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
### Count Tokens API
|
### Count Tokens API
|
||||||
|
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
|
||||||
|
<TabItem value="py" label="Vertex Python SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
from typing import List, Optional
|
||||||
|
from vertexai.generative_models import GenerativeModel
|
||||||
|
import vertexai
|
||||||
|
from google.auth.credentials import Credentials
|
||||||
|
|
||||||
|
LITELLM_PROXY_API_KEY = "sk-1234"
|
||||||
|
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex-ai"
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class CredentialsWrapper(Credentials):
|
||||||
|
def __init__(self, token=None):
|
||||||
|
super().__init__()
|
||||||
|
self.token = token
|
||||||
|
self.expiry = None # or set to a future date if needed
|
||||||
|
|
||||||
|
def refresh(self, request):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def apply(self, headers, token=None):
|
||||||
|
headers["Authorization"] = f"Bearer {self.token}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expired(self):
|
||||||
|
return False # Always consider the token as non-expired
|
||||||
|
|
||||||
|
@property
|
||||||
|
def valid(self):
|
||||||
|
return True # Always consider the credentials as valid
|
||||||
|
|
||||||
|
|
||||||
|
credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
|
||||||
|
|
||||||
|
vertexai.init(
|
||||||
|
project="adroit-crow-413218",
|
||||||
|
location="us-central1",
|
||||||
|
api_endpoint=LITELLM_PROXY_BASE,
|
||||||
|
credentials=credentials,
|
||||||
|
api_transport="rest",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
model = GenerativeModel("gemini-1.5-flash-001")
|
||||||
|
|
||||||
|
prompt = "Why is the sky blue?"
|
||||||
|
|
||||||
|
# Prompt tokens count
|
||||||
|
response = model.count_tokens(prompt)
|
||||||
|
print(f"Prompt Token Count: {response.total_tokens}")
|
||||||
|
print(f"Prompt Character Count: {response.total_billable_characters}")
|
||||||
|
|
||||||
|
# Send text to Gemini
|
||||||
|
response = model.generate_content(prompt)
|
||||||
|
|
||||||
|
# Response tokens count
|
||||||
|
usage_metadata = response.usage_metadata
|
||||||
|
print(f"Prompt Token Count: {usage_metadata.prompt_token_count}")
|
||||||
|
print(f"Candidates Token Count: {usage_metadata.candidates_token_count}")
|
||||||
|
print(f"Total Token Count: {usage_metadata.total_token_count}")
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="curl" label="Curl">
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-001:countTokens \
|
curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-001:countTokens \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
|
@ -265,10 +416,83 @@ curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-0
|
||||||
-d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}'
|
-d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
### Tuning API
|
### Tuning API
|
||||||
|
|
||||||
Create Fine Tuning Job
|
Create Fine Tuning Job
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
|
||||||
|
<TabItem value="py" label="Vertex Python SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
from typing import List, Optional
|
||||||
|
from vertexai.preview.tuning import sft
|
||||||
|
import vertexai
|
||||||
|
from google.auth.credentials import Credentials
|
||||||
|
|
||||||
|
LITELLM_PROXY_API_KEY = "sk-1234"
|
||||||
|
LITELLM_PROXY_BASE = "http://0.0.0.0:4000/vertex-ai"
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class CredentialsWrapper(Credentials):
|
||||||
|
def __init__(self, token=None):
|
||||||
|
super().__init__()
|
||||||
|
self.token = token
|
||||||
|
self.expiry = None # or set to a future date if needed
|
||||||
|
|
||||||
|
def refresh(self, request):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def apply(self, headers, token=None):
|
||||||
|
headers["Authorization"] = f"Bearer {self.token}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expired(self):
|
||||||
|
return False # Always consider the token as non-expired
|
||||||
|
|
||||||
|
@property
|
||||||
|
def valid(self):
|
||||||
|
return True # Always consider the credentials as valid
|
||||||
|
|
||||||
|
|
||||||
|
credentials = CredentialsWrapper(token=LITELLM_PROXY_API_KEY)
|
||||||
|
|
||||||
|
vertexai.init(
|
||||||
|
project="adroit-crow-413218",
|
||||||
|
location="us-central1",
|
||||||
|
api_endpoint=LITELLM_PROXY_BASE,
|
||||||
|
credentials=credentials,
|
||||||
|
api_transport="rest",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(developer): Update project
|
||||||
|
vertexai.init(project=PROJECT_ID, location="us-central1")
|
||||||
|
|
||||||
|
sft_tuning_job = sft.train(
|
||||||
|
source_model="gemini-1.0-pro-002",
|
||||||
|
train_dataset="gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Polling for job completion
|
||||||
|
while not sft_tuning_job.has_ended:
|
||||||
|
time.sleep(60)
|
||||||
|
sft_tuning_job.refresh()
|
||||||
|
|
||||||
|
print(sft_tuning_job.tuned_model_name)
|
||||||
|
print(sft_tuning_job.tuned_model_endpoint_name)
|
||||||
|
print(sft_tuning_job.experiment)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="curl" label="Curl">
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:4000/vertex-ai/tuningJobs \
|
curl http://localhost:4000/vertex-ai/tuningJobs \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
|
@ -279,4 +503,8 @@ curl http://localhost:4000/vertex-ai/tuningJobs \
|
||||||
"training_dataset_uri": "gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl"
|
"training_dataset_uri": "gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl"
|
||||||
}
|
}
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
</Tabs>
|
||||||
|
|
|
@ -1694,7 +1694,7 @@ vertexai.init(
|
||||||
api_endpoint=LITELLM_PROXY_BASE,
|
api_endpoint=LITELLM_PROXY_BASE,
|
||||||
credentials = credentials,
|
credentials = credentials,
|
||||||
api_transport="rest",
|
api_transport="rest",
|
||||||
request_metadata=[("Authorization", f"Bearer {LITELLM_PROXY_API_KEY}")],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
model = MultiModalEmbeddingModel.from_pretrained("multimodalembedding")
|
model = MultiModalEmbeddingModel.from_pretrained("multimodalembedding")
|
||||||
|
|
|
@ -42,7 +42,6 @@ vertexai.init(
|
||||||
api_endpoint=LITELLM_PROXY_BASE,
|
api_endpoint=LITELLM_PROXY_BASE,
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
api_transport="rest",
|
api_transport="rest",
|
||||||
request_metadata=[("Authorization", f"Bearer {LITELLM_PROXY_API_KEY}")],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
model = MultiModalEmbeddingModel.from_pretrained("multimodalembedding")
|
model = MultiModalEmbeddingModel.from_pretrained("multimodalembedding")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue