diff --git a/litellm/llms/huggingface_restapi.py b/litellm/llms/huggingface_restapi.py index 7783068afc..4f631c7ab8 100644 --- a/litellm/llms/huggingface_restapi.py +++ b/litellm/llms/huggingface_restapi.py @@ -266,6 +266,78 @@ def completion( } return model_response -def embedding(): - # logic for parsing in - calling - parsing out model embedding calls - pass + +def embedding( + model: str, + input: list, + api_key: str, + api_base: str, + logging_obj=None, + model_response=None, + encoding=None, +): + headers = validate_environment(api_key) + # print_verbose(f"{model}, {task}") + embed_url = "" + if "https" in model: + embed_url = model + elif api_base: + embed_url = api_base + elif "HF_API_BASE" in os.environ: + embed_url = os.getenv("HF_API_BASE", "") + elif "HUGGINGFACE_API_BASE" in os.environ: + embed_url = os.getenv("HUGGINGFACE_API_BASE", "") + else: + embed_url = f"https://api-inference.huggingface.co/models/{model}" + + data = { + "inputs": input + } + + ## LOGGING + logging_obj.pre_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + ) + ## COMPLETION CALL + response = requests.post( + embed_url, headers=headers, data=json.dumps(data) + ) + + + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=response, + ) + + + embeddings = response.json() + + output_data = [] + for idx, embedding in enumerate(embeddings): + output_data.append( + { + "object": "embedding", + "index": idx, + "embedding": embedding[0][0] # flatten list returned from hf + } + ) + model_response["object"] = "list" + model_response["data"] = output_data + model_response["model"] = model + input_tokens = 0 + for text in input: + input_tokens+=len(encoding.encode(text)) + + model_response["usage"] = { + "prompt_tokens": input_tokens, + "total_tokens": input_tokens, + } + return model_response + + + diff --git a/litellm/main.py b/litellm/main.py index 70b78c54ff..6c07ff35d4 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1350,19 +1350,23 @@ def batch_completion_models_all_responses(*args, **kwargs): def embedding( model, input=[], + api_key=None, + api_base=None, + # Optional params azure=False, force_timeout=60, litellm_call_id=None, litellm_logging_obj=None, logger_fn=None, caching=False, - api_key=None, + custom_llm_provider=None, ): + model, custom_llm_provider = get_llm_provider(model, custom_llm_provider) try: response = None logging = litellm_logging_obj logging.update_environment_variables(model=model, user="", optional_params={}, litellm_params={"force_timeout": force_timeout, "azure": azure, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn}) - if azure == True: + if azure == True or custom_llm_provider == "azure": # azure configs openai.api_type = get_secret("AZURE_API_TYPE") or "azure" openai.api_base = get_secret("AZURE_API_BASE") @@ -1380,6 +1384,9 @@ def embedding( ) ## EMBEDDING CALL response = openai.Embedding.create(input=input, engine=model) + + ## LOGGING + logging.post_call(input=input, api_key=openai.api_key, original_response=response) elif model in litellm.open_ai_embedding_models: openai.api_type = "openai" openai.api_base = "https://api.openai.com/v1" @@ -1414,20 +1421,25 @@ def embedding( model_response= EmbeddingResponse() ) - # elif custom_llm_provider == "huggingface": - # response = huggingface_restapi.embedding( - # model=model, - # input=input, - # encoding=encoding, - # api_key=cohere_key, - # logging_obj=logging, - # model_response= EmbeddingResponse() - # ) + elif custom_llm_provider == "huggingface": + api_key = ( + api_key + or litellm.huggingface_key + or get_secret("HUGGINGFACE_API_KEY") + or litellm.api_key + ) + response = huggingface_restapi.embedding( + model=model, + input=input, + encoding=encoding, + api_key=api_key, + api_base=api_base, + logging_obj=logging, + model_response= EmbeddingResponse() + ) else: args = locals() raise ValueError(f"No valid embedding model args passed in - {args}") - ## LOGGING - logging.post_call(input=input, api_key=openai.api_key, original_response=response) return response except Exception as e: ## LOGGING