mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
(Feat - Batches API) add support for retrieving vertex api batch jobs (#7661)
* add _async_retrieve_batch * fix aretrieve_batch * fix _get_batch_id_from_vertex_ai_batch_response * fix batches docs
This commit is contained in:
parent
2507c275f6
commit
13f364682d
5 changed files with 177 additions and 2 deletions
|
@ -124,3 +124,91 @@ class VertexAIBatchPrediction(VertexLLM):
|
|||
"""Return the base url for the vertex garden models"""
|
||||
# POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION/batchPredictionJobs
|
||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/batchPredictionJobs"
|
||||
|
||||
def retrieve_batch(
|
||||
self,
|
||||
_is_async: bool,
|
||||
batch_id: str,
|
||||
api_base: Optional[str],
|
||||
vertex_credentials: Optional[str],
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
|
||||
sync_handler = _get_httpx_client()
|
||||
|
||||
access_token, project_id = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
default_api_base = self.create_vertex_url(
|
||||
vertex_location=vertex_location or "us-central1",
|
||||
vertex_project=vertex_project or project_id,
|
||||
)
|
||||
|
||||
# Append batch_id to the URL
|
||||
default_api_base = f"{default_api_base}/{batch_id}"
|
||||
|
||||
if len(default_api_base.split(":")) > 1:
|
||||
endpoint = default_api_base.split(":")[-1]
|
||||
else:
|
||||
endpoint = ""
|
||||
|
||||
_, api_base = self._check_custom_proxy(
|
||||
api_base=api_base,
|
||||
custom_llm_provider="vertex_ai",
|
||||
gemini_api_key=None,
|
||||
endpoint=endpoint,
|
||||
stream=None,
|
||||
auth_header=None,
|
||||
url=default_api_base,
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
}
|
||||
|
||||
if _is_async is True:
|
||||
return self._async_retrieve_batch(
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
response = sync_handler.get(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||
|
||||
_json_response = response.json()
|
||||
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
|
||||
response=_json_response
|
||||
)
|
||||
return vertex_batch_response
|
||||
|
||||
async def _async_retrieve_batch(
|
||||
self,
|
||||
api_base: str,
|
||||
headers: Dict[str, str],
|
||||
) -> Batch:
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.VERTEX_AI,
|
||||
)
|
||||
response = await client.get(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Error: {response.status_code} {response.text}")
|
||||
|
||||
_json_response = response.json()
|
||||
vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
|
||||
response=_json_response
|
||||
)
|
||||
return vertex_batch_response
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue