add hf embedding models

This commit is contained in:
ishaan-jaff 2023-09-29 11:57:12 -07:00
parent f04d50d119
commit 3fbad7dfa7
2 changed files with 100 additions and 16 deletions

View file

@ -1350,19 +1350,23 @@ def batch_completion_models_all_responses(*args, **kwargs):
def embedding(
model,
input=[],
api_key=None,
api_base=None,
# Optional params
azure=False,
force_timeout=60,
litellm_call_id=None,
litellm_logging_obj=None,
logger_fn=None,
caching=False,
api_key=None,
custom_llm_provider=None,
):
model, custom_llm_provider = get_llm_provider(model, custom_llm_provider)
try:
response = None
logging = litellm_logging_obj
logging.update_environment_variables(model=model, user="", optional_params={}, litellm_params={"force_timeout": force_timeout, "azure": azure, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn})
if azure == True:
if azure == True or custom_llm_provider == "azure":
# azure configs
openai.api_type = get_secret("AZURE_API_TYPE") or "azure"
openai.api_base = get_secret("AZURE_API_BASE")
@ -1380,6 +1384,9 @@ def embedding(
)
## EMBEDDING CALL
response = openai.Embedding.create(input=input, engine=model)
## LOGGING
logging.post_call(input=input, api_key=openai.api_key, original_response=response)
elif model in litellm.open_ai_embedding_models:
openai.api_type = "openai"
openai.api_base = "https://api.openai.com/v1"
@ -1414,20 +1421,25 @@ def embedding(
model_response= EmbeddingResponse()
)
# elif custom_llm_provider == "huggingface":
# response = huggingface_restapi.embedding(
# model=model,
# input=input,
# encoding=encoding,
# api_key=cohere_key,
# logging_obj=logging,
# model_response= EmbeddingResponse()
# )
elif custom_llm_provider == "huggingface":
api_key = (
api_key
or litellm.huggingface_key
or get_secret("HUGGINGFACE_API_KEY")
or litellm.api_key
)
response = huggingface_restapi.embedding(
model=model,
input=input,
encoding=encoding,
api_key=api_key,
api_base=api_base,
logging_obj=logging,
model_response= EmbeddingResponse()
)
else:
args = locals()
raise ValueError(f"No valid embedding model args passed in - {args}")
## LOGGING
logging.post_call(input=input, api_key=openai.api_key, original_response=response)
return response
except Exception as e:
## LOGGING