diff --git a/docs/my-website/docs/fine_tuning.md b/docs/my-website/docs/fine_tuning.md index c69f4c1e66..fd3cbc792d 100644 --- a/docs/my-website/docs/fine_tuning.md +++ b/docs/my-website/docs/fine_tuning.md @@ -124,7 +124,7 @@ ft_job = await client.fine_tuning.jobs.create( ``` - + ```shell curl http://localhost:4000/v1/fine_tuning/jobs \ @@ -136,6 +136,28 @@ curl http://localhost:4000/v1/fine_tuning/jobs \ "training_file": "gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl" }' ``` + + + + +:::info + +Use this to create Fine tuning Jobs in [the Vertex AI API Format](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning#create-tuning) + +::: + +```shell +curl http://localhost:4000/v1/projects/tuningJobs \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-1234" \ + -d '{ + "baseModel": "gemini-1.0-pro-002", + "supervisedTuningSpec" : { + "training_dataset_uri": "gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl" + } +}' +``` + diff --git a/docs/my-website/docs/index.md b/docs/my-website/docs/index.md index a560ecf76d..2417b0cff4 100644 --- a/docs/my-website/docs/index.md +++ b/docs/my-website/docs/index.md @@ -17,7 +17,33 @@ You can use litellm through either: 1. [LiteLLM Proxy Server](#openai-proxy) - Server to call 100+ LLMs, load balance, cost tracking across projects 2. [LiteLLM python SDK](#basic-usage) - Python Client to call 100+ LLMs, load balance, cost tracking -## LiteLLM Python SDK +### When to use LiteLLM Proxy Server + +:::tip + +Use LiteLLM Proxy Server if you want a **central service to access multiple LLMs** + +Typically used by Gen AI Enablement / ML PLatform Teams + +::: + + - LiteLLM Proxy gives you a unified interface to access multiple LLMs (100+ LLMs) + - Track LLM Usage and setup guardrails + - Customize Logging, Guardrails, Caching per project + +### When to use LiteLLM Python SDK + +:::tip + + Use LiteLLM Python SDK if you want to use LiteLLM in your **python code** + +Typically used by developers building llm projects + +::: + + - LiteLLM SDK gives you a unified interface to access multiple LLMs (100+ LLMs) + - Retry/fallback logic across multiple deployments (e.g. Azure/OpenAI) - [Router](https://docs.litellm.ai/docs/routing) + ### Basic usage diff --git a/docs/my-website/docs/proxy/reliability.md b/docs/my-website/docs/proxy/reliability.md index a3f03b3d76..cb6550a478 100644 --- a/docs/my-website/docs/proxy/reliability.md +++ b/docs/my-website/docs/proxy/reliability.md @@ -50,7 +50,7 @@ Detailed information about [routing strategies can be found here](../routing) $ litellm --config /path/to/config.yaml ``` -### Test - Load Balancing +### Test - Simple Call Here requests with model=gpt-3.5-turbo will be routed across multiple instances of azure/gpt-3.5-turbo @@ -138,6 +138,27 @@ print(response) +### Test - Loadbalancing + +In this request, the following will occur: +1. A rate limit exception will be raised +2. LiteLLM proxy will retry the request on the model group (default is 3). + +```bash +curl -X POST 'http://0.0.0.0:4000/chat/completions' \ +-H 'Content-Type: application/json' \ +-H 'Authorization: Bearer sk-1234' \ +-d '{ + "model": "gpt-3.5-turbo", + "messages": [ + {"role": "user", "content": "Hi there!"} + ], + "mock_testing_rate_limit_error": true +}' +``` + +[**See Code**](https://github.com/BerriAI/litellm/blob/6b8806b45f970cb2446654d2c379f8dcaa93ce3c/litellm/router.py#L2535) + ### Test - Client Side Fallbacks In this request the following will occur: 1. The request to `model="zephyr-beta"` will fail diff --git a/docs/my-website/docs/proxy/user_keys.md b/docs/my-website/docs/proxy/user_keys.md index 75e547d17e..79d019a20c 100644 --- a/docs/my-website/docs/proxy/user_keys.md +++ b/docs/my-website/docs/proxy/user_keys.md @@ -23,6 +23,9 @@ LiteLLM Proxy is **Azure OpenAI-compatible**: LiteLLM Proxy is **Anthropic-compatible**: * /messages +LiteLLM Proxy is **Vertex AI compatible**: +- [Supports ALL Vertex Endpoints](../vertex_ai) + This doc covers: * /chat/completion diff --git a/docs/my-website/docs/vertex_ai.md b/docs/my-website/docs/vertex_ai.md new file mode 100644 index 0000000000..d9c8616a0b --- /dev/null +++ b/docs/my-website/docs/vertex_ai.md @@ -0,0 +1,93 @@ +# [BETA] Vertex AI Endpoints + +## Supported API Endpoints + +- Gemini API +- Embeddings API +- Imagen API +- Code Completion API +- Batch prediction API +- Tuning API +- CountTokens API + +## Quick Start Usage + +#### 1. Set `default_vertex_config` on your `config.yaml` + + +Add the following credentials to your litellm config.yaml to use the Vertex AI endpoints. + +```yaml +default_vertex_config: + vertex_project: "adroit-crow-413218" + vertex_location: "us-central1" + vertex_credentials: "/Users/ishaanjaffer/Downloads/adroit-crow-413218-a956eef1a2a8.json" # Add path to service account.json +``` + +#### 2. Start litellm proxy + +```shell +litellm --config /path/to/config.yaml +``` + +#### 3. Test it + +```shell +curl http://localhost:4000/vertex-ai/publishers/google/models/textembedding-gecko@001:countTokens \ +-H "Content-Type: application/json" \ +-H "Authorization: Bearer sk-1234" \ +-d '{"instances":[{"content": "gm"}]}' +``` +## Usage Examples + +### Gemini API (Generate Content) + +```shell +curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-001:generateContent \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-1234" \ + -d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}' +``` + +### Embeddings API + +```shell +curl http://localhost:4000/vertex-ai/publishers/google/models/textembedding-gecko@001:predict \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-1234" \ + -d '{"instances":[{"content": "gm"}]}' +``` + +### Imagen API + +```shell +curl http://localhost:4000/vertex-ai/publishers/google/models/imagen-3.0-generate-001:predict \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-1234" \ + -d '{"instances":[{"prompt": "make an otter"}], "parameters": {"sampleCount": 1}}' +``` + +### Count Tokens API + +```shell +curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-001:countTokens \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-1234" \ + -d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}' +``` + +### Tuning API + +Create Fine Tuning Job + +```shell +curl http://localhost:4000/vertex-ai/tuningJobs \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-1234" \ + -d '{ + "baseModel": "gemini-1.0-pro-002", + "supervisedTuningSpec" : { + "training_dataset_uri": "gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl" + } +}' +``` \ No newline at end of file diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 27084f3b45..0305a7d81b 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -24,7 +24,7 @@ const sidebars = { link: { type: "generated-index", title: "💥 LiteLLM Proxy Server", - description: `Proxy Server to call 100+ LLMs in a unified interface & track spend, set budgets per virtual key/user`, + description: `OpenAI Proxy Server to call 100+ LLMs in a unified interface & track spend, set budgets per virtual key/user`, slug: "/simple_proxy", }, items: [ @@ -178,7 +178,7 @@ const sidebars = { }, { type: "category", - label: "Embedding(), Image Generation(), Assistants(), Moderation(), Audio Transcriptions(), TTS(), Batches(), Fine-Tuning()", + label: "Supported Endpoints - /images, /audio/speech, /assistants etc", items: [ "embedding/supported_embedding", "embedding/async_embedding", @@ -189,7 +189,8 @@ const sidebars = { "assistants", "batches", "fine_tuning", - "anthropic_completion" + "anthropic_completion", + "vertex_ai" ], }, { diff --git a/litellm/llms/fine_tuning_apis/vertex_ai.py b/litellm/llms/fine_tuning_apis/vertex_ai.py index f370652d26..5f96f04831 100644 --- a/litellm/llms/fine_tuning_apis/vertex_ai.py +++ b/litellm/llms/fine_tuning_apis/vertex_ai.py @@ -240,3 +240,59 @@ class VertexFineTuningAPI(VertexLLM): vertex_response ) return open_ai_response + + async def pass_through_vertex_ai_POST_request( + self, + request_data: dict, + vertex_project: str, + vertex_location: str, + vertex_credentials: str, + request_route: str, + ): + auth_header, _ = self._get_token_and_url( + model="", + gemini_api_key=None, + vertex_credentials=vertex_credentials, + vertex_project=vertex_project, + vertex_location=vertex_location, + stream=False, + custom_llm_provider="vertex_ai_beta", + api_base="", + ) + + headers = { + "Authorization": f"Bearer {auth_header}", + "Content-Type": "application/json", + } + + url = None + if request_route == "/tuningJobs": + url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/tuningJobs" + elif "/tuningJobs/" in request_route and "cancel" in request_route: + url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/tuningJobs{request_route}" + elif "generateContent" in request_route: + url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}{request_route}" + elif "predict" in request_route: + url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}{request_route}" + elif "/batchPredictionJobs" in request_route: + url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}{request_route}" + elif "countTokens" in request_route: + url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}{request_route}" + else: + raise ValueError(f"Unsupported Vertex AI request route: {request_route}") + if self.async_handler is None: + raise ValueError("VertexAI Fine Tuning - async_handler is not initialized") + + response = await self.async_handler.post( + headers=headers, + url=url, + json=request_data, # type: ignore + ) + + if response.status_code != 200: + raise Exception( + f"Error creating fine tuning job. Status code: {response.status_code}. Response: {response.text}" + ) + + response_json = response.json() + return response_json diff --git a/litellm/main.py b/litellm/main.py index f3e006feb6..f0eb00ecdd 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -5101,23 +5101,27 @@ def stream_chunk_builder( combined_content = "" combined_arguments = "" - if ( - "tool_calls" in chunks[0]["choices"][0]["delta"] - and chunks[0]["choices"][0]["delta"]["tool_calls"] is not None - ): + tool_call_chunks = [ + chunk + for chunk in chunks + if "tool_calls" in chunk["choices"][0]["delta"] + and chunk["choices"][0]["delta"]["tool_calls"] is not None + ] + + if len(tool_call_chunks) > 0: argument_list = [] - delta = chunks[0]["choices"][0]["delta"] + delta = tool_call_chunks[0]["choices"][0]["delta"] message = response["choices"][0]["message"] message["tool_calls"] = [] id = None name = None type = None tool_calls_list = [] - prev_index = 0 + prev_index = None prev_id = None curr_id = None curr_index = 0 - for chunk in chunks: + for chunk in tool_call_chunks: choices = chunk["choices"] for choice in choices: delta = choice.get("delta", {}) @@ -5139,6 +5143,8 @@ def stream_chunk_builder( name = tool_calls[0].function.name if tool_calls[0].type: type = tool_calls[0].type + if prev_index is None: + prev_index = curr_index if curr_index != prev_index: # new tool call combined_arguments = "".join(argument_list) tool_calls_list.append( @@ -5157,18 +5163,24 @@ def stream_chunk_builder( tool_calls_list.append( { "id": id, + "index": curr_index, "function": {"arguments": combined_arguments, "name": name}, "type": type, } ) response["choices"][0]["message"]["content"] = None response["choices"][0]["message"]["tool_calls"] = tool_calls_list - elif ( - "function_call" in chunks[0]["choices"][0]["delta"] - and chunks[0]["choices"][0]["delta"]["function_call"] is not None - ): + + function_call_chunks = [ + chunk + for chunk in chunks + if "function_call" in chunk["choices"][0]["delta"] + and chunk["choices"][0]["delta"]["function_call"] is not None + ] + + if len(function_call_chunks) > 0: argument_list = [] - delta = chunks[0]["choices"][0]["delta"] + delta = function_call_chunks[0]["choices"][0]["delta"] function_call = delta.get("function_call", "") function_call_name = function_call.name @@ -5176,7 +5188,7 @@ def stream_chunk_builder( message["function_call"] = {} message["function_call"]["name"] = function_call_name - for chunk in chunks: + for chunk in function_call_chunks: choices = chunk["choices"] for choice in choices: delta = choice.get("delta", {}) @@ -5193,7 +5205,15 @@ def stream_chunk_builder( response["choices"][0]["message"]["function_call"][ "arguments" ] = combined_arguments - else: + + content_chunks = [ + chunk + for chunk in chunks + if "content" in chunk["choices"][0]["delta"] + and chunk["choices"][0]["delta"]["content"] is not None + ] + + if len(content_chunks) > 0: for chunk in chunks: choices = chunk["choices"] for choice in choices: @@ -5209,12 +5229,12 @@ def stream_chunk_builder( # Update the "content" field within the response dictionary response["choices"][0]["message"]["content"] = combined_content + completion_output = "" if len(combined_content) > 0: - completion_output = combined_content - elif len(combined_arguments) > 0: - completion_output = combined_arguments - else: - completion_output = "" + completion_output += combined_content + if len(combined_arguments) > 0: + completion_output += combined_arguments + # # Update usage information if needed prompt_tokens = 0 completion_tokens = 0 diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 5f7d933a93..47b93ccd2f 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,7 +1,10 @@ model_list: - - model_name: "claude-3-5-sonnet-20240620" + - model_name: "gpt-4" litellm_params: - model: "claude-3-5-sonnet-20240620" - -# litellm_settings: -# failure_callback: ["langfuse"] + model: "gpt-4" + - model_name: "gpt-4" + litellm_params: + model: "gpt-4o" + - model_name: "gpt-4o-mini" + litellm_params: + model: "gpt-4o-mini" \ No newline at end of file diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index aa2bfc5252..0750a39376 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -48,6 +48,11 @@ files_settings: - custom_llm_provider: openai api_key: os.environ/OPENAI_API_KEY +default_vertex_config: + vertex_project: "adroit-crow-413218" + vertex_location: "us-central1" + vertex_credentials: "/Users/ishaanjaffer/Downloads/adroit-crow-413218-a956eef1a2a8.json" + general_settings: diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 538feac49e..a9b49138b9 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -213,6 +213,8 @@ from litellm.proxy.utils import ( send_email, update_spend, ) +from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import router as vertex_router +from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import set_default_vertex_config from litellm.router import ( AssistantsTypedDict, Deployment, @@ -1818,6 +1820,10 @@ class ProxyConfig: files_config = config.get("files_settings", None) set_files_config(config=files_config) + ## default config for vertex ai routes + default_vertex_config = config.get("default_vertex_config", None) + set_default_vertex_config(config=default_vertex_config) + ## ROUTER SETTINGS (e.g. routing_strategy, ...) router_settings = config.get("router_settings", None) if router_settings and isinstance(router_settings, dict): @@ -9698,6 +9704,7 @@ def cleanup_router_config_variables(): app.include_router(router) app.include_router(fine_tuning_router) +app.include_router(vertex_router) app.include_router(health_router) app.include_router(key_management_router) app.include_router(internal_user_router) diff --git a/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py new file mode 100644 index 0000000000..b8c04583c3 --- /dev/null +++ b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py @@ -0,0 +1,305 @@ +import ast +import asyncio +import traceback +from datetime import datetime, timedelta, timezone +from typing import List, Optional + +import fastapi +import httpx +from fastapi import ( + APIRouter, + Depends, + File, + Form, + Header, + HTTPException, + Request, + Response, + UploadFile, + status, +) + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.batches.main import FileObject +from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance +from litellm.proxy._types import * +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth + +router = APIRouter() +default_vertex_config = None + + +def set_default_vertex_config(config): + global default_vertex_config + if config is None: + return + + if not isinstance(config, dict): + raise ValueError("invalid config, vertex default config must be a dictionary") + + if isinstance(config, dict): + for key, value in config.items(): + if isinstance(value, str) and value.startswith("os.environ/"): + config[key] = litellm.get_secret(value) + + default_vertex_config = config + + +def exception_handler(e: Exception): + verbose_proxy_logger.error( + "litellm.proxy.proxy_server.v1/projects/tuningJobs(): Exception occurred - {}".format( + str(e) + ) + ) + verbose_proxy_logger.debug(traceback.format_exc()) + if isinstance(e, HTTPException): + return ProxyException( + message=getattr(e, "message", str(e.detail)), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + else: + error_msg = f"{str(e)}" + return ProxyException( + message=getattr(e, "message", error_msg), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", 500), + ) + + +async def execute_post_vertex_ai_request( + request: Request, + route: str, +): + from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance + + if default_vertex_config is None: + raise ValueError( + "Vertex credentials not added on litellm proxy, please add `default_vertex_config` on your config.yaml" + ) + vertex_project = default_vertex_config.get("vertex_project", None) + vertex_location = default_vertex_config.get("vertex_location", None) + vertex_credentials = default_vertex_config.get("vertex_credentials", None) + + request_data_json = {} + body = await request.body() + body_str = body.decode() + if len(body_str) > 0: + try: + request_data_json = ast.literal_eval(body_str) + except: + request_data_json = json.loads(body_str) + + verbose_proxy_logger.debug( + "Request received by LiteLLM:\n{}".format( + json.dumps(request_data_json, indent=4) + ), + ) + + response = ( + await vertex_fine_tuning_apis_instance.pass_through_vertex_ai_POST_request( + request_data=request_data_json, + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_credentials=vertex_credentials, + request_route=route, + ) + ) + + return response + + +@router.post( + "/vertex-ai/publishers/google/models/{model_id:path}:generateContent", + dependencies=[Depends(user_api_key_auth)], + tags=["Vertex AI endpoints"], +) +async def vertex_generate_content( + request: Request, + fastapi_response: Response, + model_id: str, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + this is a pass through endpoint for the Vertex AI API. /generateContent endpoint + + Example Curl: + ``` + curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-001:generateContent \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-1234" \ + -d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}' + ``` + + Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference#rest + it uses the vertex ai credentials on the proxy and forwards to vertex ai api + """ + try: + response = await execute_post_vertex_ai_request( + request=request, + route=f"/publishers/google/models/{model_id}:generateContent", + ) + return response + except Exception as e: + raise exception_handler(e) from e + + +@router.post( + "/vertex-ai/publishers/google/models/{model_id:path}:predict", + dependencies=[Depends(user_api_key_auth)], + tags=["Vertex AI endpoints"], +) +async def vertex_predict_endpoint( + request: Request, + fastapi_response: Response, + model_id: str, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + this is a pass through endpoint for the Vertex AI API. /predict endpoint + Use this for: + - Embeddings API - Text Embedding, Multi Modal Embedding + - Imagen API + - Code Completion API + + Example Curl: + ``` + curl http://localhost:4000/vertex-ai/publishers/google/models/textembedding-gecko@001:predict \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-1234" \ + -d '{"instances":[{"content": "gm"}]}' + ``` + + Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#generative-ai-get-text-embedding-drest + it uses the vertex ai credentials on the proxy and forwards to vertex ai api + """ + try: + response = await execute_post_vertex_ai_request( + request=request, + route=f"/publishers/google/models/{model_id}:predict", + ) + return response + except Exception as e: + raise exception_handler(e) from e + + +@router.post( + "/vertex-ai/publishers/google/models/{model_id:path}:countTokens", + dependencies=[Depends(user_api_key_auth)], + tags=["Vertex AI endpoints"], +) +async def vertex_countTokens_endpoint( + request: Request, + fastapi_response: Response, + model_id: str, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + this is a pass through endpoint for the Vertex AI API. /countTokens endpoint + https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/count-tokens#curl + + + Example Curl: + ``` + curl http://localhost:4000/vertex-ai/publishers/google/models/gemini-1.5-flash-001:countTokens \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-1234" \ + -d '{"contents":[{"role": "user", "parts":[{"text": "hi"}]}]}' + ``` + + it uses the vertex ai credentials on the proxy and forwards to vertex ai api + """ + try: + response = await execute_post_vertex_ai_request( + request=request, + route=f"/publishers/google/models/{model_id}:countTokens", + ) + return response + except Exception as e: + raise exception_handler(e) from e + + +@router.post( + "/vertex-ai/batchPredictionJobs", + dependencies=[Depends(user_api_key_auth)], + tags=["Vertex AI endpoints"], +) +async def vertex_create_batch_prediction_job( + request: Request, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + this is a pass through endpoint for the Vertex AI API. /batchPredictionJobs endpoint + + Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/batch-prediction-api#syntax + + it uses the vertex ai credentials on the proxy and forwards to vertex ai api + """ + try: + response = await execute_post_vertex_ai_request( + request=request, + route="/batchPredictionJobs", + ) + return response + except Exception as e: + raise exception_handler(e) from e + + +@router.post( + "/vertex-ai/tuningJobs", + dependencies=[Depends(user_api_key_auth)], + tags=["Vertex AI endpoints"], +) +async def vertex_create_fine_tuning_job( + request: Request, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + this is a pass through endpoint for the Vertex AI API. /tuningJobs endpoint + + Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning + + it uses the vertex ai credentials on the proxy and forwards to vertex ai api + """ + try: + response = await execute_post_vertex_ai_request( + request=request, + route="/tuningJobs", + ) + return response + except Exception as e: + raise exception_handler(e) from e + + +@router.post( + "/vertex-ai/tuningJobs/{job_id:path}:cancel", + dependencies=[Depends(user_api_key_auth)], + tags=["Vertex AI endpoints"], +) +async def vertex_cancel_fine_tuning_job( + request: Request, + job_id: str, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + this is a pass through endpoint for the Vertex AI API. tuningJobs/{job_id:path}:cancel + + Vertex API Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning#cancel_a_tuning_job + + it uses the vertex ai credentials on the proxy and forwards to vertex ai api + """ + try: + + response = await execute_post_vertex_ai_request( + request=request, + route=f"/tuningJobs/{job_id}:cancel", + ) + return response + except Exception as e: + raise exception_handler(e) from e diff --git a/litellm/router.py b/litellm/router.py index 108ca706c5..e31de5332e 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -2468,6 +2468,8 @@ class Router: verbose_router_logger.info( f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}" ) + if hasattr(original_exception, "message"): + original_exception.message += f"No fallback model group found for original model_group={model_group}. Fallbacks={fallbacks}" raise original_exception for mg in fallback_model_group: """ @@ -2492,14 +2494,20 @@ class Router: return response except Exception as e: raise e - except Exception as e: - verbose_router_logger.error(f"An exception occurred - {str(e)}") - verbose_router_logger.debug(traceback.format_exc()) + except Exception as new_exception: + verbose_router_logger.error( + "litellm.router.py::async_function_with_fallbacks() - Error occurred while trying to do fallbacks - {}\n{}\n\nDebug Information:\nCooldown Deployments={}".format( + str(new_exception), + traceback.format_exc(), + await self._async_get_cooldown_deployments_with_debug_info(), + ) + ) if hasattr(original_exception, "message"): # add the available fallbacks to the exception original_exception.message += "\nReceived Model Group={}\nAvailable Model Group Fallbacks={}".format( - model_group, fallback_model_group + model_group, + fallback_model_group, ) raise original_exception @@ -2508,6 +2516,9 @@ class Router: f"Inside async function with retries: args - {args}; kwargs - {kwargs}" ) original_function = kwargs.pop("original_function") + mock_testing_rate_limit_error = kwargs.pop( + "mock_testing_rate_limit_error", None + ) fallbacks = kwargs.pop("fallbacks", self.fallbacks) context_window_fallbacks = kwargs.pop( "context_window_fallbacks", self.context_window_fallbacks @@ -2515,13 +2526,25 @@ class Router: content_policy_fallbacks = kwargs.pop( "content_policy_fallbacks", self.content_policy_fallbacks ) - + model_group = kwargs.get("model") num_retries = kwargs.pop("num_retries") verbose_router_logger.debug( f"async function w/ retries: original_function - {original_function}, num_retries - {num_retries}" ) try: + if ( + mock_testing_rate_limit_error is not None + and mock_testing_rate_limit_error is True + ): + verbose_router_logger.info( + "litellm.router.py::async_function_with_retries() - mock_testing_rate_limit_error=True. Raising litellm.RateLimitError." + ) + raise litellm.RateLimitError( + model=model_group, + llm_provider="", + message=f"This is a mock exception for model={model_group}, to trigger a rate limit error.", + ) # if the function call is successful, no exception will be raised and we'll break out of the loop response = await original_function(*args, **kwargs) return response diff --git a/litellm/tests/stream_chunk_testdata.py b/litellm/tests/stream_chunk_testdata.py new file mode 100644 index 0000000000..6be9d1ebdf --- /dev/null +++ b/litellm/tests/stream_chunk_testdata.py @@ -0,0 +1,543 @@ +from litellm.types.utils import ( + ChatCompletionDeltaToolCall, + Delta, + Function, + ModelResponse, + StreamingChoices, +) + +chunks = [ + ModelResponse( + id="chatcmpl-634a6ad3-483a-44a1-8cdd-3befbeb4ac2f", + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta( + content="To answer", + role="assistant", + function_call=None, + tool_calls=None, + ), + logprobs=None, + ) + ], + created=1722656356, + model="claude-3-5-sonnet-20240620", + object="chat.completion.chunk", + system_fingerprint=None, + ), + ModelResponse( + id="chatcmpl-634a6ad3-483a-44a1-8cdd-3befbeb4ac2f", + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta( + content=" your", role=None, function_call=None, tool_calls=None + ), + logprobs=None, + ) + ], + created=1722656356, + model="claude-3-5-sonnet-20240620", + object="chat.completion.chunk", + system_fingerprint=None, + ), + ModelResponse( + id="chatcmpl-634a6ad3-483a-44a1-8cdd-3befbeb4ac2f", + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta( + content=" question about", + role=None, + function_call=None, + tool_calls=None, + ), + logprobs=None, + ) + ], + created=1722656356, + model="claude-3-5-sonnet-20240620", + object="chat.completion.chunk", + system_fingerprint=None, + ), + ModelResponse( + id="chatcmpl-634a6ad3-483a-44a1-8cdd-3befbeb4ac2f", + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta( + content=" how", role=None, function_call=None, tool_calls=None + ), + logprobs=None, + ) + ], + created=1722656356, + model="claude-3-5-sonnet-20240620", + object="chat.completion.chunk", + system_fingerprint=None, + ), + ModelResponse( + id="chatcmpl-634a6ad3-483a-44a1-8cdd-3befbeb4ac2f", + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta( + content=" many rows are in the ", + role=None, + function_call=None, + tool_calls=None, + ), + logprobs=None, + ) + ], + created=1722656356, + model="claude-3-5-sonnet-20240620", + object="chat.completion.chunk", + system_fingerprint=None, + ), + ModelResponse( + id="chatcmpl-634a6ad3-483a-44a1-8cdd-3befbeb4ac2f", + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta( + content="'users' table, I", + role=None, + function_call=None, + tool_calls=None, + ), + logprobs=None, + ) + ], + created=1722656356, + model="claude-3-5-sonnet-20240620", + object="chat.completion.chunk", + system_fingerprint=None, + ), + ModelResponse( + id="chatcmpl-634a6ad3-483a-44a1-8cdd-3befbeb4ac2f", + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta( + content="'ll", role=None, function_call=None, tool_calls=None + ), + logprobs=None, + ) + ], + created=1722656356, + model="claude-3-5-sonnet-20240620", + object="chat.completion.chunk", + system_fingerprint=None, + ), + ModelResponse( + id="chatcmpl-634a6ad3-483a-44a1-8cdd-3befbeb4ac2f", + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta( + content=" need to", role=None, function_call=None, tool_calls=None + ), + logprobs=None, + ) + ], + created=1722656356, + model="claude-3-5-sonnet-20240620", + object="chat.completion.chunk", + system_fingerprint=None, + ), + ModelResponse( + id="chatcmpl-634a6ad3-483a-44a1-8cdd-3befbeb4ac2f", + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta( + content=" run", role=None, function_call=None, tool_calls=None + ), + logprobs=None, + ) + ], + created=1722656356, + model="claude-3-5-sonnet-20240620", + object="chat.completion.chunk", + system_fingerprint=None, + ), + ModelResponse( + id="chatcmpl-634a6ad3-483a-44a1-8cdd-3befbeb4ac2f", + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta( + content=" a SQL query.", + role=None, + function_call=None, + tool_calls=None, + ), + logprobs=None, + ) + ], + created=1722656356, + model="claude-3-5-sonnet-20240620", + object="chat.completion.chunk", + system_fingerprint=None, + ), + ModelResponse( + id="chatcmpl-634a6ad3-483a-44a1-8cdd-3befbeb4ac2f", + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta( + content=" Let", role=None, function_call=None, tool_calls=None + ), + logprobs=None, + ) + ], + created=1722656356, + model="claude-3-5-sonnet-20240620", + object="chat.completion.chunk", + system_fingerprint=None, + ), + ModelResponse( + id="chatcmpl-634a6ad3-483a-44a1-8cdd-3befbeb4ac2f", + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta( + content=" me", role=None, function_call=None, tool_calls=None + ), + logprobs=None, + ) + ], + created=1722656356, + model="claude-3-5-sonnet-20240620", + object="chat.completion.chunk", + system_fingerprint=None, + ), + ModelResponse( + id="chatcmpl-634a6ad3-483a-44a1-8cdd-3befbeb4ac2f", + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta( + content=" ", role=None, function_call=None, tool_calls=None + ), + logprobs=None, + ) + ], + created=1722656356, + model="claude-3-5-sonnet-20240620", + object="chat.completion.chunk", + system_fingerprint=None, + ), + ModelResponse( + id="chatcmpl-634a6ad3-483a-44a1-8cdd-3befbeb4ac2f", + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta( + content="do that for", + role=None, + function_call=None, + tool_calls=None, + ), + logprobs=None, + ) + ], + created=1722656356, + model="claude-3-5-sonnet-20240620", + object="chat.completion.chunk", + system_fingerprint=None, + ), + ModelResponse( + id="chatcmpl-634a6ad3-483a-44a1-8cdd-3befbeb4ac2f", + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta( + content=" you.", role=None, function_call=None, tool_calls=None + ), + logprobs=None, + ) + ], + created=1722656356, + model="claude-3-5-sonnet-20240620", + object="chat.completion.chunk", + system_fingerprint=None, + ), + ModelResponse( + id="chatcmpl-634a6ad3-483a-44a1-8cdd-3befbeb4ac2f", + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta( + content="", + role=None, + function_call=None, + tool_calls=[ + ChatCompletionDeltaToolCall( + id="toolu_01H3AjkLpRtGQrof13CBnWfK", + function=Function(arguments="", name="sql_query"), + type="function", + index=1, + ) + ], + ), + logprobs=None, + ) + ], + created=1722656356, + model="claude-3-5-sonnet-20240620", + object="chat.completion.chunk", + system_fingerprint=None, + ), + ModelResponse( + id="chatcmpl-634a6ad3-483a-44a1-8cdd-3befbeb4ac2f", + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta( + content="", + role=None, + function_call=None, + tool_calls=[ + ChatCompletionDeltaToolCall( + id=None, + function=Function(arguments="", name=None), + type="function", + index=1, + ) + ], + ), + logprobs=None, + ) + ], + created=1722656356, + model="claude-3-5-sonnet-20240620", + object="chat.completion.chunk", + system_fingerprint=None, + ), + ModelResponse( + id="chatcmpl-634a6ad3-483a-44a1-8cdd-3befbeb4ac2f", + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta( + content="", + role=None, + function_call=None, + tool_calls=[ + ChatCompletionDeltaToolCall( + id=None, + function=Function(arguments='{"', name=None), + type="function", + index=1, + ) + ], + ), + logprobs=None, + ) + ], + created=1722656357, + model="claude-3-5-sonnet-20240620", + object="chat.completion.chunk", + system_fingerprint=None, + ), + ModelResponse( + id="chatcmpl-634a6ad3-483a-44a1-8cdd-3befbeb4ac2f", + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta( + content="", + role=None, + function_call=None, + tool_calls=[ + ChatCompletionDeltaToolCall( + id=None, + function=Function(arguments='query": ', name=None), + type="function", + index=1, + ) + ], + ), + logprobs=None, + ) + ], + created=1722656357, + model="claude-3-5-sonnet-20240620", + object="chat.completion.chunk", + system_fingerprint=None, + ), + ModelResponse( + id="chatcmpl-634a6ad3-483a-44a1-8cdd-3befbeb4ac2f", + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta( + content="", + role=None, + function_call=None, + tool_calls=[ + ChatCompletionDeltaToolCall( + id=None, + function=Function(arguments='"SELECT C', name=None), + type="function", + index=1, + ) + ], + ), + logprobs=None, + ) + ], + created=1722656357, + model="claude-3-5-sonnet-20240620", + object="chat.completion.chunk", + system_fingerprint=None, + ), + ModelResponse( + id="chatcmpl-634a6ad3-483a-44a1-8cdd-3befbeb4ac2f", + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta( + content="", + role=None, + function_call=None, + tool_calls=[ + ChatCompletionDeltaToolCall( + id=None, + function=Function(arguments="OUNT(*", name=None), + type="function", + index=1, + ) + ], + ), + logprobs=None, + ) + ], + created=1722656357, + model="claude-3-5-sonnet-20240620", + object="chat.completion.chunk", + system_fingerprint=None, + ), + ModelResponse( + id="chatcmpl-634a6ad3-483a-44a1-8cdd-3befbeb4ac2f", + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta( + content="", + role=None, + function_call=None, + tool_calls=[ + ChatCompletionDeltaToolCall( + id=None, + function=Function(arguments=") ", name=None), + type="function", + index=1, + ) + ], + ), + logprobs=None, + ) + ], + created=1722656357, + model="claude-3-5-sonnet-20240620", + object="chat.completion.chunk", + system_fingerprint=None, + ), + ModelResponse( + id="chatcmpl-634a6ad3-483a-44a1-8cdd-3befbeb4ac2f", + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta( + content="", + role=None, + function_call=None, + tool_calls=[ + ChatCompletionDeltaToolCall( + id=None, + function=Function(arguments="FROM use", name=None), + type="function", + index=1, + ) + ], + ), + logprobs=None, + ) + ], + created=1722656357, + model="claude-3-5-sonnet-20240620", + object="chat.completion.chunk", + system_fingerprint=None, + ), + ModelResponse( + id="chatcmpl-634a6ad3-483a-44a1-8cdd-3befbeb4ac2f", + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta( + content="", + role=None, + function_call=None, + tool_calls=[ + ChatCompletionDeltaToolCall( + id=None, + function=Function(arguments='rs;"}', name=None), + type="function", + index=1, + ) + ], + ), + logprobs=None, + ) + ], + created=1722656357, + model="claude-3-5-sonnet-20240620", + object="chat.completion.chunk", + system_fingerprint=None, + ), + ModelResponse( + id="chatcmpl-634a6ad3-483a-44a1-8cdd-3befbeb4ac2f", + choices=[ + StreamingChoices( + finish_reason="tool_calls", + index=0, + delta=Delta( + content=None, role=None, function_call=None, tool_calls=None + ), + logprobs=None, + ) + ], + created=1722656357, + model="claude-3-5-sonnet-20240620", + object="chat.completion.chunk", + system_fingerprint=None, + ), +] diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index c26035ad0a..eec163f26a 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -23,7 +23,7 @@ from litellm import RateLimitError, Timeout, completion, completion_cost, embedd from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.prompt_templates.factory import anthropic_messages_pt -# litellm.num_retries = 3 +# litellm.num_retries=3 litellm.cache = None litellm.success_callback = [] user_message = "Write a short poem about the sky" diff --git a/litellm/tests/test_fine_tuning_api.py b/litellm/tests/test_fine_tuning_api.py index 20a58c4d00..412ffb497c 100644 --- a/litellm/tests/test_fine_tuning_api.py +++ b/litellm/tests/test_fine_tuning_api.py @@ -80,7 +80,10 @@ def test_create_fine_tune_job(): except openai.RateLimitError: pass except Exception as e: - pytest.fail(f"Error occurred: {e}") + if "Job has already completed" in str(e): + return + else: + pytest.fail(f"Error occurred: {e}") @pytest.mark.asyncio @@ -135,7 +138,7 @@ async def test_create_fine_tune_jobs_async(): pass except Exception as e: if "Job has already completed" in str(e): - pass + return else: pytest.fail(f"Error occurred: {e}") pass diff --git a/litellm/tests/test_image_generation.py b/litellm/tests/test_image_generation.py index fcbaa74579..56271e37c2 100644 --- a/litellm/tests/test_image_generation.py +++ b/litellm/tests/test_image_generation.py @@ -7,6 +7,7 @@ import sys import traceback from dotenv import load_dotenv +from openai.types.image import Image logging.basicConfig(level=logging.DEBUG) load_dotenv() diff --git a/litellm/tests/test_stream_chunk_builder.py b/litellm/tests/test_stream_chunk_builder.py index 342b070ae7..78d2617f1e 100644 --- a/litellm/tests/test_stream_chunk_builder.py +++ b/litellm/tests/test_stream_chunk_builder.py @@ -18,6 +18,8 @@ from openai import OpenAI import litellm from litellm import completion, stream_chunk_builder +import litellm.tests.stream_chunk_testdata + dotenv.load_dotenv() user_message = "What is the current weather in Boston?" @@ -196,3 +198,24 @@ def test_stream_chunk_builder_litellm_usage_chunks(): # assert prompt tokens are the same assert gemini_pt == stream_rebuilt_pt + + +def test_stream_chunk_builder_litellm_mixed_calls(): + response = stream_chunk_builder(litellm.tests.stream_chunk_testdata.chunks) + assert ( + response.choices[0].message.content + == "To answer your question about how many rows are in the 'users' table, I'll need to run a SQL query. Let me do that for you." + ) + + print(response.choices[0].message.tool_calls[0].to_dict()) + + assert len(response.choices[0].message.tool_calls) == 1 + assert response.choices[0].message.tool_calls[0].to_dict() == { + "index": 1, + "function": { + "arguments": '{"query": "SELECT COUNT(*) FROM users;"}', + "name": "sql_query", + }, + "id": "toolu_01H3AjkLpRtGQrof13CBnWfK", + "type": "function", + } diff --git a/pyproject.toml b/pyproject.toml index 6293b77fb0..a803c8e0fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "1.42.11" +version = "1.42.12" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT" @@ -91,7 +91,7 @@ requires = ["poetry-core", "wheel"] build-backend = "poetry.core.masonry.api" [tool.commitizen] -version = "1.42.11" +version = "1.42.12" version_files = [ "pyproject.toml:^version" ]