From 448ec0a5714a0fa291ecaa1f55d298aa42dbc5a8 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 22 Nov 2023 14:03:20 -0800 Subject: [PATCH] feat(proxy_server): add /v1/embeddings endpoint n --- litellm/proxy/proxy_server.py | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 91238a8e4..a4dcc09d4 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -165,7 +165,6 @@ async def user_api_key_auth(request: Request): try: api_key = await oauth2_scheme(request=request) route = request.url.path - if api_key == master_key: return @@ -246,7 +245,6 @@ def run_ollama_serve(): def load_router_config(router: Optional[litellm.Router], config_file_path: str): global master_key config = {} - general_settings: dict = {} try: if os.path.exists(config_file_path): with open(config_file_path, 'r') as file: @@ -265,7 +263,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): os.environ[key] = value ## GENERAL SERVER SETTINGS (e.g. master key,..) - general_settings = config.get("general_settings", None) + general_settings = config.get("general_settings", {}) + if general_settings is None: + general_settings = {} if general_settings: ### MASTER KEY ### master_key = general_settings.get("master_key", None) @@ -292,7 +292,7 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): # print(f"litellm_model_name: {litellm_model_name}") if "ollama" in litellm_model_name: run_ollama_serve() - + print(f"returned general settings: {general_settings}") return router, model_list, general_settings async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict, config: dict): @@ -584,26 +584,19 @@ async def chat_completion(request: Request, model: Optional[str] = None): @router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)]) @router.post("/embeddings", dependencies=[Depends(user_api_key_auth)]) -@router.post("/openai/deployments/{model:path}/embeddings", dependencies=[Depends(user_api_key_auth)]) # azure compatible -async def embeddings(request: Request, model: str): +async def embeddings(request: Request): try: - body = await request.body() - body_str = body.decode() - try: - data = ast.literal_eval(body_str) - except: - data = json.loads(body_str) - + data = await request.json() + print(f"data: {data}") data["model"] = ( general_settings.get("embedding_model", None) # server default or user_model # model name passed via cli args - or model # for azure http calls or data["model"] # default passed in http request ) if user_model: data["model"] = user_model - ## ROUTE TO CORRECT ENDPOINT ## + ## ROUTE TO CORRECT ENDPOINT ## router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else [] if llm_router is not None and data["model"] in router_model_names: # model in router model list response = await llm_router.aembedding(**data) @@ -611,6 +604,7 @@ async def embeddings(request: Request, model: str): response = litellm.aembedding(**data) return response except Exception as e: + traceback.print_exc() raise e except Exception as e: pass