Merge pull request #5457 from BerriAI/litellm_track_spend_logs_for_vertex_pass_through_endpoints

[Feat-Proxy] track spend logs for vertex pass through endpoints
This commit is contained in:
Ishaan Jaff 2024-08-31 16:30:15 -07:00 committed by GitHub
commit 56f10224df
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 384 additions and 10 deletions

View file

@ -251,7 +251,7 @@ jobs:
command: | command: |
pwd pwd
ls ls
python -m pytest -vv tests/ -x --junitxml=test-results/junit.xml --durations=5 --ignore=tests/otel_tests python -m pytest -s -vv tests/ -x --junitxml=test-results/junit.xml --durations=5 --ignore=tests/otel_tests --ignore=tests/pass_through_tests
no_output_timeout: 120m no_output_timeout: 120m
# Store test results # Store test results
@ -363,6 +363,100 @@ jobs:
- store_test_results: - store_test_results:
path: test-results path: test-results
proxy_pass_through_endpoint_tests:
machine:
image: ubuntu-2204:2023.10.1
resource_class: xlarge
working_directory: ~/project
steps:
- checkout
- run:
name: Install Docker CLI (In case it's not already installed)
command: |
sudo apt-get update
sudo apt-get install -y docker-ce docker-ce-cli containerd.io
- run:
name: Install Python 3.9
command: |
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh --output miniconda.sh
bash miniconda.sh -b -p $HOME/miniconda
export PATH="$HOME/miniconda/bin:$PATH"
conda init bash
source ~/.bashrc
conda create -n myenv python=3.9 -y
conda activate myenv
python --version
- run:
name: Install Dependencies
command: |
pip install "pytest==7.3.1"
pip install "pytest-retry==1.6.3"
pip install "pytest-asyncio==0.21.1"
pip install "google-cloud-aiplatform==1.43.0"
pip install aiohttp
pip install "openai==1.40.0"
python -m pip install --upgrade pip
pip install "pydantic==2.7.1"
pip install "pytest==7.3.1"
pip install "pytest-mock==3.12.0"
pip install "pytest-asyncio==0.21.1"
pip install "boto3==1.34.34"
pip install mypy
pip install pyarrow
pip install numpydoc
pip install prisma
pip install fastapi
pip install jsonschema
pip install "httpx==0.24.1"
pip install "anyio==3.7.1"
pip install "asyncio==3.4.3"
pip install "PyGithub==1.59.1"
- run:
name: Build Docker image
command: docker build -t my-app:latest -f Dockerfile.database .
- run:
name: Run Docker container
command: |
docker run -d \
-p 4000:4000 \
-e DATABASE_URL=$PROXY_DATABASE_URL \
-e LITELLM_MASTER_KEY="sk-1234" \
-e OPENAI_API_KEY=$OPENAI_API_KEY \
-e LITELLM_LICENSE=$LITELLM_LICENSE \
--name my-app \
-v $(pwd)/litellm/proxy/example_config_yaml/pass_through_config.yaml:/app/config.yaml \
-v $(pwd)/litellm/proxy/example_config_yaml/custom_auth_basic.py:/app/custom_auth_basic.py \
my-app:latest \
--config /app/config.yaml \
--port 4000 \
--detailed_debug \
- run:
name: Install curl and dockerize
command: |
sudo apt-get update
sudo apt-get install -y curl
sudo wget https://github.com/jwilder/dockerize/releases/download/v0.6.1/dockerize-linux-amd64-v0.6.1.tar.gz
sudo tar -C /usr/local/bin -xzvf dockerize-linux-amd64-v0.6.1.tar.gz
sudo rm dockerize-linux-amd64-v0.6.1.tar.gz
- run:
name: Start outputting logs
command: docker logs -f my-app
background: true
- run:
name: Wait for app to be ready
command: dockerize -wait http://localhost:4000 -timeout 5m
- run:
name: Run tests
command: |
pwd
ls
python -m pytest -vv tests/pass_through_tests/ -x --junitxml=test-results/junit.xml --durations=5
no_output_timeout: 120m
# Store test results
- store_test_results:
path: test-results
publish_to_pypi: publish_to_pypi:
docker: docker:
- image: cimg/python:3.8 - image: cimg/python:3.8
@ -457,6 +551,12 @@ workflows:
only: only:
- main - main
- /litellm_.*/ - /litellm_.*/
- proxy_pass_through_endpoint_tests:
filters:
branches:
only:
- main
- /litellm_.*/
- installing_litellm_on_python: - installing_litellm_on_python:
filters: filters:
branches: branches:
@ -468,6 +568,7 @@ workflows:
- local_testing - local_testing
- build_and_test - build_and_test
- proxy_log_to_otel_tests - proxy_log_to_otel_tests
- proxy_pass_through_endpoint_tests
filters: filters:
branches: branches:
only: only:

View file

@ -1,7 +1,9 @@
from litellm.proxy._types import UserAPIKeyAuth, GenerateKeyRequest
from fastapi import Request
import os import os
from fastapi import Request
from litellm.proxy._types import GenerateKeyRequest, UserAPIKeyAuth
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth: async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
try: try:

View file

@ -0,0 +1,14 @@
from fastapi import Request
from litellm.proxy._types import UserAPIKeyAuth
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
try:
return UserAPIKeyAuth(
api_key="best-api-key-ever",
user_id="best-user-id-ever",
team_id="best-team-id-ever",
)
except:
raise Exception

View file

@ -0,0 +1,9 @@
model_list:
- model_name: fake-openai-endpoint
litellm_params:
model: openai/fake
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
general_settings:
master_key: sk-1234
custom_auth: custom_auth_basic.user_api_key_auth

View file

@ -3,6 +3,7 @@ import asyncio
import json import json
import traceback import traceback
from base64 import b64encode from base64 import b64encode
from datetime import datetime
from typing import AsyncIterable, List, Optional from typing import AsyncIterable, List, Optional
import httpx import httpx
@ -20,6 +21,7 @@ from fastapi.responses import StreamingResponse
import litellm import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.proxy._types import ( from litellm.proxy._types import (
ConfigFieldInfo, ConfigFieldInfo,
ConfigFieldUpdate, ConfigFieldUpdate,
@ -30,8 +32,12 @@ from litellm.proxy._types import (
) )
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from .success_handler import PassThroughEndpointLogging
router = APIRouter() router = APIRouter()
pass_through_endpoint_logging = PassThroughEndpointLogging()
async def set_env_variables_in_header(custom_headers: dict): async def set_env_variables_in_header(custom_headers: dict):
""" """
@ -330,7 +336,7 @@ async def pass_through_request(
async_client = httpx.AsyncClient(timeout=600) async_client = httpx.AsyncClient(timeout=600)
# create logging object # create logging object
start_time = time.time() start_time = datetime.now()
logging_obj = Logging( logging_obj = Logging(
model="unknown", model="unknown",
messages=[{"role": "user", "content": "no-message-pass-through-endpoint"}], messages=[{"role": "user", "content": "no-message-pass-through-endpoint"}],
@ -473,12 +479,15 @@ async def pass_through_request(
content = await response.aread() content = await response.aread()
## LOG SUCCESS ## LOG SUCCESS
end_time = time.time() end_time = datetime.now()
await logging_obj.async_success_handler( await pass_through_endpoint_logging.pass_through_async_success_handler(
httpx_response=response,
url_route=str(url),
result="", result="",
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
logging_obj=logging_obj,
cache_hit=False, cache_hit=False,
) )

View file

@ -0,0 +1,105 @@
import re
from datetime import datetime
import httpx
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
class PassThroughEndpointLogging:
def __init__(self):
self.TRACKED_VERTEX_ROUTES = [
"generateContent",
"streamGenerateContent",
"predict",
]
async def pass_through_async_success_handler(
self,
httpx_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
url_route: str,
result: str,
start_time: datetime,
end_time: datetime,
cache_hit: bool,
**kwargs,
):
if self.is_vertex_route(url_route):
await self.vertex_passthrough_handler(
httpx_response=httpx_response,
logging_obj=logging_obj,
url_route=url_route,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
**kwargs,
)
else:
await logging_obj.async_success_handler(
result="",
start_time=start_time,
end_time=end_time,
cache_hit=False,
)
def is_vertex_route(self, url_route: str):
for route in self.TRACKED_VERTEX_ROUTES:
if route in url_route:
return True
return False
def extract_model_from_url(self, url: str) -> str:
pattern = r"/models/([^:]+)"
match = re.search(pattern, url)
if match:
return match.group(1)
return "unknown"
async def vertex_passthrough_handler(
self,
httpx_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
url_route: str,
result: str,
start_time: datetime,
end_time: datetime,
cache_hit: bool,
**kwargs,
):
if "generateContent" in url_route:
model = self.extract_model_from_url(url_route)
instance_of_vertex_llm = VertexLLM()
litellm_model_response: litellm.ModelResponse = (
instance_of_vertex_llm._process_response(
model=model,
messages=[
{"role": "user", "content": "no-message-pass-through-endpoint"}
],
response=httpx_response,
model_response=litellm.ModelResponse(),
logging_obj=logging_obj,
optional_params={},
litellm_params={},
api_key="",
data={},
print_verbose=litellm.print_verbose,
encoding=None,
)
)
logging_obj.model = litellm_model_response.model
logging_obj.model_call_details["model"] = logging_obj.model
await logging_obj.async_success_handler(
result=litellm_model_response,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
)

View file

@ -21,3 +21,4 @@ router_settings:
general_settings: general_settings:
master_key: sk-1234 master_key: sk-1234
custom_auth: example_config_yaml.custom_auth_basic.user_api_key_auth

View file

@ -696,10 +696,10 @@ def load_from_azure_key_vault(use_azure_key_vault: bool = False):
def cost_tracking(): def cost_tracking():
global prisma_client, custom_db_client global prisma_client, custom_db_client
if prisma_client is not None or custom_db_client is not None: if prisma_client is not None or custom_db_client is not None:
if isinstance(litellm.success_callback, list): if isinstance(litellm._async_success_callback, list):
verbose_proxy_logger.debug("setting litellm success callback to track cost") verbose_proxy_logger.debug("setting litellm success callback to track cost")
if (_PROXY_track_cost_callback) not in litellm.success_callback: # type: ignore if (_PROXY_track_cost_callback) not in litellm._async_success_callback: # type: ignore
litellm.success_callback.append(_PROXY_track_cost_callback) # type: ignore litellm._async_success_callback.append(_PROXY_track_cost_callback) # type: ignore
async def _PROXY_failure_handler( async def _PROXY_failure_handler(

View file

@ -976,6 +976,7 @@ async def test_async_embedding_bedrock():
# CACHING # CACHING
## Test Azure - completion, embedding ## Test Azure - completion, embedding
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.flaky(retries=3, delay=1)
async def test_async_completion_azure_caching(): async def test_async_completion_azure_caching():
litellm.set_verbose = True litellm.set_verbose = True
customHandler_caching = CompletionCustomHandler() customHandler_caching = CompletionCustomHandler()

View file

@ -0,0 +1,119 @@
"""
Test Vertex AI Pass Through
1. use Credentials client side, Assert SpendLog was created
"""
import vertexai
from vertexai.preview.generative_models import GenerativeModel
import tempfile
import json
import os
import pytest
import asyncio
# Path to your service account JSON file
SERVICE_ACCOUNT_FILE = "path/to/your/service-account.json"
def load_vertex_ai_credentials():
# Define the path to the vertex_key.json file
print("loading vertex ai credentials")
filepath = os.path.dirname(os.path.abspath(__file__))
vertex_key_path = filepath + "/vertex_key.json"
# Read the existing content of the file or create an empty dictionary
try:
with open(vertex_key_path, "r") as file:
# Read the file content
print("Read vertexai file path")
content = file.read()
# If the file is empty or not valid JSON, create an empty dictionary
if not content or not content.strip():
service_account_key_data = {}
else:
# Attempt to load the existing JSON content
file.seek(0)
service_account_key_data = json.load(file)
except FileNotFoundError:
# If the file doesn't exist, create an empty dictionary
service_account_key_data = {}
# Update the service_account_key_data with environment variables
private_key_id = os.environ.get("VERTEX_AI_PRIVATE_KEY_ID", "")
private_key = os.environ.get("VERTEX_AI_PRIVATE_KEY", "")
private_key = private_key.replace("\\n", "\n")
service_account_key_data["private_key_id"] = private_key_id
service_account_key_data["private_key"] = private_key
# Create a temporary file
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file:
# Write the updated content to the temporary files
json.dump(service_account_key_data, temp_file, indent=2)
# Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name)
async def call_spend_logs_endpoint():
"""
Call this
curl -X GET "http://0.0.0.0:4000/spend/logs" -H "Authorization: Bearer sk-1234"
"""
import datetime
import requests
todays_date = datetime.datetime.now().strftime("%Y-%m-%d")
url = f"http://0.0.0.0:4000/global/spend/logs?api_key=best-api-key-ever"
headers = {"Authorization": f"Bearer sk-1234"}
response = requests.get(url, headers=headers)
print("response from call_spend_logs_endpoint", response)
json_response = response.json()
# get spend for today
"""
json response looks like this
[{'date': '2024-08-30', 'spend': 0.00016600000000000002, 'api_key': 'best-api-key-ever'}]
"""
todays_date = datetime.datetime.now().strftime("%Y-%m-%d")
for spend_log in json_response:
if spend_log["date"] == todays_date:
return spend_log["spend"]
LITE_LLM_ENDPOINT = "http://localhost:4000"
@pytest.mark.asyncio()
async def test_basic_vertex_ai_pass_through_with_spendlog():
spend_before = await call_spend_logs_endpoint() or 0.0
load_vertex_ai_credentials()
vertexai.init(
project="adroit-crow-413218",
location="us-central1",
api_endpoint=f"{LITE_LLM_ENDPOINT}/vertex-ai",
api_transport="rest",
)
model = GenerativeModel(model_name="gemini-1.0-pro")
response = model.generate_content("hi")
print("response", response)
await asyncio.sleep(20)
spend_after = await call_spend_logs_endpoint()
print("spend_after", spend_after)
assert (
spend_after > spend_before
), "Spend should be greater than before. spend_before: {}, spend_after: {}".format(
spend_before, spend_after
)
pass

View file

@ -0,0 +1,13 @@
{
"type": "service_account",
"project_id": "adroit-crow-413218",
"private_key_id": "",
"private_key": "",
"client_email": "test-adroit-crow@adroit-crow-413218.iam.gserviceaccount.com",
"client_id": "104886546564708740969",
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/test-adroit-crow%40adroit-crow-413218.iam.gserviceaccount.com",
"universe_domain": "googleapis.com"
}