litellm-mirror/litellm/llms/watsonx/embed/transformation.py
Krish Dholakia 71f659d26b Complete 'requests' library removal (#7350)
* refactor: initial commit moving watsonx_text to base_llm_http_handler + clarifying new provider directory structure

* refactor(watsonx/completion/handler.py): move to using base llm http handler

removes 'requests' library usage

* fix(watsonx_text/transformation.py): fix result transformation

migrates to transformation.py, for usage with base llm http handler

* fix(streaming_handler.py): migrate watsonx streaming to transformation.py

ensures streaming works with base llm http handler

* fix(streaming_handler.py): fix streaming linting errors and remove watsonx conditional logic

* fix(watsonx/): fix chat route post completion route refactor

* refactor(watsonx/embed): refactor watsonx to use base llm http handler for embedding calls as well

* refactor(base.py): remove requests library usage from litellm

* build(pyproject.toml): remove requests library usage

* fix: fix linting errors

* fix: fix linting errors

* fix(types/utils.py): fix validation errors for modelresponsestream

* fix(replicate/handler.py): fix linting errors

* fix(litellm_logging.py): handle modelresponsestream object

* fix(streaming_handler.py): fix modelresponsestream args

* fix: remove unused imports

* test: fix test

* fix: fix test

* test: fix test

* test: fix tests

* test: fix test

* test: fix patch target

* test: fix test
2024-12-22 07:21:25 -08:00

116 lines
3.8 KiB
Python

"""
Translates from OpenAI's `/v1/embeddings` to IBM's `/text/embeddings` route.
"""
from typing import Optional
import httpx
from litellm.llms.base_llm.embedding.transformation import (
BaseEmbeddingConfig,
LiteLLMLoggingObj,
)
from litellm.types.llms.openai import AllEmbeddingInputValues
from litellm.types.llms.watsonx import WatsonXAIEndpoint
from litellm.types.utils import EmbeddingResponse, Usage
from ..common_utils import IBMWatsonXMixin, WatsonXAIError, _get_api_params
class IBMWatsonXEmbeddingConfig(IBMWatsonXMixin, BaseEmbeddingConfig):
def get_supported_openai_params(self, model: str) -> list:
return []
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
return optional_params
def transform_embedding_request(
self,
model: str,
input: AllEmbeddingInputValues,
optional_params: dict,
headers: dict,
) -> dict:
watsonx_api_params = _get_api_params(params=optional_params)
project_id = watsonx_api_params["project_id"]
if not project_id:
raise ValueError("project_id is required")
return {
"inputs": input,
"model_id": model,
"project_id": project_id,
"parameters": optional_params,
}
def get_complete_url(
self,
api_base: Optional[str],
model: str,
optional_params: dict,
stream: Optional[bool] = None,
) -> str:
url = self._get_base_url(api_base=api_base)
endpoint = WatsonXAIEndpoint.EMBEDDINGS.value
if model.startswith("deployment/"):
# deployment models are passed in as 'deployment/<deployment_id>'
if optional_params.get("space_id") is None:
raise WatsonXAIError(
status_code=401,
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
)
deployment_id = "/".join(model.split("/")[1:])
endpoint = endpoint.format(deployment_id=deployment_id)
url = url.rstrip("/") + endpoint
## add api version
url = self._add_api_version_to_url(
url=url, api_version=optional_params.pop("api_version", None)
)
return url
def transform_embedding_response(
self,
model: str,
raw_response: httpx.Response,
model_response: EmbeddingResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str],
request_data: dict,
optional_params: dict,
litellm_params: dict,
) -> EmbeddingResponse:
logging_obj.post_call(
original_response=raw_response.text,
)
json_resp = raw_response.json()
if model_response is None:
model_response = EmbeddingResponse(model=json_resp.get("model_id", None))
results = json_resp.get("results", [])
embedding_response = []
for idx, result in enumerate(results):
embedding_response.append(
{
"object": "embedding",
"index": idx,
"embedding": result["embedding"],
}
)
model_response.object = "list"
model_response.data = embedding_response
input_tokens = json_resp.get("input_token_count", 0)
setattr(
model_response,
"usage",
Usage(
prompt_tokens=input_tokens,
completion_tokens=0,
total_tokens=input_tokens,
),
)
return model_response