add hf embedding models

This commit is contained in:
ishaan-jaff 2023-09-29 11:57:12 -07:00
parent f04d50d119
commit 3fbad7dfa7
2 changed files with 100 additions and 16 deletions

View file

@ -266,6 +266,78 @@ def completion(
} }
return model_response return model_response
def embedding():
# logic for parsing in - calling - parsing out model embedding calls def embedding(
pass model: str,
input: list,
api_key: str,
api_base: str,
logging_obj=None,
model_response=None,
encoding=None,
):
headers = validate_environment(api_key)
# print_verbose(f"{model}, {task}")
embed_url = ""
if "https" in model:
embed_url = model
elif api_base:
embed_url = api_base
elif "HF_API_BASE" in os.environ:
embed_url = os.getenv("HF_API_BASE", "")
elif "HUGGINGFACE_API_BASE" in os.environ:
embed_url = os.getenv("HUGGINGFACE_API_BASE", "")
else:
embed_url = f"https://api-inference.huggingface.co/models/{model}"
data = {
"inputs": input
}
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
)
## COMPLETION CALL
response = requests.post(
embed_url, headers=headers, data=json.dumps(data)
)
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)
embeddings = response.json()
output_data = []
for idx, embedding in enumerate(embeddings):
output_data.append(
{
"object": "embedding",
"index": idx,
"embedding": embedding[0][0] # flatten list returned from hf
}
)
model_response["object"] = "list"
model_response["data"] = output_data
model_response["model"] = model
input_tokens = 0
for text in input:
input_tokens+=len(encoding.encode(text))
model_response["usage"] = {
"prompt_tokens": input_tokens,
"total_tokens": input_tokens,
}
return model_response

View file

@ -1350,19 +1350,23 @@ def batch_completion_models_all_responses(*args, **kwargs):
def embedding( def embedding(
model, model,
input=[], input=[],
api_key=None,
api_base=None,
# Optional params
azure=False, azure=False,
force_timeout=60, force_timeout=60,
litellm_call_id=None, litellm_call_id=None,
litellm_logging_obj=None, litellm_logging_obj=None,
logger_fn=None, logger_fn=None,
caching=False, caching=False,
api_key=None, custom_llm_provider=None,
): ):
model, custom_llm_provider = get_llm_provider(model, custom_llm_provider)
try: try:
response = None response = None
logging = litellm_logging_obj logging = litellm_logging_obj
logging.update_environment_variables(model=model, user="", optional_params={}, litellm_params={"force_timeout": force_timeout, "azure": azure, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn}) logging.update_environment_variables(model=model, user="", optional_params={}, litellm_params={"force_timeout": force_timeout, "azure": azure, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn})
if azure == True: if azure == True or custom_llm_provider == "azure":
# azure configs # azure configs
openai.api_type = get_secret("AZURE_API_TYPE") or "azure" openai.api_type = get_secret("AZURE_API_TYPE") or "azure"
openai.api_base = get_secret("AZURE_API_BASE") openai.api_base = get_secret("AZURE_API_BASE")
@ -1380,6 +1384,9 @@ def embedding(
) )
## EMBEDDING CALL ## EMBEDDING CALL
response = openai.Embedding.create(input=input, engine=model) response = openai.Embedding.create(input=input, engine=model)
## LOGGING
logging.post_call(input=input, api_key=openai.api_key, original_response=response)
elif model in litellm.open_ai_embedding_models: elif model in litellm.open_ai_embedding_models:
openai.api_type = "openai" openai.api_type = "openai"
openai.api_base = "https://api.openai.com/v1" openai.api_base = "https://api.openai.com/v1"
@ -1414,20 +1421,25 @@ def embedding(
model_response= EmbeddingResponse() model_response= EmbeddingResponse()
) )
# elif custom_llm_provider == "huggingface": elif custom_llm_provider == "huggingface":
# response = huggingface_restapi.embedding( api_key = (
# model=model, api_key
# input=input, or litellm.huggingface_key
# encoding=encoding, or get_secret("HUGGINGFACE_API_KEY")
# api_key=cohere_key, or litellm.api_key
# logging_obj=logging, )
# model_response= EmbeddingResponse() response = huggingface_restapi.embedding(
# ) model=model,
input=input,
encoding=encoding,
api_key=api_key,
api_base=api_base,
logging_obj=logging,
model_response= EmbeddingResponse()
)
else: else:
args = locals() args = locals()
raise ValueError(f"No valid embedding model args passed in - {args}") raise ValueError(f"No valid embedding model args passed in - {args}")
## LOGGING
logging.post_call(input=input, api_key=openai.api_key, original_response=response)
return response return response
except Exception as e: except Exception as e:
## LOGGING ## LOGGING