forked from phoenix/litellm-mirror
feat(proxy_server): add /v1/embeddings endpoint
n
This commit is contained in:
parent
40dd38508f
commit
448ec0a571
1 changed files with 9 additions and 15 deletions
|
@ -165,7 +165,6 @@ async def user_api_key_auth(request: Request):
|
||||||
try:
|
try:
|
||||||
api_key = await oauth2_scheme(request=request)
|
api_key = await oauth2_scheme(request=request)
|
||||||
route = request.url.path
|
route = request.url.path
|
||||||
|
|
||||||
if api_key == master_key:
|
if api_key == master_key:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -246,7 +245,6 @@ def run_ollama_serve():
|
||||||
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
global master_key
|
global master_key
|
||||||
config = {}
|
config = {}
|
||||||
general_settings: dict = {}
|
|
||||||
try:
|
try:
|
||||||
if os.path.exists(config_file_path):
|
if os.path.exists(config_file_path):
|
||||||
with open(config_file_path, 'r') as file:
|
with open(config_file_path, 'r') as file:
|
||||||
|
@ -265,7 +263,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
os.environ[key] = value
|
os.environ[key] = value
|
||||||
|
|
||||||
## GENERAL SERVER SETTINGS (e.g. master key,..)
|
## GENERAL SERVER SETTINGS (e.g. master key,..)
|
||||||
general_settings = config.get("general_settings", None)
|
general_settings = config.get("general_settings", {})
|
||||||
|
if general_settings is None:
|
||||||
|
general_settings = {}
|
||||||
if general_settings:
|
if general_settings:
|
||||||
### MASTER KEY ###
|
### MASTER KEY ###
|
||||||
master_key = general_settings.get("master_key", None)
|
master_key = general_settings.get("master_key", None)
|
||||||
|
@ -292,7 +292,7 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
# print(f"litellm_model_name: {litellm_model_name}")
|
# print(f"litellm_model_name: {litellm_model_name}")
|
||||||
if "ollama" in litellm_model_name:
|
if "ollama" in litellm_model_name:
|
||||||
run_ollama_serve()
|
run_ollama_serve()
|
||||||
|
print(f"returned general settings: {general_settings}")
|
||||||
return router, model_list, general_settings
|
return router, model_list, general_settings
|
||||||
|
|
||||||
async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict, config: dict):
|
async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict, config: dict):
|
||||||
|
@ -584,20 +584,13 @@ async def chat_completion(request: Request, model: Optional[str] = None):
|
||||||
|
|
||||||
@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)])
|
@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)])
|
||||||
@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)])
|
@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)])
|
||||||
@router.post("/openai/deployments/{model:path}/embeddings", dependencies=[Depends(user_api_key_auth)]) # azure compatible
|
async def embeddings(request: Request):
|
||||||
async def embeddings(request: Request, model: str):
|
|
||||||
try:
|
try:
|
||||||
body = await request.body()
|
data = await request.json()
|
||||||
body_str = body.decode()
|
print(f"data: {data}")
|
||||||
try:
|
|
||||||
data = ast.literal_eval(body_str)
|
|
||||||
except:
|
|
||||||
data = json.loads(body_str)
|
|
||||||
|
|
||||||
data["model"] = (
|
data["model"] = (
|
||||||
general_settings.get("embedding_model", None) # server default
|
general_settings.get("embedding_model", None) # server default
|
||||||
or user_model # model name passed via cli args
|
or user_model # model name passed via cli args
|
||||||
or model # for azure http calls
|
|
||||||
or data["model"] # default passed in http request
|
or data["model"] # default passed in http request
|
||||||
)
|
)
|
||||||
if user_model:
|
if user_model:
|
||||||
|
@ -611,6 +604,7 @@ async def embeddings(request: Request, model: str):
|
||||||
response = litellm.aembedding(**data)
|
response = litellm.aembedding(**data)
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue