forked from phoenix/litellm-mirror
Merge pull request #5481 from BerriAI/litellm_track_imagen_spend_logs
[Feat-Proxy] track imagen /predict in LiteLLM spend logs
This commit is contained in:
commit
7aaabb51ef
12 changed files with 342 additions and 246 deletions
|
@ -1833,6 +1833,19 @@ response = await litellm.aimage_generation(
|
|||
)
|
||||
```
|
||||
|
||||
### Supported Image Generation Models
|
||||
|
||||
| Model Name | FUsage |
|
||||
|------------------------------|--------------------------------------------------------------|
|
||||
| `imagen-3.0-generate-001` | `litellm.image_generation('vertex_ai/imagen-3.0-generate-001', prompt)` |
|
||||
| `imagen-3.0-fast-generate-001` | `litellm.image_generation('vertex_ai/imagen-3.0-fast-generate-001', prompt)` |
|
||||
| `imagegeneration@006` | `litellm.image_generation('vertex_ai/imagegeneration@006', prompt)` |
|
||||
| `imagegeneration@005` | `litellm.image_generation('vertex_ai/imagegeneration@005', prompt)` |
|
||||
| `imagegeneration@002` | `litellm.image_generation('vertex_ai/imagegeneration@002', prompt)` |
|
||||
|
||||
|
||||
|
||||
|
||||
## **Text to Speech APIs**
|
||||
|
||||
:::info
|
||||
|
|
|
@ -356,6 +356,7 @@ vertex_language_models: List = []
|
|||
vertex_vision_models: List = []
|
||||
vertex_chat_models: List = []
|
||||
vertex_code_chat_models: List = []
|
||||
vertex_ai_image_models: List = []
|
||||
vertex_text_models: List = []
|
||||
vertex_code_text_models: List = []
|
||||
vertex_embedding_models: List = []
|
||||
|
@ -416,6 +417,9 @@ for key, value in model_cost.items():
|
|||
elif value.get("litellm_provider") == "vertex_ai-ai21_models":
|
||||
key = key.replace("vertex_ai/", "")
|
||||
vertex_ai_ai21_models.append(key)
|
||||
elif value.get("litellm_provider") == "vertex_ai-image-models":
|
||||
key = key.replace("vertex_ai/", "")
|
||||
vertex_ai_image_models.append(key)
|
||||
elif value.get("litellm_provider") == "ai21":
|
||||
if value.get("mode") == "chat":
|
||||
ai21_chat_models.append(key)
|
||||
|
|
|
@ -24,7 +24,7 @@ from litellm.llms.anthropic.cost_calculation import (
|
|||
)
|
||||
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
||||
from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS
|
||||
from litellm.types.utils import Usage
|
||||
from litellm.types.utils import PassthroughCallTypes, Usage
|
||||
from litellm.utils import (
|
||||
CallTypes,
|
||||
CostPerToken,
|
||||
|
@ -625,6 +625,7 @@ def completion_cost(
|
|||
if (
|
||||
call_type == CallTypes.image_generation.value
|
||||
or call_type == CallTypes.aimage_generation.value
|
||||
or call_type == PassthroughCallTypes.passthrough_image_generation.value
|
||||
):
|
||||
### IMAGE GENERATION COST CALCULATION ###
|
||||
if custom_llm_provider == "vertex_ai":
|
||||
|
|
|
@ -13,7 +13,6 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
|||
|
||||
import httpx # type: ignore
|
||||
import requests # type: ignore
|
||||
from openai.types.image import Image
|
||||
|
||||
import litellm
|
||||
import litellm.litellm_core_utils
|
||||
|
@ -1488,248 +1487,6 @@ class VertexLLM(BaseLLM):
|
|||
encoding=encoding,
|
||||
)
|
||||
|
||||
def image_generation(
|
||||
self,
|
||||
prompt: str,
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
vertex_credentials: Optional[str],
|
||||
model_response: litellm.ImageResponse,
|
||||
model: Optional[
|
||||
str
|
||||
] = "imagegeneration", # vertex ai uses imagegeneration as the default model
|
||||
client: Optional[Any] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[int] = None,
|
||||
logging_obj=None,
|
||||
aimg_generation=False,
|
||||
):
|
||||
if aimg_generation is True:
|
||||
return self.aimage_generation(
|
||||
prompt=prompt,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
model=model,
|
||||
client=client,
|
||||
optional_params=optional_params,
|
||||
timeout=timeout,
|
||||
logging_obj=logging_obj,
|
||||
model_response=model_response,
|
||||
)
|
||||
|
||||
if client is None:
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
_httpx_timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = _httpx_timeout
|
||||
else:
|
||||
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
|
||||
sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore
|
||||
else:
|
||||
sync_handler = client # type: ignore
|
||||
|
||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
|
||||
|
||||
auth_header, _ = self._ensure_access_token(
|
||||
credentials=vertex_credentials, project_id=vertex_project
|
||||
)
|
||||
optional_params = optional_params or {
|
||||
"sampleCount": 1
|
||||
} # default optional params
|
||||
|
||||
request_data = {
|
||||
"instances": [{"prompt": prompt}],
|
||||
"parameters": optional_params,
|
||||
}
|
||||
|
||||
request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\""
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
|
||||
response = sync_handler.post(
|
||||
url=url,
|
||||
headers={
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Authorization": f"Bearer {auth_header}",
|
||||
},
|
||||
data=json.dumps(request_data),
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||
"""
|
||||
Vertex AI Image generation response example:
|
||||
{
|
||||
"predictions": [
|
||||
{
|
||||
"bytesBase64Encoded": "BASE64_IMG_BYTES",
|
||||
"mimeType": "image/png"
|
||||
},
|
||||
{
|
||||
"mimeType": "image/png",
|
||||
"bytesBase64Encoded": "BASE64_IMG_BYTES"
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
|
||||
_json_response = response.json()
|
||||
if "predictions" not in _json_response:
|
||||
raise litellm.InternalServerError(
|
||||
message=f"image generation response does not contain 'predictions', got {_json_response}",
|
||||
llm_provider="vertex_ai",
|
||||
model=model,
|
||||
)
|
||||
_predictions = _json_response["predictions"]
|
||||
|
||||
_response_data: List[Image] = []
|
||||
for _prediction in _predictions:
|
||||
_bytes_base64_encoded = _prediction["bytesBase64Encoded"]
|
||||
image_object = Image(b64_json=_bytes_base64_encoded)
|
||||
_response_data.append(image_object)
|
||||
|
||||
model_response.data = _response_data
|
||||
|
||||
return model_response
|
||||
|
||||
async def aimage_generation(
|
||||
self,
|
||||
prompt: str,
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
vertex_credentials: Optional[str],
|
||||
model_response: litellm.ImageResponse,
|
||||
model: Optional[
|
||||
str
|
||||
] = "imagegeneration", # vertex ai uses imagegeneration as the default model
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[int] = None,
|
||||
logging_obj=None,
|
||||
):
|
||||
response = None
|
||||
if client is None:
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
_httpx_timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = _httpx_timeout
|
||||
else:
|
||||
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
|
||||
self.async_handler = AsyncHTTPHandler(**_params) # type: ignore
|
||||
else:
|
||||
self.async_handler = client # type: ignore
|
||||
|
||||
# make POST request to
|
||||
# https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict
|
||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
|
||||
|
||||
"""
|
||||
Docs link: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218
|
||||
curl -X POST \
|
||||
-H "Authorization: Bearer $(gcloud auth print-access-token)" \
|
||||
-H "Content-Type: application/json; charset=utf-8" \
|
||||
-d {
|
||||
"instances": [
|
||||
{
|
||||
"prompt": "a cat"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"sampleCount": 1
|
||||
}
|
||||
} \
|
||||
"https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
|
||||
"""
|
||||
auth_header, _ = self._ensure_access_token(
|
||||
credentials=vertex_credentials, project_id=vertex_project
|
||||
)
|
||||
optional_params = optional_params or {
|
||||
"sampleCount": 1
|
||||
} # default optional params
|
||||
|
||||
request_data = {
|
||||
"instances": [{"prompt": prompt}],
|
||||
"parameters": optional_params,
|
||||
}
|
||||
|
||||
request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\""
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
|
||||
response = await self.async_handler.post(
|
||||
url=url,
|
||||
headers={
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Authorization": f"Bearer {auth_header}",
|
||||
},
|
||||
data=json.dumps(request_data),
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||
"""
|
||||
Vertex AI Image generation response example:
|
||||
{
|
||||
"predictions": [
|
||||
{
|
||||
"bytesBase64Encoded": "BASE64_IMG_BYTES",
|
||||
"mimeType": "image/png"
|
||||
},
|
||||
{
|
||||
"mimeType": "image/png",
|
||||
"bytesBase64Encoded": "BASE64_IMG_BYTES"
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
|
||||
_json_response = response.json()
|
||||
|
||||
if "predictions" not in _json_response:
|
||||
raise litellm.InternalServerError(
|
||||
message=f"image generation response does not contain 'predictions', got {_json_response}",
|
||||
llm_provider="vertex_ai",
|
||||
model=model,
|
||||
)
|
||||
|
||||
_predictions = _json_response["predictions"]
|
||||
|
||||
_response_data: List[Image] = []
|
||||
for _prediction in _predictions:
|
||||
_bytes_base64_encoded = _prediction["bytesBase64Encoded"]
|
||||
image_object = Image(b64_json=_bytes_base64_encoded)
|
||||
_response_data.append(image_object)
|
||||
|
||||
model_response.data = _response_data
|
||||
|
||||
return model_response
|
||||
|
||||
|
||||
class ModelResponseIterator:
|
||||
def __init__(self, streaming_response, sync_stream: bool):
|
||||
|
|
|
@ -0,0 +1,225 @@
|
|||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
from openai.types.image import Image
|
||||
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
VertexLLM,
|
||||
)
|
||||
|
||||
|
||||
class VertexImageGeneration(VertexLLM):
|
||||
def process_image_generation_response(
|
||||
self,
|
||||
json_response: Dict[str, Any],
|
||||
model_response: litellm.ImageResponse,
|
||||
model: Optional[str] = None,
|
||||
) -> litellm.ImageResponse:
|
||||
if "predictions" not in json_response:
|
||||
raise litellm.InternalServerError(
|
||||
message=f"image generation response does not contain 'predictions', got {json_response}",
|
||||
llm_provider="vertex_ai",
|
||||
model=model,
|
||||
)
|
||||
|
||||
predictions = json_response["predictions"]
|
||||
response_data: List[Image] = []
|
||||
|
||||
for prediction in predictions:
|
||||
bytes_base64_encoded = prediction["bytesBase64Encoded"]
|
||||
image_object = Image(b64_json=bytes_base64_encoded)
|
||||
response_data.append(image_object)
|
||||
|
||||
model_response.data = response_data
|
||||
return model_response
|
||||
|
||||
def image_generation(
|
||||
self,
|
||||
prompt: str,
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
vertex_credentials: Optional[str],
|
||||
model_response: litellm.ImageResponse,
|
||||
model: Optional[
|
||||
str
|
||||
] = "imagegeneration", # vertex ai uses imagegeneration as the default model
|
||||
client: Optional[Any] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[int] = None,
|
||||
logging_obj=None,
|
||||
aimg_generation=False,
|
||||
):
|
||||
if aimg_generation is True:
|
||||
return self.aimage_generation(
|
||||
prompt=prompt,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
model=model,
|
||||
client=client,
|
||||
optional_params=optional_params,
|
||||
timeout=timeout,
|
||||
logging_obj=logging_obj,
|
||||
model_response=model_response,
|
||||
)
|
||||
|
||||
if client is None:
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
_httpx_timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = _httpx_timeout
|
||||
else:
|
||||
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
|
||||
sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore
|
||||
else:
|
||||
sync_handler = client # type: ignore
|
||||
|
||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
|
||||
|
||||
auth_header, _ = self._ensure_access_token(
|
||||
credentials=vertex_credentials, project_id=vertex_project
|
||||
)
|
||||
optional_params = optional_params or {
|
||||
"sampleCount": 1
|
||||
} # default optional params
|
||||
|
||||
request_data = {
|
||||
"instances": [{"prompt": prompt}],
|
||||
"parameters": optional_params,
|
||||
}
|
||||
|
||||
request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\""
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
|
||||
response = sync_handler.post(
|
||||
url=url,
|
||||
headers={
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Authorization": f"Bearer {auth_header}",
|
||||
},
|
||||
data=json.dumps(request_data),
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||
|
||||
json_response = response.json()
|
||||
return self.process_image_generation_response(
|
||||
json_response, model_response, model
|
||||
)
|
||||
|
||||
async def aimage_generation(
|
||||
self,
|
||||
prompt: str,
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
vertex_credentials: Optional[str],
|
||||
model_response: litellm.ImageResponse,
|
||||
model: Optional[
|
||||
str
|
||||
] = "imagegeneration", # vertex ai uses imagegeneration as the default model
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[int] = None,
|
||||
logging_obj=None,
|
||||
):
|
||||
response = None
|
||||
if client is None:
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
_httpx_timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = _httpx_timeout
|
||||
else:
|
||||
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
|
||||
|
||||
self.async_handler = AsyncHTTPHandler(**_params) # type: ignore
|
||||
else:
|
||||
self.async_handler = client # type: ignore
|
||||
|
||||
# make POST request to
|
||||
# https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict
|
||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
|
||||
|
||||
"""
|
||||
Docs link: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218
|
||||
curl -X POST \
|
||||
-H "Authorization: Bearer $(gcloud auth print-access-token)" \
|
||||
-H "Content-Type: application/json; charset=utf-8" \
|
||||
-d {
|
||||
"instances": [
|
||||
{
|
||||
"prompt": "a cat"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"sampleCount": 1
|
||||
}
|
||||
} \
|
||||
"https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
|
||||
"""
|
||||
auth_header, _ = self._ensure_access_token(
|
||||
credentials=vertex_credentials, project_id=vertex_project
|
||||
)
|
||||
optional_params = optional_params or {
|
||||
"sampleCount": 1
|
||||
} # default optional params
|
||||
|
||||
request_data = {
|
||||
"instances": [{"prompt": prompt}],
|
||||
"parameters": optional_params,
|
||||
}
|
||||
|
||||
request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\""
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
|
||||
response = await self.async_handler.post(
|
||||
url=url,
|
||||
headers={
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Authorization": f"Bearer {auth_header}",
|
||||
},
|
||||
data=json.dumps(request_data),
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||
|
||||
json_response = response.json()
|
||||
return self.process_image_generation_response(
|
||||
json_response, model_response, model
|
||||
)
|
||||
|
||||
def is_image_generation_response(self, json_response: Dict[str, Any]) -> bool:
|
||||
if "predictions" in json_response:
|
||||
if "bytesBase64Encoded" in json_response["predictions"][0]:
|
||||
return True
|
||||
return False
|
|
@ -281,3 +281,33 @@ async def async_embedding(
|
|||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
|
||||
async def transform_vertex_response_to_openai(
|
||||
response: dict, model: str, model_response: litellm.EmbeddingResponse
|
||||
) -> litellm.EmbeddingResponse:
|
||||
|
||||
_predictions = response["predictions"]
|
||||
|
||||
embedding_response = []
|
||||
input_tokens: int = 0
|
||||
for idx, element in enumerate(_predictions):
|
||||
|
||||
embedding = element["embeddings"]
|
||||
embedding_response.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": embedding["values"],
|
||||
}
|
||||
)
|
||||
input_tokens += embedding["statistics"]["token_count"]
|
||||
|
||||
model_response.object = "list"
|
||||
model_response.data = embedding_response
|
||||
model_response.model = model
|
||||
usage = Usage(
|
||||
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
|
|
@ -126,6 +126,9 @@ from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gem
|
|||
from .llms.vertex_ai_and_google_ai_studio.gemini_embeddings.batch_embed_content_handler import (
|
||||
GoogleBatchEmbeddings,
|
||||
)
|
||||
from .llms.vertex_ai_and_google_ai_studio.image_generation.image_generation_handler import (
|
||||
VertexImageGeneration,
|
||||
)
|
||||
from .llms.vertex_ai_and_google_ai_studio.multimodal_embeddings.embedding_handler import (
|
||||
VertexMultimodalEmbedding,
|
||||
)
|
||||
|
@ -180,6 +183,7 @@ bedrock_converse_chat_completion = BedrockConverseLLM()
|
|||
bedrock_embedding = BedrockEmbedding()
|
||||
vertex_chat_completion = VertexLLM()
|
||||
vertex_multimodal_embedding = VertexMultimodalEmbedding()
|
||||
vertex_image_generation = VertexImageGeneration()
|
||||
google_batch_embeddings = GoogleBatchEmbeddings()
|
||||
vertex_partner_models_chat_completion = VertexAIPartnerModels()
|
||||
vertex_text_to_speech = VertexTextToSpeechAPI()
|
||||
|
@ -4538,7 +4542,7 @@ def image_generation(
|
|||
or optional_params.pop("vertex_ai_credentials", None)
|
||||
or get_secret("VERTEXAI_CREDENTIALS")
|
||||
)
|
||||
model_response = vertex_chat_completion.image_generation(
|
||||
model_response = vertex_image_generation.image_generation(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
timeout=timeout,
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import re
|
||||
from datetime import datetime
|
||||
from typing import Union
|
||||
|
||||
import httpx
|
||||
|
||||
|
@ -103,3 +104,52 @@ class PassThroughEndpointLogging:
|
|||
end_time=end_time,
|
||||
cache_hit=cache_hit,
|
||||
)
|
||||
elif "predict" in url_route:
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.image_generation.image_generation_handler import (
|
||||
VertexImageGeneration,
|
||||
)
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_embeddings.embedding_handler import (
|
||||
transform_vertex_response_to_openai,
|
||||
)
|
||||
from litellm.types.utils import PassthroughCallTypes
|
||||
|
||||
vertex_image_generation_class = VertexImageGeneration()
|
||||
|
||||
model = self.extract_model_from_url(url_route)
|
||||
_json_response = httpx_response.json()
|
||||
|
||||
litellm_prediction_response: Union[
|
||||
litellm.ModelResponse, litellm.EmbeddingResponse, litellm.ImageResponse
|
||||
] = litellm.ModelResponse()
|
||||
if vertex_image_generation_class.is_image_generation_response(
|
||||
_json_response
|
||||
):
|
||||
litellm_prediction_response = (
|
||||
vertex_image_generation_class.process_image_generation_response(
|
||||
_json_response,
|
||||
model_response=litellm.ImageResponse(),
|
||||
model=model,
|
||||
)
|
||||
)
|
||||
|
||||
logging_obj.call_type = (
|
||||
PassthroughCallTypes.passthrough_image_generation.value
|
||||
)
|
||||
else:
|
||||
litellm_prediction_response = await transform_vertex_response_to_openai(
|
||||
response=_json_response,
|
||||
model=model,
|
||||
model_response=litellm.EmbeddingResponse(),
|
||||
)
|
||||
if isinstance(litellm_prediction_response, litellm.EmbeddingResponse):
|
||||
litellm_prediction_response.model = model
|
||||
|
||||
logging_obj.model = model
|
||||
logging_obj.model_call_details["model"] = logging_obj.model
|
||||
|
||||
await logging_obj.async_success_handler(
|
||||
result=litellm_prediction_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=cache_hit,
|
||||
)
|
||||
|
|
|
@ -844,7 +844,7 @@ async def _PROXY_track_cost_callback(
|
|||
kwargs["stream"] == True and "complete_streaming_response" in kwargs
|
||||
):
|
||||
raise Exception(
|
||||
f"Model not in litellm model cost map. Add custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
|
||||
f"Model not in litellm model cost map. Passed model = {kwargs.get('model')} - Add custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"error in tracking cost callback - {traceback.format_exc()}"
|
||||
|
|
|
@ -70,6 +70,12 @@ def test_get_llm_provider_deepseek_custom_api_base():
|
|||
os.environ.pop("DEEPSEEK_API_BASE")
|
||||
|
||||
|
||||
def test_get_llm_provider_vertex_ai_image_models():
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
|
||||
model="imagegeneration@006", custom_llm_provider=None
|
||||
)
|
||||
assert custom_llm_provider == "vertex_ai"
|
||||
|
||||
def test_get_llm_provider_ai21_chat():
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
|
||||
model="jamba-1.5-large",
|
||||
|
@ -93,3 +99,4 @@ def test_get_llm_provider_ai21_chat_test2():
|
|||
assert custom_llm_provider == "ai21_chat"
|
||||
assert model == "jamba-1.5-large"
|
||||
assert api_base == "https://api.ai21.com/studio/v1"
|
||||
|
||||
|
|
|
@ -119,6 +119,10 @@ class CallTypes(Enum):
|
|||
speech = "speech"
|
||||
|
||||
|
||||
class PassthroughCallTypes(Enum):
|
||||
passthrough_image_generation = "passthrough-image-generation"
|
||||
|
||||
|
||||
class TopLogprob(OpenAIObject):
|
||||
token: str
|
||||
"""The token."""
|
||||
|
|
|
@ -4975,6 +4975,7 @@ def get_llm_provider(
|
|||
or model in litellm.vertex_language_models
|
||||
or model in litellm.vertex_embedding_models
|
||||
or model in litellm.vertex_vision_models
|
||||
or model in litellm.vertex_ai_image_models
|
||||
):
|
||||
custom_llm_provider = "vertex_ai"
|
||||
## ai21
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue