add hf embedding models

This commit is contained in:
ishaan-jaff 2023-09-29 11:57:12 -07:00
parent 932447e63a
commit 0dbb5a10e9
2 changed files with 100 additions and 16 deletions

View file

@ -266,6 +266,78 @@ def completion(
}
return model_response
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass
def embedding(
model: str,
input: list,
api_key: str,
api_base: str,
logging_obj=None,
model_response=None,
encoding=None,
):
headers = validate_environment(api_key)
# print_verbose(f"{model}, {task}")
embed_url = ""
if "https" in model:
embed_url = model
elif api_base:
embed_url = api_base
elif "HF_API_BASE" in os.environ:
embed_url = os.getenv("HF_API_BASE", "")
elif "HUGGINGFACE_API_BASE" in os.environ:
embed_url = os.getenv("HUGGINGFACE_API_BASE", "")
else:
embed_url = f"https://api-inference.huggingface.co/models/{model}"
data = {
"inputs": input
}
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
)
## COMPLETION CALL
response = requests.post(
embed_url, headers=headers, data=json.dumps(data)
)
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
embeddings = response.json()
output_data = []
for idx, embedding in enumerate(embeddings):
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding[0][0] # flatten list returned from hf
}
)
model_response["object"] = "list"
model_response["data"] = output_data
model_response["model"] = model
input_tokens = 0
for text in input:
input_tokens+=len(encoding.encode(text))
model_response["usage"] = {
"prompt_tokens": input_tokens,
"total_tokens": input_tokens,
}
return model_response

View file

@ -1350,19 +1350,23 @@ def batch_completion_models_all_responses(*args, **kwargs):
def embedding(
model,
input=[],
api_key=None,
api_base=None,
# Optional params
azure=False,
force_timeout=60,
litellm_call_id=None,
litellm_logging_obj=None,
logger_fn=None,
caching=False,
api_key=None,
custom_llm_provider=None,
):
model, custom_llm_provider = get_llm_provider(model, custom_llm_provider)
try:
response = None
logging = litellm_logging_obj
logging.update_environment_variables(model=model, user="", optional_params={}, litellm_params={"force_timeout": force_timeout, "azure": azure, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn})
if azure == True:
if azure == True or custom_llm_provider == "azure":
# azure configs
openai.api_type = get_secret("AZURE_API_TYPE") or "azure"
openai.api_base = get_secret("AZURE_API_BASE")
@ -1380,6 +1384,9 @@ def embedding(
)
## EMBEDDING CALL
response = openai.Embedding.create(input=input, engine=model)
## LOGGING
logging.post_call(input=input, api_key=openai.api_key, original_response=response)
elif model in litellm.open_ai_embedding_models:
openai.api_type = "openai"
openai.api_base = "https://api.openai.com/v1"
@ -1414,20 +1421,25 @@ def embedding(
model_response= EmbeddingResponse()
)
# elif custom_llm_provider == "huggingface":
# response = huggingface_restapi.embedding(
# model=model,
# input=input,
# encoding=encoding,
# api_key=cohere_key,
# logging_obj=logging,
# model_response= EmbeddingResponse()
# )
elif custom_llm_provider == "huggingface":
api_key = (
api_key
or litellm.huggingface_key
or get_secret("HUGGINGFACE_API_KEY")
or litellm.api_key
)
response = huggingface_restapi.embedding(
model=model,
input=input,
encoding=encoding,
api_key=api_key,
api_base=api_base,
logging_obj=logging,
model_response= EmbeddingResponse()
)
else:
args = locals()
raise ValueError(f"No valid embedding model args passed in - {args}")
## LOGGING
logging.post_call(input=input, api_key=openai.api_key, original_response=response)
return response
except Exception as e:
## LOGGING