This commit is contained in:
coconut49 2023-10-18 01:47:56 +08:00
commit cfeaa79bea
8 changed files with 106 additions and 18 deletions

View file

@ -2,6 +2,7 @@ import sys, os, platform, time, copy
import threading
import shutil, random, traceback
messages = []
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path - for litellm local dev
@ -72,6 +73,7 @@ print()
import litellm
from fastapi import FastAPI, Request
from fastapi.routing import APIRouter
from fastapi.encoders import jsonable_encoder
from fastapi.responses import StreamingResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
import json
@ -111,6 +113,13 @@ def print_verbose(print_statement):
print(print_statement)
def find_avatar_url(role):
role = role.replace(" ", "%20")
avatar_filename = f"avatars/{role}.png"
avatar_url = f"/static/{avatar_filename}"
return avatar_url
def usage_telemetry(
feature: str): # helps us know if people are using this feature. Set `litellm --telemetry False` to your cli call to turn this off
if user_telemetry:
@ -461,24 +470,21 @@ def model_list():
)
@router.post("/v1/completions")
@router.post("/completions")
async def completion(request: Request):
data = await request.json()
print_verbose(f"data passed in: {data}")
return litellm_completion(data=data, type="completion", user_model=user_model, user_temperature=user_temperature,
user_max_tokens=user_max_tokens, user_api_base=user_api_base, user_headers=user_headers,
user_debug=user_debug)
@router.post("/v1/chat/completions")
@router.post("/chat/completions")
async def chat_completion(request: Request):
data = await request.json()
print_verbose(f"data passed in: {data}")
return litellm_completion(data, type="chat_completion", user_model=user_model,
user_temperature=user_temperature, user_max_tokens=user_max_tokens,
user_api_base=user_api_base, user_headers=user_headers, user_debug=user_debug)
user_temperature=user_temperature, user_max_tokens=user_max_tokens,
user_api_base=user_api_base, user_headers=user_headers, user_debug=user_debug)
def print_cost_logs():