updates to proxy

This commit is contained in:
Krrish Dholakia 2023-09-28 17:58:39 -07:00
parent 13ff65a8fe
commit 4665b2a898
4 changed files with 38 additions and 10 deletions

View file

@ -9,9 +9,10 @@ load_dotenv()
@click.option('--debug', is_flag=True, help='To debug the input') @click.option('--debug', is_flag=True, help='To debug the input')
@click.option('--temperature', default=None, type=float, help='Set temperature for the model') @click.option('--temperature', default=None, type=float, help='Set temperature for the model')
@click.option('--max_tokens', default=None, help='Set max tokens for the model') @click.option('--max_tokens', default=None, help='Set max tokens for the model')
def run_server(port, api_base, model, debug, temperature, max_tokens): @click.option('--telemetry', default=True, type=bool, help='Helps us know if people are using this feature. Turn this off by doing `--telemetry False`')
def run_server(port, api_base, model, debug, temperature, max_tokens, telemetry):
from .proxy_server import app, initialize from .proxy_server import app, initialize
initialize(model, api_base, debug, temperature, max_tokens) initialize(model, api_base, debug, temperature, max_tokens, telemetry)
try: try:
import uvicorn import uvicorn
except: except:

View file

@ -3,8 +3,20 @@ sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
try:
import uvicorn
import fastapi
except ImportError:
import subprocess
import sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "uvicorn", "fastapi"])
print()
print("\033[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new\033[0m")
print()
print()
import litellm import litellm
print(litellm.__file__)
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
import json import json
@ -16,23 +28,31 @@ user_model = None
user_debug = False user_debug = False
user_max_tokens = None user_max_tokens = None
user_temperature = None user_temperature = None
user_telemetry = False
#### HELPER FUNCTIONS ####
def print_verbose(print_statement): def print_verbose(print_statement):
global user_debug global user_debug
print(f"user_debug: {user_debug}") print(f"user_debug: {user_debug}")
if user_debug: if user_debug:
print(print_statement) print(print_statement)
def initialize(model, api_base, debug, temperature, max_tokens): def usage_telemetry(): # helps us know if people are using this feature. Set `litellm --telemetry False` to your cli call to turn this off
global user_model, user_api_base, user_debug, user_max_tokens, user_temperature if user_telemetry:
data = {
"feature": "local_proxy_server"
}
litellm.utils.litellm_telemetry(data=data)
def initialize(model, api_base, debug, temperature, max_tokens, telemetry):
global user_model, user_api_base, user_debug, user_max_tokens, user_temperature, user_telemetry
user_model = model user_model = model
user_api_base = api_base user_api_base = api_base
user_debug = debug user_debug = debug
user_max_tokens = max_tokens user_max_tokens = max_tokens
user_temperature = temperature user_temperature = temperature
user_telemetry = telemetry
# if debug: usage_telemetry()
# litellm.set_verbose = True
# for streaming # for streaming
def data_generator(response): def data_generator(response):
@ -42,6 +62,7 @@ def data_generator(response):
print_verbose(f"returned chunk: {chunk}") print_verbose(f"returned chunk: {chunk}")
yield f"data: {json.dumps(chunk)}\n\n" yield f"data: {json.dumps(chunk)}\n\n"
#### API ENDPOINTS ####
@app.get("/models") # if project requires model list @app.get("/models") # if project requires model list
def model_list(): def model_list():
return dict( return dict(

View file

@ -2776,10 +2776,16 @@ def litellm_telemetry(data):
uuid_value = str(uuid.uuid4()) uuid_value = str(uuid.uuid4())
try: try:
# Prepare the data to send to litellm logging api # Prepare the data to send to litellm logging api
try:
pkg_version = importlib.metadata.version("litellm")
except:
pkg_version = None
if "model" not in data:
data["model"] = None
payload = { payload = {
"uuid": uuid_value, "uuid": uuid_value,
"data": data, "data": data,
"version:": importlib.metadata.version("litellm"), "version:": pkg_version
} }
# Make the POST request to litellm logging api # Make the POST request to litellm logging api
response = requests.post( response = requests.post(