updates to proxy

This commit is contained in:
Krrish Dholakia 2023-09-28 17:58:39 -07:00
parent 13ff65a8fe
commit 4665b2a898
4 changed files with 38 additions and 10 deletions

View file

@ -3,8 +3,20 @@ sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
try:
import uvicorn
import fastapi
except ImportError:
import subprocess
import sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "uvicorn", "fastapi"])
print()
print("\033[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new\033[0m")
print()
print()
import litellm
print(litellm.__file__)
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
import json
@ -16,23 +28,31 @@ user_model = None
user_debug = False
user_max_tokens = None
user_temperature = None
user_telemetry = False
#### HELPER FUNCTIONS ####
def print_verbose(print_statement):
global user_debug
print(f"user_debug: {user_debug}")
if user_debug:
print(print_statement)
def initialize(model, api_base, debug, temperature, max_tokens):
global user_model, user_api_base, user_debug, user_max_tokens, user_temperature
def usage_telemetry(): # helps us know if people are using this feature. Set `litellm --telemetry False` to your cli call to turn this off
if user_telemetry:
data = {
"feature": "local_proxy_server"
}
litellm.utils.litellm_telemetry(data=data)
def initialize(model, api_base, debug, temperature, max_tokens, telemetry):
global user_model, user_api_base, user_debug, user_max_tokens, user_temperature, user_telemetry
user_model = model
user_api_base = api_base
user_debug = debug
user_max_tokens = max_tokens
user_temperature = temperature
# if debug:
# litellm.set_verbose = True
user_telemetry = telemetry
usage_telemetry()
# for streaming
def data_generator(response):
@ -41,7 +61,8 @@ def data_generator(response):
print(f"chunk: {chunk}")
print_verbose(f"returned chunk: {chunk}")
yield f"data: {json.dumps(chunk)}\n\n"
#### API ENDPOINTS ####
@app.get("/models") # if project requires model list
def model_list():
return dict(