Merge branch 'main' into litellm_call_id_in_response

This commit is contained in:
Krish Dholakia 2024-07-11 21:54:49 -07:00 committed by GitHub
commit 72f1c9181d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
119 changed files with 4737 additions and 1868 deletions

View file

@ -1,24 +1,18 @@
import ast
import asyncio
import copy
import hashlib
import importlib
import inspect
import os
import platform
import random
import re
import secrets
import shutil
import subprocess
import sys
import threading
import time
import traceback
import uuid
import warnings
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Set, get_args
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, List, Optional
import requests
@ -106,7 +100,6 @@ import litellm
from litellm import (
CancelBatchRequest,
CreateBatchRequest,
CreateFileRequest,
ListBatchRequest,
RetrieveBatchRequest,
)
@ -174,6 +167,9 @@ from litellm.proxy.management_endpoints.key_management_endpoints import (
router as key_management_router,
)
from litellm.proxy.management_endpoints.team_endpoints import router as team_router
from litellm.proxy.openai_files_endpoints.files_endpoints import (
router as openai_files_router,
)
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
initialize_pass_through_endpoints,
)
@ -213,6 +209,12 @@ from litellm.router import (
from litellm.router import ModelInfo as RouterModelInfo
from litellm.router import updateDeployment
from litellm.scheduler import DefaultPriorities, FlowItem, Scheduler
from litellm.types.llms.anthropic import (
AnthropicMessagesRequest,
AnthropicResponse,
AnthropicResponseContentBlockText,
AnthropicResponseUsageBlock,
)
from litellm.types.llms.openai import HttpxBinaryResponseContent
from litellm.types.router import RouterGeneralSettings
@ -2667,6 +2669,11 @@ async def startup_event():
def model_list(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Use `/model/info` - to get detailed model information, example - pricing, mode, etc.
This is just for compatibility with openai projects like aider.
"""
global llm_model_list, general_settings
all_models = []
## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
@ -2791,7 +2798,7 @@ async def chat_completion(
## LOGGING OBJECT ## - initialize logging object for logging success/failure events for call
## IMPORTANT Note: - initialize this before running pre-call checks. Ensures we log rejected requests to langfuse.
data["litellm_call_id"] = str(uuid.uuid4())
data["litellm_call_id"] = request.headers.get('x-litellm-call-id', str(uuid.uuid4()))
logging_obj, data = litellm.utils.function_setup(
original_function="acompletion",
rules_obj=litellm.utils.Rules(),
@ -3243,6 +3250,12 @@ async def completion(
response_class=ORJSONResponse,
tags=["embeddings"],
)
@router.post(
"/engines/{model:path}/embeddings",
dependencies=[Depends(user_api_key_auth)],
response_class=ORJSONResponse,
tags=["embeddings"],
) # azure compatible endpoint
@router.post(
"/openai/deployments/{model:path}/embeddings",
dependencies=[Depends(user_api_key_auth)],
@ -4891,117 +4904,6 @@ async def retrieve_batch(
######################################################################
######################################################################
# /v1/files Endpoints
######################################################################
@router.post(
"/v1/files",
dependencies=[Depends(user_api_key_auth)],
tags=["files"],
)
@router.post(
"/files",
dependencies=[Depends(user_api_key_auth)],
tags=["files"],
)
async def create_file(
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Upload a file that can be used across - Assistants API, Batch API
This is the equivalent of POST https://api.openai.com/v1/files
Supports Identical Params as: https://platform.openai.com/docs/api-reference/files/create
Example Curl
```
curl https://api.openai.com/v1/files \
-H "Authorization: Bearer sk-1234" \
-F purpose="batch" \
-F file="@mydata.jsonl"
```
"""
global proxy_logging_obj
data: Dict = {}
try:
# Use orjson to parse JSON data, orjson speeds up requests significantly
form_data = await request.form()
data = {key: value for key, value in form_data.items() if key != "file"}
# Include original request and headers in the data
data = await add_litellm_data_to_request(
data=data,
request=request,
general_settings=general_settings,
user_api_key_dict=user_api_key_dict,
version=version,
proxy_config=proxy_config,
)
_create_file_request = CreateFileRequest()
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
response = await litellm.acreate_file(
custom_llm_provider="openai", **_create_file_request
)
### ALERTING ###
asyncio.create_task(
proxy_logging_obj.update_request_status(
litellm_call_id=data.get("litellm_call_id", ""), status="success"
)
)
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
)
)
return response
except Exception as e:
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.create_file(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e.detail)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
@router.post(
"/v1/moderations",
@ -5150,6 +5052,198 @@ async def moderations(
)
#### ANTHROPIC ENDPOINTS ####
@router.post(
"/v1/messages",
tags=["[beta] Anthropic `/v1/messages`"],
dependencies=[Depends(user_api_key_auth)],
response_model=AnthropicResponse,
)
async def anthropic_response(
anthropic_data: AnthropicMessagesRequest,
fastapi_response: Response,
request: Request,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
from litellm import adapter_completion
from litellm.adapters.anthropic_adapter import anthropic_adapter
litellm.adapters = [{"id": "anthropic", "adapter": anthropic_adapter}]
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
data: dict = {**anthropic_data, "adapter_id": "anthropic"}
try:
data["model"] = (
general_settings.get("completion_model", None) # server default
or user_model # model name passed via cli args
or data["model"] # default passed in http request
)
if user_model:
data["model"] = user_model
data = await add_litellm_data_to_request(
data=data, # type: ignore
request=request,
general_settings=general_settings,
user_api_key_dict=user_api_key_dict,
version=version,
proxy_config=proxy_config,
)
# override with user settings, these are params passed via cli
if user_temperature:
data["temperature"] = user_temperature
if user_request_timeout:
data["request_timeout"] = user_request_timeout
if user_max_tokens:
data["max_tokens"] = user_max_tokens
if user_api_base:
data["api_base"] = user_api_base
### MODEL ALIAS MAPPING ###
# check if model name in model alias map
# get the actual model name
if data["model"] in litellm.model_alias_map:
data["model"] = litellm.model_alias_map[data["model"]]
### CALL HOOKS ### - modify incoming data before calling the model
data = await proxy_logging_obj.pre_call_hook( # type: ignore
user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion"
)
### ROUTE THE REQUESTs ###
router_model_names = llm_router.model_names if llm_router is not None else []
# skip router if user passed their key
if "api_key" in data:
llm_response = asyncio.create_task(litellm.aadapter_completion(**data))
elif (
llm_router is not None and data["model"] in router_model_names
): # model in router model list
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data))
elif (
llm_router is not None
and llm_router.model_group_alias is not None
and data["model"] in llm_router.model_group_alias
): # model set in model_group_alias
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data))
elif (
llm_router is not None and data["model"] in llm_router.deployment_names
): # model in router deployments, calling a specific deployment on the router
llm_response = asyncio.create_task(
llm_router.aadapter_completion(**data, specific_deployment=True)
)
elif (
llm_router is not None and data["model"] in llm_router.get_model_ids()
): # model in router model list
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data))
elif (
llm_router is not None
and data["model"] not in router_model_names
and llm_router.default_deployment is not None
): # model in router deployments, calling a specific deployment on the router
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data))
elif user_model is not None: # `litellm --model <your-model-name>`
llm_response = asyncio.create_task(litellm.aadapter_completion(**data))
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": "completion: Invalid model name passed in model="
+ data.get("model", "")
},
)
# Await the llm_response task
response = await llm_response
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
response_cost = hidden_params.get("response_cost", None) or ""
### ALERTING ###
asyncio.create_task(
proxy_logging_obj.update_request_status(
litellm_call_id=data.get("litellm_call_id", ""), status="success"
)
)
verbose_proxy_logger.debug("final response: %s", response)
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
response_cost=response_cost,
)
)
verbose_proxy_logger.info("\nResponse from Litellm:\n{}".format(response))
return response
except RejectedRequestError as e:
_data = e.request_data
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict,
original_exception=e,
request_data=_data,
)
if _data.get("stream", None) is not None and _data["stream"] == True:
_chat_response = litellm.ModelResponse()
_usage = litellm.Usage(
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
)
_chat_response.usage = _usage # type: ignore
_chat_response.choices[0].message.content = e.message # type: ignore
_iterator = litellm.utils.ModelResponseIterator(
model_response=_chat_response, convert_to_delta=True
)
_streaming_response = litellm.TextCompletionStreamWrapper(
completion_stream=_iterator,
model=_data.get("model", ""),
)
selected_data_generator = select_data_generator(
response=_streaming_response,
user_api_key_dict=user_api_key_dict,
request_data=data,
)
return StreamingResponse(
selected_data_generator,
media_type="text/event-stream",
headers={},
)
else:
_response = litellm.TextCompletionResponse()
_response.choices[0].text = e.message
return _response
except Exception as e:
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.completion(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
#### DEV UTILS ####
# @router.get(
@ -9302,3 +9396,4 @@ app.include_router(caching_router)
app.include_router(analytics_router)
app.include_router(debugging_endpoints_router)
app.include_router(ui_crud_endpoints_router)
app.include_router(openai_files_router)