mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
* fix(vertex_llm_base.py): handle credentials passed in as dictionary * fix(router.py): support vertex credentials as json dict * test(test_vertex.py): allows easier testing mock anthropic thinking response for vertex ai * test(vertex_ai_partner_models/): don't remove "@" from model breaks anthropic cost calculation * test: move testing * fix: fix linting error * fix: fix linting error * fix(vertex_ai_partner_models/main.py): split @ for codestral model * test: fix test * fix: fix stripping "@" on mistral models * fix: fix test * test: fix test
217 lines
7 KiB
Python
217 lines
7 KiB
Python
import json
|
|
from typing import Any, Coroutine, Dict, Optional, Union
|
|
|
|
import httpx
|
|
|
|
import litellm
|
|
from litellm.llms.custom_httpx.http_handler import (
|
|
_get_httpx_client,
|
|
get_async_httpx_client,
|
|
)
|
|
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
|
|
from litellm.types.llms.openai import Batch, CreateBatchRequest
|
|
from litellm.types.llms.vertex_ai import (
|
|
VERTEX_CREDENTIALS_TYPES,
|
|
VertexAIBatchPredictionJob,
|
|
)
|
|
|
|
from .transformation import VertexAIBatchTransformation
|
|
|
|
|
|
class VertexAIBatchPrediction(VertexLLM):
|
|
def __init__(self, gcs_bucket_name: str, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.gcs_bucket_name = gcs_bucket_name
|
|
|
|
def create_batch(
|
|
self,
|
|
_is_async: bool,
|
|
create_batch_data: CreateBatchRequest,
|
|
api_base: Optional[str],
|
|
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
|
vertex_project: Optional[str],
|
|
vertex_location: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
|
|
|
|
sync_handler = _get_httpx_client()
|
|
|
|
access_token, project_id = self._ensure_access_token(
|
|
credentials=vertex_credentials,
|
|
project_id=vertex_project,
|
|
custom_llm_provider="vertex_ai",
|
|
)
|
|
|
|
default_api_base = self.create_vertex_url(
|
|
vertex_location=vertex_location or "us-central1",
|
|
vertex_project=vertex_project or project_id,
|
|
)
|
|
|
|
if len(default_api_base.split(":")) > 1:
|
|
endpoint = default_api_base.split(":")[-1]
|
|
else:
|
|
endpoint = ""
|
|
|
|
_, api_base = self._check_custom_proxy(
|
|
api_base=api_base,
|
|
custom_llm_provider="vertex_ai",
|
|
gemini_api_key=None,
|
|
endpoint=endpoint,
|
|
stream=None,
|
|
auth_header=None,
|
|
url=default_api_base,
|
|
)
|
|
|
|
headers = {
|
|
"Content-Type": "application/json; charset=utf-8",
|
|
"Authorization": f"Bearer {access_token}",
|
|
}
|
|
|
|
vertex_batch_request: VertexAIBatchPredictionJob = (
|
|
VertexAIBatchTransformation.transform_openai_batch_request_to_vertex_ai_batch_request(
|
|
request=create_batch_data
|
|
)
|
|
)
|
|
|
|
if _is_async is True:
|
|
return self._async_create_batch(
|
|
vertex_batch_request=vertex_batch_request,
|
|
api_base=api_base,
|
|
headers=headers,
|
|
)
|
|
|
|
response = sync_handler.post(
|
|
url=api_base,
|
|
headers=headers,
|
|
data=json.dumps(vertex_batch_request),
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
raise Exception(f"Error: {response.status_code} {response.text}")
|
|
|
|
_json_response = response.json()
|
|
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
|
|
response=_json_response
|
|
)
|
|
return vertex_batch_response
|
|
|
|
async def _async_create_batch(
|
|
self,
|
|
vertex_batch_request: VertexAIBatchPredictionJob,
|
|
api_base: str,
|
|
headers: Dict[str, str],
|
|
) -> Batch:
|
|
client = get_async_httpx_client(
|
|
llm_provider=litellm.LlmProviders.VERTEX_AI,
|
|
)
|
|
response = await client.post(
|
|
url=api_base,
|
|
headers=headers,
|
|
data=json.dumps(vertex_batch_request),
|
|
)
|
|
if response.status_code != 200:
|
|
raise Exception(f"Error: {response.status_code} {response.text}")
|
|
|
|
_json_response = response.json()
|
|
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
|
|
response=_json_response
|
|
)
|
|
return vertex_batch_response
|
|
|
|
def create_vertex_url(
|
|
self,
|
|
vertex_location: str,
|
|
vertex_project: str,
|
|
) -> str:
|
|
"""Return the base url for the vertex garden models"""
|
|
# POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION/batchPredictionJobs
|
|
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/batchPredictionJobs"
|
|
|
|
def retrieve_batch(
|
|
self,
|
|
_is_async: bool,
|
|
batch_id: str,
|
|
api_base: Optional[str],
|
|
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
|
vertex_project: Optional[str],
|
|
vertex_location: Optional[str],
|
|
timeout: Union[float, httpx.Timeout],
|
|
max_retries: Optional[int],
|
|
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
|
|
sync_handler = _get_httpx_client()
|
|
|
|
access_token, project_id = self._ensure_access_token(
|
|
credentials=vertex_credentials,
|
|
project_id=vertex_project,
|
|
custom_llm_provider="vertex_ai",
|
|
)
|
|
|
|
default_api_base = self.create_vertex_url(
|
|
vertex_location=vertex_location or "us-central1",
|
|
vertex_project=vertex_project or project_id,
|
|
)
|
|
|
|
# Append batch_id to the URL
|
|
default_api_base = f"{default_api_base}/{batch_id}"
|
|
|
|
if len(default_api_base.split(":")) > 1:
|
|
endpoint = default_api_base.split(":")[-1]
|
|
else:
|
|
endpoint = ""
|
|
|
|
_, api_base = self._check_custom_proxy(
|
|
api_base=api_base,
|
|
custom_llm_provider="vertex_ai",
|
|
gemini_api_key=None,
|
|
endpoint=endpoint,
|
|
stream=None,
|
|
auth_header=None,
|
|
url=default_api_base,
|
|
)
|
|
|
|
headers = {
|
|
"Content-Type": "application/json; charset=utf-8",
|
|
"Authorization": f"Bearer {access_token}",
|
|
}
|
|
|
|
if _is_async is True:
|
|
return self._async_retrieve_batch(
|
|
api_base=api_base,
|
|
headers=headers,
|
|
)
|
|
|
|
response = sync_handler.get(
|
|
url=api_base,
|
|
headers=headers,
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
raise Exception(f"Error: {response.status_code} {response.text}")
|
|
|
|
_json_response = response.json()
|
|
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
|
|
response=_json_response
|
|
)
|
|
return vertex_batch_response
|
|
|
|
async def _async_retrieve_batch(
|
|
self,
|
|
api_base: str,
|
|
headers: Dict[str, str],
|
|
) -> Batch:
|
|
client = get_async_httpx_client(
|
|
llm_provider=litellm.LlmProviders.VERTEX_AI,
|
|
)
|
|
response = await client.get(
|
|
url=api_base,
|
|
headers=headers,
|
|
)
|
|
if response.status_code != 200:
|
|
raise Exception(f"Error: {response.status_code} {response.text}")
|
|
|
|
_json_response = response.json()
|
|
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
|
|
response=_json_response
|
|
)
|
|
return vertex_batch_response
|