mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
add hf embedding models
This commit is contained in:
parent
f04d50d119
commit
3fbad7dfa7
2 changed files with 100 additions and 16 deletions
|
@ -266,6 +266,78 @@ def completion(
|
||||||
}
|
}
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
def embedding():
|
|
||||||
# logic for parsing in - calling - parsing out model embedding calls
|
def embedding(
|
||||||
pass
|
model: str,
|
||||||
|
input: list,
|
||||||
|
api_key: str,
|
||||||
|
api_base: str,
|
||||||
|
logging_obj=None,
|
||||||
|
model_response=None,
|
||||||
|
encoding=None,
|
||||||
|
):
|
||||||
|
headers = validate_environment(api_key)
|
||||||
|
# print_verbose(f"{model}, {task}")
|
||||||
|
embed_url = ""
|
||||||
|
if "https" in model:
|
||||||
|
embed_url = model
|
||||||
|
elif api_base:
|
||||||
|
embed_url = api_base
|
||||||
|
elif "HF_API_BASE" in os.environ:
|
||||||
|
embed_url = os.getenv("HF_API_BASE", "")
|
||||||
|
elif "HUGGINGFACE_API_BASE" in os.environ:
|
||||||
|
embed_url = os.getenv("HUGGINGFACE_API_BASE", "")
|
||||||
|
else:
|
||||||
|
embed_url = f"https://api-inference.huggingface.co/models/{model}"
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"inputs": input
|
||||||
|
}
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=input,
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
## COMPLETION CALL
|
||||||
|
response = requests.post(
|
||||||
|
embed_url, headers=headers, data=json.dumps(data)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=input,
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
original_response=response,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
embeddings = response.json()
|
||||||
|
|
||||||
|
output_data = []
|
||||||
|
for idx, embedding in enumerate(embeddings):
|
||||||
|
output_data.append(
|
||||||
|
{
|
||||||
|
"object": "embedding",
|
||||||
|
"index": idx,
|
||||||
|
"embedding": embedding[0][0] # flatten list returned from hf
|
||||||
|
}
|
||||||
|
)
|
||||||
|
model_response["object"] = "list"
|
||||||
|
model_response["data"] = output_data
|
||||||
|
model_response["model"] = model
|
||||||
|
input_tokens = 0
|
||||||
|
for text in input:
|
||||||
|
input_tokens+=len(encoding.encode(text))
|
||||||
|
|
||||||
|
model_response["usage"] = {
|
||||||
|
"prompt_tokens": input_tokens,
|
||||||
|
"total_tokens": input_tokens,
|
||||||
|
}
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1350,19 +1350,23 @@ def batch_completion_models_all_responses(*args, **kwargs):
|
||||||
def embedding(
|
def embedding(
|
||||||
model,
|
model,
|
||||||
input=[],
|
input=[],
|
||||||
|
api_key=None,
|
||||||
|
api_base=None,
|
||||||
|
# Optional params
|
||||||
azure=False,
|
azure=False,
|
||||||
force_timeout=60,
|
force_timeout=60,
|
||||||
litellm_call_id=None,
|
litellm_call_id=None,
|
||||||
litellm_logging_obj=None,
|
litellm_logging_obj=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
caching=False,
|
caching=False,
|
||||||
api_key=None,
|
custom_llm_provider=None,
|
||||||
):
|
):
|
||||||
|
model, custom_llm_provider = get_llm_provider(model, custom_llm_provider)
|
||||||
try:
|
try:
|
||||||
response = None
|
response = None
|
||||||
logging = litellm_logging_obj
|
logging = litellm_logging_obj
|
||||||
logging.update_environment_variables(model=model, user="", optional_params={}, litellm_params={"force_timeout": force_timeout, "azure": azure, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn})
|
logging.update_environment_variables(model=model, user="", optional_params={}, litellm_params={"force_timeout": force_timeout, "azure": azure, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn})
|
||||||
if azure == True:
|
if azure == True or custom_llm_provider == "azure":
|
||||||
# azure configs
|
# azure configs
|
||||||
openai.api_type = get_secret("AZURE_API_TYPE") or "azure"
|
openai.api_type = get_secret("AZURE_API_TYPE") or "azure"
|
||||||
openai.api_base = get_secret("AZURE_API_BASE")
|
openai.api_base = get_secret("AZURE_API_BASE")
|
||||||
|
@ -1380,6 +1384,9 @@ def embedding(
|
||||||
)
|
)
|
||||||
## EMBEDDING CALL
|
## EMBEDDING CALL
|
||||||
response = openai.Embedding.create(input=input, engine=model)
|
response = openai.Embedding.create(input=input, engine=model)
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging.post_call(input=input, api_key=openai.api_key, original_response=response)
|
||||||
elif model in litellm.open_ai_embedding_models:
|
elif model in litellm.open_ai_embedding_models:
|
||||||
openai.api_type = "openai"
|
openai.api_type = "openai"
|
||||||
openai.api_base = "https://api.openai.com/v1"
|
openai.api_base = "https://api.openai.com/v1"
|
||||||
|
@ -1414,20 +1421,25 @@ def embedding(
|
||||||
model_response= EmbeddingResponse()
|
model_response= EmbeddingResponse()
|
||||||
|
|
||||||
)
|
)
|
||||||
# elif custom_llm_provider == "huggingface":
|
elif custom_llm_provider == "huggingface":
|
||||||
# response = huggingface_restapi.embedding(
|
api_key = (
|
||||||
# model=model,
|
api_key
|
||||||
# input=input,
|
or litellm.huggingface_key
|
||||||
# encoding=encoding,
|
or get_secret("HUGGINGFACE_API_KEY")
|
||||||
# api_key=cohere_key,
|
or litellm.api_key
|
||||||
# logging_obj=logging,
|
)
|
||||||
# model_response= EmbeddingResponse()
|
response = huggingface_restapi.embedding(
|
||||||
# )
|
model=model,
|
||||||
|
input=input,
|
||||||
|
encoding=encoding,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
logging_obj=logging,
|
||||||
|
model_response= EmbeddingResponse()
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
args = locals()
|
args = locals()
|
||||||
raise ValueError(f"No valid embedding model args passed in - {args}")
|
raise ValueError(f"No valid embedding model args passed in - {args}")
|
||||||
## LOGGING
|
|
||||||
logging.post_call(input=input, api_key=openai.api_key, original_response=response)
|
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue