test: fix config import for proxy testing

This commit is contained in:
Krrish Dholakia 2023-12-06 17:40:38 -08:00
parent c040b478fc
commit b7e75b940a
3 changed files with 46 additions and 18 deletions

View file

@ -748,6 +748,19 @@ def litellm_completion(*args, **kwargs):
return StreamingResponse(data_generator(response), media_type='text/event-stream') return StreamingResponse(data_generator(response), media_type='text/event-stream')
return response return response
def get_litellm_model_info(model: dict = {}):
model_info = model.get("model_info", {})
model_to_lookup = model.get("litellm_params", {}).get("model", None)
try:
if "azure" in model_to_lookup:
model_to_lookup = model_info.get("base_model", None)
litellm_model_info = litellm.get_model_info(model_to_lookup)
return litellm_model_info
except:
# this should not block returning on /model/info
# if litellm does not have info on the model it should return {}
return {}
@app.middleware("http") @app.middleware("http")
async def rate_limit_per_token(request: Request, call_next): async def rate_limit_per_token(request: Request, call_next):
global user_api_key_cache, general_settings global user_api_key_cache, general_settings
@ -1101,19 +1114,6 @@ async def add_new_model(model_params: ModelParams):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}") raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
def get_litellm_model_info(model: dict = {}):
model_info = model.get("model_info", {})
model_to_lookup = model.get("litellm_params", {}).get("model", None)
try:
if "azure" in model_to_lookup:
model_to_lookup = model_info.get("base_model", None)
litellm_model_info = litellm.get_model_info(model_to_lookup)
return litellm_model_info
except:
# this should not block returning on /model/info
# if litellm does not have info on the model it should return {}
return {}
#### [BETA] - This is a beta endpoint, format might change based on user feedback https://github.com/BerriAI/litellm/issues/933. If you need a stable endpoint use /model/info #### [BETA] - This is a beta endpoint, format might change based on user feedback https://github.com/BerriAI/litellm/issues/933. If you need a stable endpoint use /model/info
@router.get("/v1/model/info", description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", tags=["model management"], dependencies=[Depends(user_api_key_auth)]) @router.get("/v1/model/info", description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", tags=["model management"], dependencies=[Depends(user_api_key_auth)])
async def model_info_v1(request: Request): async def model_info_v1(request: Request):

View file

@ -0,0 +1,21 @@
model_list:
- litellm_params:
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
api_key: os.environ/AZURE_EUROPE_API_KEY
model: azure/gpt-35-turbo
model_name: azure-model
- litellm_params:
api_base: https://my-endpoint-canada-berri992.openai.azure.com
api_key: os.environ/AZURE_CANADA_API_KEY
model: azure/gpt-35-turbo
model_name: azure-model
- litellm_params:
api_base: https://openai-france-1234.openai.azure.com
api_key: os.environ/AZURE_FRANCE_API_KEY
model: azure/gpt-turbo
model_name: azure-model
- litellm_params:
model: gpt-3.5-turbo
model_info:
description: this is a test openai model
model_name: test_openai_models

View file

@ -10,21 +10,28 @@ import os, io
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import pytest import pytest, logging
import litellm import litellm
from litellm import embedding, completion, completion_cost, Timeout from litellm import embedding, completion, completion_cost, Timeout
from litellm import RateLimitError from litellm import RateLimitError
# Configure logging
logging.basicConfig(
level=logging.DEBUG, # Set the desired logging level
format="%(asctime)s - %(levelname)s - %(message)s",
)
# test /chat/completion request to the proxy # test /chat/completion request to the proxy
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from fastapi import FastAPI from fastapi import FastAPI
from litellm.proxy.proxy_server import router, save_worker_config, initialize # Replace with the actual module where your FastAPI router is defined from litellm.proxy.proxy_server import router, save_worker_config, startup_event # Replace with the actual module where your FastAPI router is defined
save_worker_config(config=None, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False) filepath = os.path.dirname(os.path.abspath(__file__))
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
app = FastAPI() app = FastAPI()
app.include_router(router) # Include your router in the test app app.include_router(router) # Include your router in the test app
@app.on_event("startup") @app.on_event("startup")
async def wrapper_startup_event(): # required to reset config on app init - b/c pytest collects across multiple files - which sets the fastapi client + WORKER CONFIG to whatever was collected last async def wrapper_startup_event():
initialize(config=None, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False) await startup_event()
# Here you create a fixture that will be used by your tests # Here you create a fixture that will be used by your tests
# Make sure the fixture returns TestClient(app) # Make sure the fixture returns TestClient(app)