mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
feat(proxy_server): add /v1/embeddings endpoint
n
This commit is contained in:
parent
40dd38508f
commit
448ec0a571
1 changed files with 9 additions and 15 deletions
|
@ -165,7 +165,6 @@ async def user_api_key_auth(request: Request):
|
|||
try:
|
||||
api_key = await oauth2_scheme(request=request)
|
||||
route = request.url.path
|
||||
|
||||
if api_key == master_key:
|
||||
return
|
||||
|
||||
|
@ -246,7 +245,6 @@ def run_ollama_serve():
|
|||
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||
global master_key
|
||||
config = {}
|
||||
general_settings: dict = {}
|
||||
try:
|
||||
if os.path.exists(config_file_path):
|
||||
with open(config_file_path, 'r') as file:
|
||||
|
@ -265,7 +263,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
os.environ[key] = value
|
||||
|
||||
## GENERAL SERVER SETTINGS (e.g. master key,..)
|
||||
general_settings = config.get("general_settings", None)
|
||||
general_settings = config.get("general_settings", {})
|
||||
if general_settings is None:
|
||||
general_settings = {}
|
||||
if general_settings:
|
||||
### MASTER KEY ###
|
||||
master_key = general_settings.get("master_key", None)
|
||||
|
@ -292,7 +292,7 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
# print(f"litellm_model_name: {litellm_model_name}")
|
||||
if "ollama" in litellm_model_name:
|
||||
run_ollama_serve()
|
||||
|
||||
print(f"returned general settings: {general_settings}")
|
||||
return router, model_list, general_settings
|
||||
|
||||
async def generate_key_helper_fn(duration_str: str, models: list, aliases: dict, config: dict):
|
||||
|
@ -584,26 +584,19 @@ async def chat_completion(request: Request, model: Optional[str] = None):
|
|||
|
||||
@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)])
|
||||
@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)])
|
||||
@router.post("/openai/deployments/{model:path}/embeddings", dependencies=[Depends(user_api_key_auth)]) # azure compatible
|
||||
async def embeddings(request: Request, model: str):
|
||||
async def embeddings(request: Request):
|
||||
try:
|
||||
body = await request.body()
|
||||
body_str = body.decode()
|
||||
try:
|
||||
data = ast.literal_eval(body_str)
|
||||
except:
|
||||
data = json.loads(body_str)
|
||||
|
||||
data = await request.json()
|
||||
print(f"data: {data}")
|
||||
data["model"] = (
|
||||
general_settings.get("embedding_model", None) # server default
|
||||
or user_model # model name passed via cli args
|
||||
or model # for azure http calls
|
||||
or data["model"] # default passed in http request
|
||||
)
|
||||
if user_model:
|
||||
data["model"] = user_model
|
||||
|
||||
## ROUTE TO CORRECT ENDPOINT ##
|
||||
## ROUTE TO CORRECT ENDPOINT ##
|
||||
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
||||
if llm_router is not None and data["model"] in router_model_names: # model in router model list
|
||||
response = await llm_router.aembedding(**data)
|
||||
|
@ -611,6 +604,7 @@ async def embeddings(request: Request, model: str):
|
|||
response = litellm.aembedding(**data)
|
||||
return response
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
except Exception as e:
|
||||
pass
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue