feat(proxy_server.py): return litellm version in response headers

This commit is contained in:
Krrish Dholakia 2024-05-08 16:00:08 -07:00
parent 80378966a0
commit 6575143460
50 changed files with 260 additions and 140 deletions

View file

@ -16,11 +16,11 @@ repos:
name: Check if files match name: Check if files match
entry: python3 ci_cd/check_files_match.py entry: python3 ci_cd/check_files_match.py
language: system language: system
- repo: local # - repo: local
hooks: # hooks:
- id: mypy # - id: mypy
name: mypy # name: mypy
entry: python3 -m mypy --ignore-missing-imports # entry: python3 -m mypy --ignore-missing-imports
language: system # language: system
types: [python] # types: [python]
files: ^litellm/ # files: ^litellm/

View file

@ -291,7 +291,7 @@ def _create_clickhouse_aggregate_tables(client=None, table_names=[]):
def _forecast_daily_cost(data: list): def _forecast_daily_cost(data: list):
import requests import requests # type: ignore
from datetime import datetime, timedelta from datetime import datetime, timedelta
if len(data) == 0: if len(data) == 0:

View file

@ -10,8 +10,8 @@
# s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation # s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation
import os import os
import inspect import inspect
import redis, litellm import redis, litellm # type: ignore
import redis.asyncio as async_redis import redis.asyncio as async_redis # type: ignore
from typing import List, Optional from typing import List, Optional

View file

@ -10,7 +10,7 @@
import os, json, time import os, json, time
import litellm import litellm
from litellm.utils import ModelResponse from litellm.utils import ModelResponse
import requests, threading import requests, threading # type: ignore
from typing import Optional, Union, Literal from typing import Optional, Union, Literal

View file

@ -1,7 +1,6 @@
#### What this does #### #### What this does ####
# On success + failure, log events to aispend.io # On success + failure, log events to aispend.io
import dotenv, os import dotenv, os
import requests
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback

View file

@ -4,18 +4,30 @@ import datetime
class AthinaLogger: class AthinaLogger:
def __init__(self): def __init__(self):
import os import os
self.athina_api_key = os.getenv("ATHINA_API_KEY") self.athina_api_key = os.getenv("ATHINA_API_KEY")
self.headers = { self.headers = {
"athina-api-key": self.athina_api_key, "athina-api-key": self.athina_api_key,
"Content-Type": "application/json" "Content-Type": "application/json",
} }
self.athina_logging_url = "https://log.athina.ai/api/v1/log/inference" self.athina_logging_url = "https://log.athina.ai/api/v1/log/inference"
self.additional_keys = ["environment", "prompt_slug", "customer_id", "customer_user_id", "session_id", "external_reference_id", "context", "expected_response", "user_query"] self.additional_keys = [
"environment",
"prompt_slug",
"customer_id",
"customer_user_id",
"session_id",
"external_reference_id",
"context",
"expected_response",
"user_query",
]
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose): def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
import requests import requests # type: ignore
import json import json
import traceback import traceback
try: try:
response_json = response_obj.model_dump() if response_obj else {} response_json = response_obj.model_dump() if response_obj else {}
data = { data = {
@ -23,32 +35,51 @@ class AthinaLogger:
"request": kwargs, "request": kwargs,
"response": response_json, "response": response_json,
"prompt_tokens": response_json.get("usage", {}).get("prompt_tokens"), "prompt_tokens": response_json.get("usage", {}).get("prompt_tokens"),
"completion_tokens": response_json.get("usage", {}).get("completion_tokens"), "completion_tokens": response_json.get("usage", {}).get(
"completion_tokens"
),
"total_tokens": response_json.get("usage", {}).get("total_tokens"), "total_tokens": response_json.get("usage", {}).get("total_tokens"),
} }
if type(end_time) == datetime.datetime and type(start_time) == datetime.datetime: if (
data["response_time"] = int((end_time - start_time).total_seconds() * 1000) type(end_time) == datetime.datetime
and type(start_time) == datetime.datetime
):
data["response_time"] = int(
(end_time - start_time).total_seconds() * 1000
)
if "messages" in kwargs: if "messages" in kwargs:
data["prompt"] = kwargs.get("messages", None) data["prompt"] = kwargs.get("messages", None)
# Directly add tools or functions if present # Directly add tools or functions if present
optional_params = kwargs.get("optional_params", {}) optional_params = kwargs.get("optional_params", {})
data.update((k, v) for k, v in optional_params.items() if k in ["tools", "functions"]) data.update(
(k, v)
for k, v in optional_params.items()
if k in ["tools", "functions"]
)
# Add additional metadata keys # Add additional metadata keys
metadata = kwargs.get("litellm_params", {}).get("metadata", {}) metadata = kwargs.get("litellm_params", {}).get("metadata", {})
if metadata: if metadata:
for key in self.additional_keys: for key in self.additional_keys:
if key in metadata: if key in metadata:
data[key] = metadata[key] data[key] = metadata[key]
response = requests.post(self.athina_logging_url, headers=self.headers, data=json.dumps(data, default=str)) response = requests.post(
self.athina_logging_url,
headers=self.headers,
data=json.dumps(data, default=str),
)
if response.status_code != 200: if response.status_code != 200:
print_verbose(f"Athina Logger Error - {response.text}, {response.status_code}") print_verbose(
f"Athina Logger Error - {response.text}, {response.status_code}"
)
else: else:
print_verbose(f"Athina Logger Succeeded - {response.text}") print_verbose(f"Athina Logger Succeeded - {response.text}")
except Exception as e: except Exception as e:
print_verbose(f"Athina Logger Error - {e}, Stack trace: {traceback.format_exc()}") print_verbose(
pass f"Athina Logger Error - {e}, Stack trace: {traceback.format_exc()}"
)
pass

View file

@ -1,7 +1,7 @@
#### What this does #### #### What this does ####
# On success + failure, log events to aispend.io # On success + failure, log events to aispend.io
import dotenv, os import dotenv, os
import requests import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback

View file

@ -3,7 +3,6 @@
#### What this does #### #### What this does ####
# On success, logs events to Promptlayer # On success, logs events to Promptlayer
import dotenv, os import dotenv, os
import requests
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache from litellm.caching import DualCache

View file

@ -1,7 +1,6 @@
#### What this does #### #### What this does ####
# On success, logs events to Promptlayer # On success, logs events to Promptlayer
import dotenv, os import dotenv, os
import requests
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache from litellm.caching import DualCache

View file

@ -2,7 +2,7 @@
# On success + failure, log events to Supabase # On success + failure, log events to Supabase
import dotenv, os import dotenv, os
import requests import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback

View file

@ -2,7 +2,7 @@
# On success + failure, log events to Supabase # On success + failure, log events to Supabase
import dotenv, os import dotenv, os
import requests import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback

View file

@ -1,15 +1,17 @@
import requests import requests # type: ignore
import json import json
import traceback import traceback
from datetime import datetime, timezone from datetime import datetime, timezone
class GreenscaleLogger: class GreenscaleLogger:
def __init__(self): def __init__(self):
import os import os
self.greenscale_api_key = os.getenv("GREENSCALE_API_KEY") self.greenscale_api_key = os.getenv("GREENSCALE_API_KEY")
self.headers = { self.headers = {
"api-key": self.greenscale_api_key, "api-key": self.greenscale_api_key,
"Content-Type": "application/json" "Content-Type": "application/json",
} }
self.greenscale_logging_url = os.getenv("GREENSCALE_ENDPOINT") self.greenscale_logging_url = os.getenv("GREENSCALE_ENDPOINT")
@ -19,33 +21,48 @@ class GreenscaleLogger:
data = { data = {
"modelId": kwargs.get("model"), "modelId": kwargs.get("model"),
"inputTokenCount": response_json.get("usage", {}).get("prompt_tokens"), "inputTokenCount": response_json.get("usage", {}).get("prompt_tokens"),
"outputTokenCount": response_json.get("usage", {}).get("completion_tokens"), "outputTokenCount": response_json.get("usage", {}).get(
"completion_tokens"
),
} }
data["timestamp"] = datetime.now(timezone.utc).strftime('%Y-%m-%dT%H:%M:%SZ') data["timestamp"] = datetime.now(timezone.utc).strftime(
"%Y-%m-%dT%H:%M:%SZ"
if type(end_time) == datetime and type(start_time) == datetime: )
data["invocationLatency"] = int((end_time - start_time).total_seconds() * 1000)
if type(end_time) == datetime and type(start_time) == datetime:
data["invocationLatency"] = int(
(end_time - start_time).total_seconds() * 1000
)
# Add additional metadata keys to tags # Add additional metadata keys to tags
tags = [] tags = []
metadata = kwargs.get("litellm_params", {}).get("metadata", {}) metadata = kwargs.get("litellm_params", {}).get("metadata", {})
for key, value in metadata.items(): for key, value in metadata.items():
if key.startswith("greenscale"): if key.startswith("greenscale"):
if key == "greenscale_project": if key == "greenscale_project":
data["project"] = value data["project"] = value
elif key == "greenscale_application": elif key == "greenscale_application":
data["application"] = value data["application"] = value
else: else:
tags.append({"key": key.replace("greenscale_", ""), "value": str(value)}) tags.append(
{"key": key.replace("greenscale_", ""), "value": str(value)}
)
data["tags"] = tags data["tags"] = tags
response = requests.post(self.greenscale_logging_url, headers=self.headers, data=json.dumps(data, default=str)) response = requests.post(
self.greenscale_logging_url,
headers=self.headers,
data=json.dumps(data, default=str),
)
if response.status_code != 200: if response.status_code != 200:
print_verbose(f"Greenscale Logger Error - {response.text}, {response.status_code}") print_verbose(
f"Greenscale Logger Error - {response.text}, {response.status_code}"
)
else: else:
print_verbose(f"Greenscale Logger Succeeded - {response.text}") print_verbose(f"Greenscale Logger Succeeded - {response.text}")
except Exception as e: except Exception as e:
print_verbose(f"Greenscale Logger Error - {e}, Stack trace: {traceback.format_exc()}") print_verbose(
pass f"Greenscale Logger Error - {e}, Stack trace: {traceback.format_exc()}"
)
pass

View file

@ -1,7 +1,7 @@
#### What this does #### #### What this does ####
# On success, logs events to Helicone # On success, logs events to Helicone
import dotenv, os import dotenv, os
import requests import requests # type: ignore
import litellm import litellm
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv

View file

@ -1,15 +1,14 @@
#### What this does #### #### What this does ####
# On success, logs events to Langsmith # On success, logs events to Langsmith
import dotenv, os import dotenv, os # type: ignore
import requests import requests # type: ignore
import requests
from datetime import datetime from datetime import datetime
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
import asyncio import asyncio
import types import types
from pydantic import BaseModel from pydantic import BaseModel # type: ignore
def is_serializable(value): def is_serializable(value):
@ -79,8 +78,6 @@ class LangsmithLogger:
except: except:
response_obj = response_obj.dict() # type: ignore response_obj = response_obj.dict() # type: ignore
print(f"response_obj: {response_obj}")
data = { data = {
"name": run_name, "name": run_name,
"run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain" "run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain"
@ -90,7 +87,6 @@ class LangsmithLogger:
"start_time": start_time, "start_time": start_time,
"end_time": end_time, "end_time": end_time,
} }
print(f"data: {data}")
response = requests.post( response = requests.post(
"https://api.smith.langchain.com/runs", "https://api.smith.langchain.com/runs",

View file

@ -2,7 +2,6 @@
## On Success events log cost to OpenMeter - https://github.com/BerriAI/litellm/issues/1268 ## On Success events log cost to OpenMeter - https://github.com/BerriAI/litellm/issues/1268
import dotenv, os, json import dotenv, os, json
import requests
import litellm import litellm
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
@ -60,7 +59,7 @@ class OpenMeterLogger(CustomLogger):
"total_tokens": response_obj["usage"].get("total_tokens"), "total_tokens": response_obj["usage"].get("total_tokens"),
} }
subject = kwargs.get("user", None), # end-user passed in via 'user' param subject = (kwargs.get("user", None),) # end-user passed in via 'user' param
if not subject: if not subject:
raise Exception("OpenMeter: user is required") raise Exception("OpenMeter: user is required")

View file

@ -3,7 +3,7 @@
# On success, log events to Prometheus # On success, log events to Prometheus
import dotenv, os import dotenv, os
import requests import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
@ -19,7 +19,6 @@ class PrometheusLogger:
**kwargs, **kwargs,
): ):
try: try:
print(f"in init prometheus metrics")
from prometheus_client import Counter from prometheus_client import Counter
self.litellm_llm_api_failed_requests_metric = Counter( self.litellm_llm_api_failed_requests_metric = Counter(

View file

@ -4,7 +4,7 @@
import dotenv, os import dotenv, os
import requests import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
@ -183,7 +183,6 @@ class PrometheusServicesLogger:
) )
async def async_service_failure_hook(self, payload: ServiceLoggerPayload): async def async_service_failure_hook(self, payload: ServiceLoggerPayload):
print(f"received error payload: {payload.error}")
if self.mock_testing: if self.mock_testing:
self.mock_testing_failure_calls += 1 self.mock_testing_failure_calls += 1

View file

@ -1,12 +1,13 @@
#### What this does #### #### What this does ####
# On success, logs events to Promptlayer # On success, logs events to Promptlayer
import dotenv, os import dotenv, os
import requests import requests # type: ignore
from pydantic import BaseModel from pydantic import BaseModel
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback
class PromptLayerLogger: class PromptLayerLogger:
# Class variables or attributes # Class variables or attributes
def __init__(self): def __init__(self):
@ -32,7 +33,11 @@ class PromptLayerLogger:
tags = kwargs["litellm_params"]["metadata"]["pl_tags"] tags = kwargs["litellm_params"]["metadata"]["pl_tags"]
# Remove "pl_tags" from metadata # Remove "pl_tags" from metadata
metadata = {k:v for k, v in kwargs["litellm_params"]["metadata"].items() if k != "pl_tags"} metadata = {
k: v
for k, v in kwargs["litellm_params"]["metadata"].items()
if k != "pl_tags"
}
print_verbose( print_verbose(
f"Prompt Layer Logging - Enters logging function for model kwargs: {new_kwargs}\n, response: {response_obj}" f"Prompt Layer Logging - Enters logging function for model kwargs: {new_kwargs}\n, response: {response_obj}"

View file

@ -2,7 +2,6 @@
# On success + failure, log events to Supabase # On success + failure, log events to Supabase
import dotenv, os import dotenv, os
import requests
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback

View file

@ -2,7 +2,7 @@
# On success + failure, log events to Supabase # On success + failure, log events to Supabase
import dotenv, os import dotenv, os
import requests import requests # type: ignore
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv
import traceback import traceback

View file

@ -1,8 +1,8 @@
import os, types, traceback import os, types, traceback
import json import json
from enum import Enum from enum import Enum
import requests import requests # type: ignore
import time, httpx import time, httpx # type: ignore
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse, Choices, Message from litellm.utils import ModelResponse, Choices, Message
import litellm import litellm

View file

@ -1,12 +1,12 @@
import os, types import os, types
import json import json
from enum import Enum from enum import Enum
import requests import requests # type: ignore
import time import time
from typing import Callable, Optional from typing import Callable, Optional
import litellm import litellm
from litellm.utils import ModelResponse, Choices, Message, Usage from litellm.utils import ModelResponse, Choices, Message, Usage
import httpx import httpx # type: ignore
class AlephAlphaError(Exception): class AlephAlphaError(Exception):

View file

@ -1,7 +1,7 @@
import os, types import os, types
import json import json
from enum import Enum from enum import Enum
import requests, copy import requests, copy # type: ignore
import time import time
from typing import Callable, Optional, List from typing import Callable, Optional, List
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
@ -9,7 +9,7 @@ import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from .base import BaseLLM from .base import BaseLLM
import httpx import httpx # type: ignore
class AnthropicConstants(Enum): class AnthropicConstants(Enum):

View file

@ -1,5 +1,5 @@
from typing import Optional, Union, Any from typing import Optional, Union, Any
import types, requests import types, requests # type: ignore
from .base import BaseLLM from .base import BaseLLM
from litellm.utils import ( from litellm.utils import (
ModelResponse, ModelResponse,
@ -12,7 +12,7 @@ from litellm.utils import (
from typing import Callable, Optional, BinaryIO from typing import Callable, Optional, BinaryIO
from litellm import OpenAIConfig from litellm import OpenAIConfig
import litellm, json import litellm, json
import httpx import httpx # type: ignore
from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport from .custom_httpx.azure_dall_e_2 import CustomHTTPTransport, AsyncCustomHTTPTransport
from openai import AzureOpenAI, AsyncAzureOpenAI from openai import AzureOpenAI, AsyncAzureOpenAI
import uuid import uuid

View file

@ -1,5 +1,5 @@
from typing import Optional, Union, Any from typing import Optional, Union, Any
import types, requests import types, requests # type: ignore
from .base import BaseLLM from .base import BaseLLM
from litellm.utils import ( from litellm.utils import (
ModelResponse, ModelResponse,

View file

@ -1,7 +1,7 @@
import os import os
import json import json
from enum import Enum from enum import Enum
import requests import requests # type: ignore
import time import time
from typing import Callable from typing import Callable
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage

View file

@ -163,10 +163,9 @@ class AmazonAnthropicClaude3Config:
"stop", "stop",
"temperature", "temperature",
"top_p", "top_p",
"extra_headers" "extra_headers",
] ]
def map_openai_params(self, non_default_params: dict, optional_params: dict): def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items(): for param, value in non_default_params.items():
if param == "max_tokens": if param == "max_tokens":
@ -534,10 +533,12 @@ class AmazonStabilityConfig:
def add_custom_header(headers): def add_custom_header(headers):
"""Closure to capture the headers and add them.""" """Closure to capture the headers and add them."""
def callback(request, **kwargs): def callback(request, **kwargs):
"""Actual callback function that Boto3 will call.""" """Actual callback function that Boto3 will call."""
for header_name, header_value in headers.items(): for header_name, header_value in headers.items():
request.headers.add_header(header_name, header_value) request.headers.add_header(header_name, header_value)
return callback return callback
@ -672,7 +673,9 @@ def init_bedrock_client(
config=config, config=config,
) )
if extra_headers: if extra_headers:
client.meta.events.register('before-sign.bedrock-runtime.*', add_custom_header(extra_headers)) client.meta.events.register(
"before-sign.bedrock-runtime.*", add_custom_header(extra_headers)
)
return client return client
@ -1224,7 +1227,7 @@ def _embedding_func_single(
"input_type", "search_document" "input_type", "search_document"
) # aws bedrock example default - https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=cohere.embed-english-v3 ) # aws bedrock example default - https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=cohere.embed-english-v3
data = {"texts": [input], **inference_params} # type: ignore data = {"texts": [input], **inference_params} # type: ignore
body = json.dumps(data).encode("utf-8") body = json.dumps(data).encode("utf-8") # type: ignore
## LOGGING ## LOGGING
request_str = f""" request_str = f"""
response = client.invoke_model( response = client.invoke_model(
@ -1416,7 +1419,7 @@ def image_generation(
## LOGGING ## LOGGING
request_str = f""" request_str = f"""
response = client.invoke_model( response = client.invoke_model(
body={body}, body={body}, # type: ignore
modelId={modelId}, modelId={modelId},
accept="application/json", accept="application/json",
contentType="application/json", contentType="application/json",

View file

@ -1,11 +1,11 @@
import os, types import os, types
import json import json
from enum import Enum from enum import Enum
import requests import requests # type: ignore
import time import time
from typing import Callable, Optional from typing import Callable, Optional
import litellm import litellm
import httpx import httpx # type: ignore
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt

View file

@ -1,12 +1,12 @@
import os, types import os, types
import json import json
from enum import Enum from enum import Enum
import requests import requests # type: ignore
import time, traceback import time, traceback
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse, Choices, Message, Usage from litellm.utils import ModelResponse, Choices, Message, Usage
import litellm import litellm
import httpx import httpx # type: ignore
class CohereError(Exception): class CohereError(Exception):

View file

@ -1,12 +1,12 @@
import os, types import os, types
import json import json
from enum import Enum from enum import Enum
import requests import requests # type: ignore
import time, traceback import time, traceback
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse, Choices, Message, Usage from litellm.utils import ModelResponse, Choices, Message, Usage
import litellm import litellm
import httpx import httpx # type: ignore
from .prompt_templates.factory import cohere_message_pt from .prompt_templates.factory import cohere_message_pt

View file

@ -1,7 +1,7 @@
import os, types import os, types
import json import json
from enum import Enum from enum import Enum
import requests import requests # type: ignore
import time, traceback import time, traceback
from typing import Callable, Optional, List from typing import Callable, Optional, List
from litellm.utils import ModelResponse, Choices, Message, Usage from litellm.utils import ModelResponse, Choices, Message, Usage

View file

@ -1,7 +1,7 @@
import os, types import os, types
import json import json
from enum import Enum from enum import Enum
import requests import requests # type: ignore
import time import time
from typing import Callable, Optional from typing import Callable, Optional
import litellm import litellm

View file

@ -1,10 +1,10 @@
from itertools import chain from itertools import chain
import requests, types, time import requests, types, time # type: ignore
import json, uuid import json, uuid
import traceback import traceback
from typing import Optional from typing import Optional
import litellm import litellm
import httpx, aiohttp, asyncio import httpx, aiohttp, asyncio # type: ignore
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
@ -220,7 +220,10 @@ def get_ollama_response(
tool_calls=[ tool_calls=[
{ {
"id": f"call_{str(uuid.uuid4())}", "id": f"call_{str(uuid.uuid4())}",
"function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])}, "function": {
"name": function_call["name"],
"arguments": json.dumps(function_call["arguments"]),
},
"type": "function", "type": "function",
} }
], ],
@ -232,7 +235,9 @@ def get_ollama_response(
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = "ollama/" + model model_response["model"] = "ollama/" + model
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(prompt, disallowed_special=()))) # type: ignore prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(prompt, disallowed_special=()))) # type: ignore
completion_tokens = response_json.get("eval_count", len(response_json.get("message",dict()).get("content", ""))) completion_tokens = response_json.get(
"eval_count", len(response_json.get("message", dict()).get("content", ""))
)
model_response["usage"] = litellm.Usage( model_response["usage"] = litellm.Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
@ -273,7 +278,10 @@ def ollama_completion_stream(url, data, logging_obj):
tool_calls=[ tool_calls=[
{ {
"id": f"call_{str(uuid.uuid4())}", "id": f"call_{str(uuid.uuid4())}",
"function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])}, "function": {
"name": function_call["name"],
"arguments": json.dumps(function_call["arguments"]),
},
"type": "function", "type": "function",
} }
], ],
@ -314,9 +322,10 @@ async def ollama_async_streaming(url, data, model_response, encoding, logging_ob
first_chunk_content = first_chunk.choices[0].delta.content or "" first_chunk_content = first_chunk.choices[0].delta.content or ""
response_content = first_chunk_content + "".join( response_content = first_chunk_content + "".join(
[ [
chunk.choices[0].delta.content chunk.choices[0].delta.content
async for chunk in streamwrapper async for chunk in streamwrapper
if chunk.choices[0].delta.content] if chunk.choices[0].delta.content
]
) )
function_call = json.loads(response_content) function_call = json.loads(response_content)
delta = litellm.utils.Delta( delta = litellm.utils.Delta(
@ -324,7 +333,10 @@ async def ollama_async_streaming(url, data, model_response, encoding, logging_ob
tool_calls=[ tool_calls=[
{ {
"id": f"call_{str(uuid.uuid4())}", "id": f"call_{str(uuid.uuid4())}",
"function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])}, "function": {
"name": function_call["name"],
"arguments": json.dumps(function_call["arguments"]),
},
"type": "function", "type": "function",
} }
], ],
@ -373,7 +385,10 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
tool_calls=[ tool_calls=[
{ {
"id": f"call_{str(uuid.uuid4())}", "id": f"call_{str(uuid.uuid4())}",
"function": {"name": function_call["name"], "arguments": json.dumps(function_call["arguments"])}, "function": {
"name": function_call["name"],
"arguments": json.dumps(function_call["arguments"]),
},
"type": "function", "type": "function",
} }
], ],
@ -387,7 +402,10 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
model_response["created"] = int(time.time()) model_response["created"] = int(time.time())
model_response["model"] = "ollama/" + data["model"] model_response["model"] = "ollama/" + data["model"]
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(data["prompt"], disallowed_special=()))) # type: ignore prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(data["prompt"], disallowed_special=()))) # type: ignore
completion_tokens = response_json.get("eval_count", len(response_json.get("message",dict()).get("content", ""))) completion_tokens = response_json.get(
"eval_count",
len(response_json.get("message", dict()).get("content", "")),
)
model_response["usage"] = litellm.Usage( model_response["usage"] = litellm.Usage(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
@ -475,6 +493,7 @@ async def ollama_aembeddings(
} }
return model_response return model_response
def ollama_embeddings( def ollama_embeddings(
api_base: str, api_base: str,
model: str, model: str,
@ -492,5 +511,6 @@ def ollama_embeddings(
optional_params, optional_params,
logging_obj, logging_obj,
model_response, model_response,
encoding) encoding,
) )
)

View file

@ -1,7 +1,7 @@
import os import os
import json import json
from enum import Enum from enum import Enum
import requests import requests # type: ignore
import time import time
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage

View file

@ -22,7 +22,6 @@ from litellm.utils import (
TextCompletionResponse, TextCompletionResponse,
) )
from typing import Callable, Optional from typing import Callable, Optional
import aiohttp, requests
import litellm import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
from openai import OpenAI, AsyncOpenAI from openai import OpenAI, AsyncOpenAI

View file

@ -1,7 +1,7 @@
import os, types import os, types
import json import json
from enum import Enum from enum import Enum
import requests import requests # type: ignore
import time import time
from typing import Callable, Optional from typing import Callable, Optional
import litellm import litellm

View file

@ -1,11 +1,11 @@
import os, types import os, types
import json import json
import requests import requests # type: ignore
import time import time
from typing import Callable, Optional from typing import Callable, Optional
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage
import litellm import litellm
import httpx import httpx # type: ignore
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt

View file

@ -1,14 +1,14 @@
import os, types, traceback import os, types, traceback
from enum import Enum from enum import Enum
import json import json
import requests import requests # type: ignore
import time import time
from typing import Callable, Optional, Any from typing import Callable, Optional, Any
import litellm import litellm
from litellm.utils import ModelResponse, EmbeddingResponse, get_secret, Usage from litellm.utils import ModelResponse, EmbeddingResponse, get_secret, Usage
import sys import sys
from copy import deepcopy from copy import deepcopy
import httpx import httpx # type: ignore
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt
@ -295,7 +295,7 @@ def completion(
EndpointName={model}, EndpointName={model},
InferenceComponentName={model_id}, InferenceComponentName={model_id},
ContentType="application/json", ContentType="application/json",
Body={data}, Body={data}, # type: ignore
CustomAttributes="accept_eula=true", CustomAttributes="accept_eula=true",
) )
""" # type: ignore """ # type: ignore
@ -321,7 +321,7 @@ def completion(
response = client.invoke_endpoint( response = client.invoke_endpoint(
EndpointName={model}, EndpointName={model},
ContentType="application/json", ContentType="application/json",
Body={data}, Body={data}, # type: ignore
CustomAttributes="accept_eula=true", CustomAttributes="accept_eula=true",
) )
""" # type: ignore """ # type: ignore
@ -688,7 +688,7 @@ def embedding(
response = client.invoke_endpoint( response = client.invoke_endpoint(
EndpointName={model}, EndpointName={model},
ContentType="application/json", ContentType="application/json",
Body={data}, Body={data}, # type: ignore
CustomAttributes="accept_eula=true", CustomAttributes="accept_eula=true",
)""" # type: ignore )""" # type: ignore
logging_obj.pre_call( logging_obj.pre_call(

View file

@ -6,11 +6,11 @@ Reference: https://docs.together.ai/docs/openai-api-compatibility
import os, types import os, types
import json import json
from enum import Enum from enum import Enum
import requests import requests # type: ignore
import time import time
from typing import Callable, Optional from typing import Callable, Optional
import litellm import litellm
import httpx import httpx # type: ignore
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt

View file

@ -1,12 +1,12 @@
import os, types import os, types
import json import json
from enum import Enum from enum import Enum
import requests import requests # type: ignore
import time import time
from typing import Callable, Optional, Union, List from typing import Callable, Optional, Union, List
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
import litellm, uuid import litellm, uuid
import httpx, inspect import httpx, inspect # type: ignore
class VertexAIError(Exception): class VertexAIError(Exception):

View file

@ -3,7 +3,7 @@
import os, types import os, types
import json import json
from enum import Enum from enum import Enum
import requests, copy import requests, copy # type: ignore
import time, uuid import time, uuid
from typing import Callable, Optional, List from typing import Callable, Optional, List
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
@ -17,7 +17,7 @@ from .prompt_templates.factory import (
extract_between_tags, extract_between_tags,
parse_xml_params, parse_xml_params,
) )
import httpx import httpx # type: ignore
class VertexAIError(Exception): class VertexAIError(Exception):

View file

@ -1,8 +1,8 @@
import os import os
import json import json
from enum import Enum from enum import Enum
import requests import requests # type: ignore
import time, httpx import time, httpx # type: ignore
from typing import Callable, Any from typing import Callable, Any
from litellm.utils import ModelResponse, Usage from litellm.utils import ModelResponse, Usage
from .prompt_templates.factory import prompt_factory, custom_prompt from .prompt_templates.factory import prompt_factory, custom_prompt

View file

@ -3,8 +3,8 @@ import json, types, time # noqa: E401
from contextlib import contextmanager from contextlib import contextmanager
from typing import Callable, Dict, Optional, Any, Union, List from typing import Callable, Dict, Optional, Any, Union, List
import httpx import httpx # type: ignore
import requests import requests # type: ignore
import litellm import litellm
from litellm.utils import ModelResponse, get_secret, Usage from litellm.utils import ModelResponse, get_secret, Usage

View file

@ -252,7 +252,7 @@ def run_server(
if model and "ollama" in model and api_base is None: if model and "ollama" in model and api_base is None:
run_ollama_serve() run_ollama_serve()
if test_async is True: if test_async is True:
import requests, concurrent, time import requests, concurrent, time # type: ignore
api_base = f"http://{host}:{port}" api_base = f"http://{host}:{port}"
@ -418,7 +418,7 @@ def run_server(
read from there and save it to os.env['DATABASE_URL'] read from there and save it to os.env['DATABASE_URL']
""" """
try: try:
import yaml, asyncio import yaml, asyncio # type: ignore
except: except:
raise ImportError( raise ImportError(
"yaml needs to be imported. Run - `pip install 'litellm[proxy]'`" "yaml needs to be imported. Run - `pip install 'litellm[proxy]'`"

View file

@ -30,7 +30,7 @@ sys.path.insert(
try: try:
import fastapi import fastapi
import backoff import backoff
import yaml import yaml # type: ignore
import orjson import orjson
import logging import logging
from apscheduler.schedulers.asyncio import AsyncIOScheduler from apscheduler.schedulers.asyncio import AsyncIOScheduler
@ -3719,6 +3719,7 @@ async def chat_completion(
"x-litellm-model-id": model_id, "x-litellm-model-id": model_id,
"x-litellm-cache-key": cache_key, "x-litellm-cache-key": cache_key,
"x-litellm-model-api-base": api_base, "x-litellm-model-api-base": api_base,
"x-litellm-version": version,
} }
selected_data_generator = select_data_generator( selected_data_generator = select_data_generator(
response=response, response=response,
@ -3734,6 +3735,7 @@ async def chat_completion(
fastapi_response.headers["x-litellm-model-id"] = model_id fastapi_response.headers["x-litellm-model-id"] = model_id
fastapi_response.headers["x-litellm-cache-key"] = cache_key fastapi_response.headers["x-litellm-cache-key"] = cache_key
fastapi_response.headers["x-litellm-model-api-base"] = api_base fastapi_response.headers["x-litellm-model-api-base"] = api_base
fastapi_response.headers["x-litellm-version"] = version
### CALL HOOKS ### - modify outgoing data ### CALL HOOKS ### - modify outgoing data
response = await proxy_logging_obj.post_call_success_hook( response = await proxy_logging_obj.post_call_success_hook(
@ -3890,14 +3892,10 @@ async def completion(
}, },
) )
if hasattr(response, "_hidden_params"): hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = response._hidden_params.get("model_id", None) or "" model_id = hidden_params.get("model_id", None) or ""
original_response = ( cache_key = hidden_params.get("cache_key", None) or ""
response._hidden_params.get("original_response", None) or "" api_base = hidden_params.get("api_base", None) or ""
)
else:
model_id = ""
original_response = ""
verbose_proxy_logger.debug("final response: %s", response) verbose_proxy_logger.debug("final response: %s", response)
if ( if (
@ -3905,6 +3903,9 @@ async def completion(
): # use generate_responses to stream responses ): # use generate_responses to stream responses
custom_headers = { custom_headers = {
"x-litellm-model-id": model_id, "x-litellm-model-id": model_id,
"x-litellm-cache-key": cache_key,
"x-litellm-model-api-base": api_base,
"x-litellm-version": version,
} }
selected_data_generator = select_data_generator( selected_data_generator = select_data_generator(
response=response, response=response,
@ -3919,6 +3920,10 @@ async def completion(
) )
fastapi_response.headers["x-litellm-model-id"] = model_id fastapi_response.headers["x-litellm-model-id"] = model_id
fastapi_response.headers["x-litellm-cache-key"] = cache_key
fastapi_response.headers["x-litellm-model-api-base"] = api_base
fastapi_response.headers["x-litellm-version"] = version
return response return response
except Exception as e: except Exception as e:
data["litellm_status"] = "fail" # used for alerting data["litellm_status"] = "fail" # used for alerting
@ -3958,6 +3963,7 @@ async def completion(
) # azure compatible endpoint ) # azure compatible endpoint
async def embeddings( async def embeddings(
request: Request, request: Request,
fastapi_response: Response,
model: Optional[str] = None, model: Optional[str] = None,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
): ):
@ -4104,6 +4110,17 @@ async def embeddings(
### ALERTING ### ### ALERTING ###
data["litellm_status"] = "success" # used for alerting data["litellm_status"] = "success" # used for alerting
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers["x-litellm-model-id"] = model_id
fastapi_response.headers["x-litellm-cache-key"] = cache_key
fastapi_response.headers["x-litellm-model-api-base"] = api_base
fastapi_response.headers["x-litellm-version"] = version
return response return response
except Exception as e: except Exception as e:
data["litellm_status"] = "fail" # used for alerting data["litellm_status"] = "fail" # used for alerting
@ -4142,6 +4159,7 @@ async def embeddings(
) )
async def image_generation( async def image_generation(
request: Request, request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
): ):
global proxy_logging_obj global proxy_logging_obj
@ -4261,6 +4279,17 @@ async def image_generation(
### ALERTING ### ### ALERTING ###
data["litellm_status"] = "success" # used for alerting data["litellm_status"] = "success" # used for alerting
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers["x-litellm-model-id"] = model_id
fastapi_response.headers["x-litellm-cache-key"] = cache_key
fastapi_response.headers["x-litellm-model-api-base"] = api_base
fastapi_response.headers["x-litellm-version"] = version
return response return response
except Exception as e: except Exception as e:
data["litellm_status"] = "fail" # used for alerting data["litellm_status"] = "fail" # used for alerting
@ -4297,6 +4326,7 @@ async def image_generation(
) )
async def audio_transcriptions( async def audio_transcriptions(
request: Request, request: Request,
fastapi_response: Response,
file: UploadFile = File(...), file: UploadFile = File(...),
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
): ):
@ -4441,6 +4471,18 @@ async def audio_transcriptions(
### ALERTING ### ### ALERTING ###
data["litellm_status"] = "success" # used for alerting data["litellm_status"] = "success" # used for alerting
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers["x-litellm-model-id"] = model_id
fastapi_response.headers["x-litellm-cache-key"] = cache_key
fastapi_response.headers["x-litellm-model-api-base"] = api_base
fastapi_response.headers["x-litellm-version"] = version
return response return response
except Exception as e: except Exception as e:
data["litellm_status"] = "fail" # used for alerting data["litellm_status"] = "fail" # used for alerting
@ -4480,6 +4522,7 @@ async def audio_transcriptions(
) )
async def moderations( async def moderations(
request: Request, request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
): ):
""" """
@ -4604,6 +4647,17 @@ async def moderations(
### ALERTING ### ### ALERTING ###
data["litellm_status"] = "success" # used for alerting data["litellm_status"] = "success" # used for alerting
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers["x-litellm-model-id"] = model_id
fastapi_response.headers["x-litellm-cache-key"] = cache_key
fastapi_response.headers["x-litellm-model-api-base"] = api_base
fastapi_response.headers["x-litellm-version"] = version
return response return response
except Exception as e: except Exception as e:
data["litellm_status"] = "fail" # used for alerting data["litellm_status"] = "fail" # used for alerting

View file

@ -1689,12 +1689,12 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
module_file_path = os.path.join(directory, *module_name.split(".")) module_file_path = os.path.join(directory, *module_name.split("."))
module_file_path += ".py" module_file_path += ".py"
spec = importlib.util.spec_from_file_location(module_name, module_file_path) spec = importlib.util.spec_from_file_location(module_name, module_file_path) # type: ignore
if spec is None: if spec is None:
raise ImportError( raise ImportError(
f"Could not find a module specification for {module_file_path}" f"Could not find a module specification for {module_file_path}"
) )
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec) # type: ignore
spec.loader.exec_module(module) # type: ignore spec.loader.exec_module(module) # type: ignore
else: else:
# Dynamically import the module # Dynamically import the module

View file

@ -6,7 +6,7 @@
# - use litellm.success + failure callbacks to log when a request completed # - use litellm.success + failure callbacks to log when a request completed
# - in get_available_deployment, for a given model group name -> pick based on traffic # - in get_available_deployment, for a given model group name -> pick based on traffic
import dotenv, os, requests, random import dotenv, os, requests, random # type: ignore
from typing import Optional from typing import Optional
dotenv.load_dotenv() # Loading env variables using dotenv dotenv.load_dotenv() # Loading env variables using dotenv

View file

@ -1,7 +1,7 @@
#### What this does #### #### What this does ####
# picks based on response time (for streaming, this is time to first token) # picks based on response time (for streaming, this is time to first token)
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import BaseModel, Extra, Field, root_validator
import dotenv, os, requests, random import dotenv, os, requests, random # type: ignore
from typing import Optional, Union, List, Dict from typing import Optional, Union, List, Dict
from datetime import datetime, timedelta from datetime import datetime, timedelta
import random import random

View file

@ -1,7 +1,7 @@
#### What this does #### #### What this does ####
# picks based on response time (for streaming, this is time to first token) # picks based on response time (for streaming, this is time to first token)
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import BaseModel, Extra, Field, root_validator # type: ignore
import dotenv, os, requests, random import dotenv, os, requests, random # type: ignore
from typing import Optional, Union, List, Dict from typing import Optional, Union, List, Dict
from datetime import datetime, timedelta from datetime import datetime, timedelta
import random import random

View file

@ -14,7 +14,7 @@ import subprocess, os
from os.path import abspath, join, dirname from os.path import abspath, join, dirname
import litellm, openai import litellm, openai
import itertools import itertools
import random, uuid, requests import random, uuid, requests # type: ignore
from functools import wraps from functools import wraps
import datetime, time import datetime, time
import tiktoken import tiktoken
@ -36,7 +36,7 @@ import litellm._service_logger # for storing API inputs, outputs, and metadata
try: try:
# this works in python 3.8 # this works in python 3.8
import pkg_resources import pkg_resources # type: ignore
filename = pkg_resources.resource_filename(__name__, "llms/tokenizers") filename = pkg_resources.resource_filename(__name__, "llms/tokenizers")
# try: # try:
@ -7732,11 +7732,11 @@ def _calculate_retry_after(
try: try:
retry_after = int(retry_header) retry_after = int(retry_header)
except Exception: except Exception:
retry_date_tuple = email.utils.parsedate_tz(retry_header) retry_date_tuple = email.utils.parsedate_tz(retry_header) # type: ignore
if retry_date_tuple is None: if retry_date_tuple is None:
retry_after = -1 retry_after = -1
else: else:
retry_date = email.utils.mktime_tz(retry_date_tuple) retry_date = email.utils.mktime_tz(retry_date_tuple) # type: ignore
retry_after = int(retry_date - time.time()) retry_after = int(retry_date - time.time())
else: else:
retry_after = -1 retry_after = -1
@ -9423,7 +9423,9 @@ def get_secret(
else: else:
secret = os.environ.get(secret_name) secret = os.environ.get(secret_name)
try: try:
secret_value_as_bool = ast.literal_eval(secret) if secret is not None else None secret_value_as_bool = (
ast.literal_eval(secret) if secret is not None else None
)
if isinstance(secret_value_as_bool, bool): if isinstance(secret_value_as_bool, bool):
return secret_value_as_bool return secret_value_as_bool
else: else: