feat(proxy_server): add /v1/embeddings endpoint

n
This commit is contained in:
Krrish Dholakia 2023-11-22 14:03:20 -08:00
parent 40dd38508f
commit 448ec0a571

View file

@ -165,7 +165,6 @@ async def user_api_key_auth(request: Request):
try: try:
api_key = await oauth2_scheme(request=request) api_key = await oauth2_scheme(request=request)
route = request.url.path route = request.url.path
if api_key == master_key: if api_key == master_key:
return return
@ -246,7 +245,6 @@ def run_ollama_serve():
def load_router_config(router: Optional[litellm.Router], config_file_path: str): def load_router_config(router: Optional[litellm.Router], config_file_path: str):
global master_key global master_key
config = {} config = {}
general_settings: dict = {}
try: try:
if os.path.exists(config_file_path): if os.path.exists(config_file_path):
with open(config_file_path, 'r') as file: with open(config_file_path, 'r') as file:
@ -265,7 +263,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
os.environ[key] = value os.environ[key] = value
## GENERAL SERVER SETTINGS (e.g. master key,..) ## GENERAL SERVER SETTINGS (e.g. master key,..)
general_settings = config.get("general_settings", None) general_settings = config.get("general_settings", {})
if general_settings is None:
general_settings = {}
if general_settings: if general_settings:
### MASTER KEY ### ### MASTER KEY ###
master_key = general_settings.get("master_key", None) master_key = general_settings.get("master_key", None)
@ -292,7 +292,7 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
# print(f"litellm_model_name: {litellm_model_name}") # print(f"litellm_model_name: {litellm_model_name}")
if "ollama" in litellm_model_name: if "ollama" in litellm_model_name:
run_ollama_serve() run_ollama_serve()
print(f"returned general settings: {general_settings}")
return router, model_list, general_settings return router, model_list, general_settings
async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict, config: dict): async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict, config: dict):
@ -584,20 +584,13 @@ async def chat_completion(request: Request, model: Optional[str] = None):
@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)]) @router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)])
@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)]) @router.post("/embeddings", dependencies=[Depends(user_api_key_auth)])
@router.post("/openai/deployments/{model:path}/embeddings", dependencies=[Depends(user_api_key_auth)]) # azure compatible async def embeddings(request: Request):
async def embeddings(request: Request, model: str):
try: try:
body = await request.body() data = await request.json()
body_str = body.decode() print(f"data: {data}")
try:
data = ast.literal_eval(body_str)
except:
data = json.loads(body_str)
data["model"] = ( data["model"] = (
general_settings.get("embedding_model", None) # server default general_settings.get("embedding_model", None) # server default
or user_model # model name passed via cli args or user_model # model name passed via cli args
or model # for azure http calls
or data["model"] # default passed in http request or data["model"] # default passed in http request
) )
if user_model: if user_model:
@ -611,6 +604,7 @@ async def embeddings(request: Request, model: str):
response = litellm.aembedding(**data) response = litellm.aembedding(**data)
return response return response
except Exception as e: except Exception as e:
traceback.print_exc()
raise e raise e
except Exception as e: except Exception as e:
pass pass