feat(proxy_server): add /v1/embeddings endpoint

n
This commit is contained in:
Krrish Dholakia 2023-11-22 14:03:20 -08:00
parent 40dd38508f
commit 448ec0a571

View file

@ -165,7 +165,6 @@ async def user_api_key_auth(request: Request):
try:
api_key = await oauth2_scheme(request=request)
route = request.url.path
if api_key == master_key:
return
@ -246,7 +245,6 @@ def run_ollama_serve():
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
global master_key
config = {}
general_settings: dict = {}
try:
if os.path.exists(config_file_path):
with open(config_file_path, 'r') as file:
@ -265,7 +263,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
os.environ[key] = value
## GENERAL SERVER SETTINGS (e.g. master key,..)
general_settings = config.get("general_settings", None)
general_settings = config.get("general_settings", {})
if general_settings is None:
general_settings = {}
if general_settings:
### MASTER KEY ###
master_key = general_settings.get("master_key", None)
@ -292,7 +292,7 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
# print(f"litellm_model_name: {litellm_model_name}")
if "ollama" in litellm_model_name:
run_ollama_serve()
print(f"returned general settings: {general_settings}")
return router, model_list, general_settings
async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict, config: dict):
@ -584,26 +584,19 @@ async def chat_completion(request: Request, model: Optional[str] = None):
@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)])
@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)])
@router.post("/openai/deployments/{model:path}/embeddings", dependencies=[Depends(user_api_key_auth)]) # azure compatible
async def embeddings(request: Request, model: str):
async def embeddings(request: Request):
try:
body = await request.body()
body_str = body.decode()
try:
data = ast.literal_eval(body_str)
except:
data = json.loads(body_str)
data = await request.json()
print(f"data: {data}")
data["model"] = (
general_settings.get("embedding_model", None) # server default
or user_model # model name passed via cli args
or model # for azure http calls
or data["model"] # default passed in http request
)
if user_model:
data["model"] = user_model
## ROUTE TO CORRECT ENDPOINT ##
## ROUTE TO CORRECT ENDPOINT ##
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
if llm_router is not None and data["model"] in router_model_names: # model in router model list
response = await llm_router.aembedding(**data)
@ -611,6 +604,7 @@ async def embeddings(request: Request, model: str):
response = litellm.aembedding(**data)
return response
except Exception as e:
traceback.print_exc()
raise e
except Exception as e:
pass