mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
Complete 'requests' library removal (#7350)
* refactor: initial commit moving watsonx_text to base_llm_http_handler + clarifying new provider directory structure * refactor(watsonx/completion/handler.py): move to using base llm http handler removes 'requests' library usage * fix(watsonx_text/transformation.py): fix result transformation migrates to transformation.py, for usage with base llm http handler * fix(streaming_handler.py): migrate watsonx streaming to transformation.py ensures streaming works with base llm http handler * fix(streaming_handler.py): fix streaming linting errors and remove watsonx conditional logic * fix(watsonx/): fix chat route post completion route refactor * refactor(watsonx/embed): refactor watsonx to use base llm http handler for embedding calls as well * refactor(base.py): remove requests library usage from litellm * build(pyproject.toml): remove requests library usage * fix: fix linting errors * fix: fix linting errors * fix(types/utils.py): fix validation errors for modelresponsestream * fix(replicate/handler.py): fix linting errors * fix(litellm_logging.py): handle modelresponsestream object * fix(streaming_handler.py): fix modelresponsestream args * fix: remove unused imports * test: fix test * fix: fix test * test: fix test * test: fix tests * test: fix test * test: fix patch target * test: fix test
This commit is contained in:
parent
4b0bef1823
commit
71f659d26b
39 changed files with 2147 additions and 2279 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -67,3 +67,4 @@ litellm/tests/langfuse.log
|
||||||
litellm/proxy/google-cloud-sdk/*
|
litellm/proxy/google-cloud-sdk/*
|
||||||
tests/llm_translation/log.txt
|
tests/llm_translation/log.txt
|
||||||
venv/
|
venv/
|
||||||
|
tests/local_testing/log.txt
|
||||||
|
|
|
@ -5,16 +5,16 @@ When adding a new provider, you need to create a directory for the provider that
|
||||||
```
|
```
|
||||||
litellm/llms/
|
litellm/llms/
|
||||||
└── provider_name/
|
└── provider_name/
|
||||||
├── completion/
|
├── completion/ # use when endpoint is equivalent to openai's `/v1/completions`
|
||||||
│ ├── handler.py
|
│ ├── handler.py
|
||||||
│ └── transformation.py
|
│ └── transformation.py
|
||||||
├── chat/
|
├── chat/ # use when endpoint is equivalent to openai's `/v1/chat/completions`
|
||||||
│ ├── handler.py
|
│ ├── handler.py
|
||||||
│ └── transformation.py
|
│ └── transformation.py
|
||||||
├── embed/
|
├── embed/ # use when endpoint is equivalent to openai's `/v1/embeddings`
|
||||||
│ ├── handler.py
|
│ ├── handler.py
|
||||||
│ └── transformation.py
|
│ └── transformation.py
|
||||||
└── rerank/
|
└── rerank/ # use when endpoint is equivalent to cohere's `/rerank` endpoint.
|
||||||
├── handler.py
|
├── handler.py
|
||||||
└── transformation.py
|
└── transformation.py
|
||||||
```
|
```
|
||||||
|
|
|
@ -991,6 +991,7 @@ from .utils import (
|
||||||
get_api_base,
|
get_api_base,
|
||||||
get_first_chars_messages,
|
get_first_chars_messages,
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
|
ModelResponseStream,
|
||||||
EmbeddingResponse,
|
EmbeddingResponse,
|
||||||
ImageResponse,
|
ImageResponse,
|
||||||
TranscriptionResponse,
|
TranscriptionResponse,
|
||||||
|
@ -1157,6 +1158,7 @@ from .llms.perplexity.chat.transformation import PerplexityChatConfig
|
||||||
from .llms.azure.chat.o1_transformation import AzureOpenAIO1Config
|
from .llms.azure.chat.o1_transformation import AzureOpenAIO1Config
|
||||||
from .llms.watsonx.completion.transformation import IBMWatsonXAIConfig
|
from .llms.watsonx.completion.transformation import IBMWatsonXAIConfig
|
||||||
from .llms.watsonx.chat.transformation import IBMWatsonXChatConfig
|
from .llms.watsonx.chat.transformation import IBMWatsonXChatConfig
|
||||||
|
from .llms.watsonx.embed.transformation import IBMWatsonXEmbeddingConfig
|
||||||
from .main import * # type: ignore
|
from .main import * # type: ignore
|
||||||
from .integrations import *
|
from .integrations import *
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
|
|
|
@ -43,6 +43,7 @@ from litellm.types.utils import (
|
||||||
ImageResponse,
|
ImageResponse,
|
||||||
LiteLLMLoggingBaseClass,
|
LiteLLMLoggingBaseClass,
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
|
ModelResponseStream,
|
||||||
StandardCallbackDynamicParams,
|
StandardCallbackDynamicParams,
|
||||||
StandardLoggingAdditionalHeaders,
|
StandardLoggingAdditionalHeaders,
|
||||||
StandardLoggingHiddenParams,
|
StandardLoggingHiddenParams,
|
||||||
|
@ -741,6 +742,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
self,
|
self,
|
||||||
result: Union[
|
result: Union[
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
|
ModelResponseStream,
|
||||||
EmbeddingResponse,
|
EmbeddingResponse,
|
||||||
ImageResponse,
|
ImageResponse,
|
||||||
TranscriptionResponse,
|
TranscriptionResponse,
|
||||||
|
@ -848,6 +850,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
): # handle streaming separately
|
): # handle streaming separately
|
||||||
if (
|
if (
|
||||||
isinstance(result, ModelResponse)
|
isinstance(result, ModelResponse)
|
||||||
|
or isinstance(result, ModelResponseStream)
|
||||||
or isinstance(result, EmbeddingResponse)
|
or isinstance(result, EmbeddingResponse)
|
||||||
or isinstance(result, ImageResponse)
|
or isinstance(result, ImageResponse)
|
||||||
or isinstance(result, TranscriptionResponse)
|
or isinstance(result, TranscriptionResponse)
|
||||||
|
@ -955,6 +958,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
if self.stream and (
|
if self.stream and (
|
||||||
isinstance(result, litellm.ModelResponse)
|
isinstance(result, litellm.ModelResponse)
|
||||||
or isinstance(result, TextCompletionResponse)
|
or isinstance(result, TextCompletionResponse)
|
||||||
|
or isinstance(result, ModelResponseStream)
|
||||||
):
|
):
|
||||||
complete_streaming_response: Optional[
|
complete_streaming_response: Optional[
|
||||||
Union[ModelResponse, TextCompletionResponse]
|
Union[ModelResponse, TextCompletionResponse]
|
||||||
|
@ -966,9 +970,6 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
streaming_chunks=self.sync_streaming_chunks,
|
streaming_chunks=self.sync_streaming_chunks,
|
||||||
is_async=False,
|
is_async=False,
|
||||||
)
|
)
|
||||||
_caching_complete_streaming_response: Optional[
|
|
||||||
Union[ModelResponse, TextCompletionResponse]
|
|
||||||
] = None
|
|
||||||
if complete_streaming_response is not None:
|
if complete_streaming_response is not None:
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
"Logging Details LiteLLM-Success Call streaming complete"
|
"Logging Details LiteLLM-Success Call streaming complete"
|
||||||
|
@ -976,9 +977,6 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
self.model_call_details["complete_streaming_response"] = (
|
self.model_call_details["complete_streaming_response"] = (
|
||||||
complete_streaming_response
|
complete_streaming_response
|
||||||
)
|
)
|
||||||
_caching_complete_streaming_response = copy.deepcopy(
|
|
||||||
complete_streaming_response
|
|
||||||
)
|
|
||||||
self.model_call_details["response_cost"] = (
|
self.model_call_details["response_cost"] = (
|
||||||
self._response_cost_calculator(result=complete_streaming_response)
|
self._response_cost_calculator(result=complete_streaming_response)
|
||||||
)
|
)
|
||||||
|
@ -1474,6 +1472,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
] = None
|
] = None
|
||||||
if self.stream is True and (
|
if self.stream is True and (
|
||||||
isinstance(result, litellm.ModelResponse)
|
isinstance(result, litellm.ModelResponse)
|
||||||
|
or isinstance(result, litellm.ModelResponseStream)
|
||||||
or isinstance(result, TextCompletionResponse)
|
or isinstance(result, TextCompletionResponse)
|
||||||
):
|
):
|
||||||
complete_streaming_response: Optional[
|
complete_streaming_response: Optional[
|
||||||
|
|
|
@ -2,7 +2,11 @@ from datetime import datetime
|
||||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||||
|
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
from litellm.types.utils import ModelResponse, TextCompletionResponse
|
from litellm.types.utils import (
|
||||||
|
ModelResponse,
|
||||||
|
ModelResponseStream,
|
||||||
|
TextCompletionResponse,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from litellm import ModelResponse as _ModelResponse
|
from litellm import ModelResponse as _ModelResponse
|
||||||
|
@ -38,7 +42,7 @@ def convert_litellm_response_object_to_str(
|
||||||
|
|
||||||
|
|
||||||
def _assemble_complete_response_from_streaming_chunks(
|
def _assemble_complete_response_from_streaming_chunks(
|
||||||
result: Union[ModelResponse, TextCompletionResponse],
|
result: Union[ModelResponse, TextCompletionResponse, ModelResponseStream],
|
||||||
start_time: datetime,
|
start_time: datetime,
|
||||||
end_time: datetime,
|
end_time: datetime,
|
||||||
request_kwargs: dict,
|
request_kwargs: dict,
|
||||||
|
|
|
@ -5,7 +5,7 @@ import time
|
||||||
import traceback
|
import traceback
|
||||||
import uuid
|
import uuid
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import Any, Callable, List, Optional
|
from typing import Any, Callable, Dict, List, Optional, cast
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
@ -611,44 +611,6 @@ class CustomStreamWrapper:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def handle_watsonx_stream(self, chunk):
|
|
||||||
try:
|
|
||||||
if isinstance(chunk, dict):
|
|
||||||
parsed_response = chunk
|
|
||||||
elif isinstance(chunk, (str, bytes)):
|
|
||||||
if isinstance(chunk, bytes):
|
|
||||||
chunk = chunk.decode("utf-8")
|
|
||||||
if "generated_text" in chunk:
|
|
||||||
response = chunk.replace("data: ", "").strip()
|
|
||||||
parsed_response = json.loads(response)
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"text": "",
|
|
||||||
"is_finished": False,
|
|
||||||
"prompt_tokens": 0,
|
|
||||||
"completion_tokens": 0,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
print_verbose(f"chunk: {chunk} (Type: {type(chunk)})")
|
|
||||||
raise ValueError(
|
|
||||||
f"Unable to parse response. Original response: {chunk}"
|
|
||||||
)
|
|
||||||
results = parsed_response.get("results", [])
|
|
||||||
if len(results) > 0:
|
|
||||||
text = results[0].get("generated_text", "")
|
|
||||||
finish_reason = results[0].get("stop_reason")
|
|
||||||
is_finished = finish_reason != "not_finished"
|
|
||||||
return {
|
|
||||||
"text": text,
|
|
||||||
"is_finished": is_finished,
|
|
||||||
"finish_reason": finish_reason,
|
|
||||||
"prompt_tokens": results[0].get("input_token_count", 0),
|
|
||||||
"completion_tokens": results[0].get("generated_token_count", 0),
|
|
||||||
}
|
|
||||||
return {"text": "", "is_finished": False}
|
|
||||||
except Exception as e:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def handle_triton_stream(self, chunk):
|
def handle_triton_stream(self, chunk):
|
||||||
try:
|
try:
|
||||||
if isinstance(chunk, dict):
|
if isinstance(chunk, dict):
|
||||||
|
@ -702,9 +664,18 @@ class CustomStreamWrapper:
|
||||||
# pop model keyword
|
# pop model keyword
|
||||||
chunk.pop("model", None)
|
chunk.pop("model", None)
|
||||||
|
|
||||||
model_response = ModelResponse(
|
chunk_dict = {}
|
||||||
stream=True, model=_model, stream_options=self.stream_options, **chunk
|
for key, value in chunk.items():
|
||||||
)
|
if key != "stream":
|
||||||
|
chunk_dict[key] = value
|
||||||
|
|
||||||
|
args = {
|
||||||
|
"model": _model,
|
||||||
|
"stream_options": self.stream_options,
|
||||||
|
**chunk_dict,
|
||||||
|
}
|
||||||
|
|
||||||
|
model_response = ModelResponseStream(**args)
|
||||||
if self.response_id is not None:
|
if self.response_id is not None:
|
||||||
model_response.id = self.response_id
|
model_response.id = self.response_id
|
||||||
else:
|
else:
|
||||||
|
@ -742,9 +713,9 @@ class CustomStreamWrapper:
|
||||||
|
|
||||||
def return_processed_chunk_logic( # noqa
|
def return_processed_chunk_logic( # noqa
|
||||||
self,
|
self,
|
||||||
completion_obj: dict,
|
completion_obj: Dict[str, Any],
|
||||||
model_response: ModelResponseStream,
|
model_response: ModelResponseStream,
|
||||||
response_obj: dict,
|
response_obj: Dict[str, Any],
|
||||||
):
|
):
|
||||||
|
|
||||||
print_verbose(
|
print_verbose(
|
||||||
|
@ -887,11 +858,11 @@ class CustomStreamWrapper:
|
||||||
|
|
||||||
def chunk_creator(self, chunk): # type: ignore # noqa: PLR0915
|
def chunk_creator(self, chunk): # type: ignore # noqa: PLR0915
|
||||||
model_response = self.model_response_creator()
|
model_response = self.model_response_creator()
|
||||||
response_obj: dict = {}
|
response_obj: Dict[str, Any] = {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# return this for all models
|
# return this for all models
|
||||||
completion_obj = {"content": ""}
|
completion_obj: Dict[str, Any] = {"content": ""}
|
||||||
from litellm.types.utils import GenericStreamingChunk as GChunk
|
from litellm.types.utils import GenericStreamingChunk as GChunk
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
@ -1089,11 +1060,6 @@ class CustomStreamWrapper:
|
||||||
print_verbose(f"completion obj content: {completion_obj['content']}")
|
print_verbose(f"completion obj content: {completion_obj['content']}")
|
||||||
if response_obj["is_finished"]:
|
if response_obj["is_finished"]:
|
||||||
self.received_finish_reason = response_obj["finish_reason"]
|
self.received_finish_reason = response_obj["finish_reason"]
|
||||||
elif self.custom_llm_provider == "watsonx":
|
|
||||||
response_obj = self.handle_watsonx_stream(chunk)
|
|
||||||
completion_obj["content"] = response_obj["text"]
|
|
||||||
if response_obj["is_finished"]:
|
|
||||||
self.received_finish_reason = response_obj["finish_reason"]
|
|
||||||
elif self.custom_llm_provider == "triton":
|
elif self.custom_llm_provider == "triton":
|
||||||
response_obj = self.handle_triton_stream(chunk)
|
response_obj = self.handle_triton_stream(chunk)
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
|
@ -1158,7 +1124,7 @@ class CustomStreamWrapper:
|
||||||
self.received_finish_reason = response_obj["finish_reason"]
|
self.received_finish_reason = response_obj["finish_reason"]
|
||||||
else: # openai / azure chat model
|
else: # openai / azure chat model
|
||||||
if self.custom_llm_provider == "azure":
|
if self.custom_llm_provider == "azure":
|
||||||
if hasattr(chunk, "model"):
|
if isinstance(chunk, BaseModel) and hasattr(chunk, "model"):
|
||||||
# for azure, we need to pass the model from the orignal chunk
|
# for azure, we need to pass the model from the orignal chunk
|
||||||
self.model = chunk.model
|
self.model = chunk.model
|
||||||
response_obj = self.handle_openai_chat_completion_chunk(chunk)
|
response_obj = self.handle_openai_chat_completion_chunk(chunk)
|
||||||
|
@ -1190,7 +1156,10 @@ class CustomStreamWrapper:
|
||||||
|
|
||||||
if response_obj["usage"] is not None:
|
if response_obj["usage"] is not None:
|
||||||
if isinstance(response_obj["usage"], dict):
|
if isinstance(response_obj["usage"], dict):
|
||||||
model_response.usage = litellm.Usage(
|
setattr(
|
||||||
|
model_response,
|
||||||
|
"usage",
|
||||||
|
litellm.Usage(
|
||||||
prompt_tokens=response_obj["usage"].get(
|
prompt_tokens=response_obj["usage"].get(
|
||||||
"prompt_tokens", None
|
"prompt_tokens", None
|
||||||
)
|
)
|
||||||
|
@ -1199,12 +1168,17 @@ class CustomStreamWrapper:
|
||||||
"completion_tokens", None
|
"completion_tokens", None
|
||||||
)
|
)
|
||||||
or None,
|
or None,
|
||||||
total_tokens=response_obj["usage"].get("total_tokens", None)
|
total_tokens=response_obj["usage"].get(
|
||||||
|
"total_tokens", None
|
||||||
|
)
|
||||||
or None,
|
or None,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
elif isinstance(response_obj["usage"], BaseModel):
|
elif isinstance(response_obj["usage"], BaseModel):
|
||||||
model_response.usage = litellm.Usage(
|
setattr(
|
||||||
**response_obj["usage"].model_dump()
|
model_response,
|
||||||
|
"usage",
|
||||||
|
litellm.Usage(**response_obj["usage"].model_dump()),
|
||||||
)
|
)
|
||||||
|
|
||||||
model_response.model = self.model
|
model_response.model = self.model
|
||||||
|
@ -1337,7 +1311,7 @@ class CustomStreamWrapper:
|
||||||
raise StopIteration
|
raise StopIteration
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.format_exc()
|
traceback.format_exc()
|
||||||
e.message = str(e)
|
setattr(e, "message", str(e))
|
||||||
raise exception_type(
|
raise exception_type(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
custom_llm_provider=self.custom_llm_provider,
|
custom_llm_provider=self.custom_llm_provider,
|
||||||
|
@ -1434,7 +1408,9 @@ class CustomStreamWrapper:
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}; custom_llm_provider: {self.custom_llm_provider}"
|
f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}; custom_llm_provider: {self.custom_llm_provider}"
|
||||||
)
|
)
|
||||||
response: Optional[ModelResponse] = self.chunk_creator(chunk=chunk)
|
response: Optional[ModelResponseStream] = self.chunk_creator(
|
||||||
|
chunk=chunk
|
||||||
|
)
|
||||||
print_verbose(f"PROCESSED CHUNK POST CHUNK CREATOR: {response}")
|
print_verbose(f"PROCESSED CHUNK POST CHUNK CREATOR: {response}")
|
||||||
|
|
||||||
if response is None:
|
if response is None:
|
||||||
|
@ -1597,7 +1573,7 @@ class CustomStreamWrapper:
|
||||||
# __anext__ also calls async_success_handler, which does logging
|
# __anext__ also calls async_success_handler, which does logging
|
||||||
print_verbose(f"PROCESSED ASYNC CHUNK PRE CHUNK CREATOR: {chunk}")
|
print_verbose(f"PROCESSED ASYNC CHUNK PRE CHUNK CREATOR: {chunk}")
|
||||||
|
|
||||||
processed_chunk: Optional[ModelResponse] = self.chunk_creator(
|
processed_chunk: Optional[ModelResponseStream] = self.chunk_creator(
|
||||||
chunk=chunk
|
chunk=chunk
|
||||||
)
|
)
|
||||||
print_verbose(
|
print_verbose(
|
||||||
|
@ -1624,7 +1600,7 @@ class CustomStreamWrapper:
|
||||||
if self.logging_obj._llm_caching_handler is not None:
|
if self.logging_obj._llm_caching_handler is not None:
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
self.logging_obj._llm_caching_handler._add_streaming_response_to_cache(
|
self.logging_obj._llm_caching_handler._add_streaming_response_to_cache(
|
||||||
processed_chunk=processed_chunk,
|
processed_chunk=cast(ModelResponse, processed_chunk),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1663,8 +1639,8 @@ class CustomStreamWrapper:
|
||||||
chunk = next(self.completion_stream)
|
chunk = next(self.completion_stream)
|
||||||
if chunk is not None and chunk != b"":
|
if chunk is not None and chunk != b"":
|
||||||
print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}")
|
print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}")
|
||||||
processed_chunk: Optional[ModelResponse] = self.chunk_creator(
|
processed_chunk: Optional[ModelResponseStream] = (
|
||||||
chunk=chunk
|
self.chunk_creator(chunk=chunk)
|
||||||
)
|
)
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"PROCESSED CHUNK POST CHUNK CREATOR: {processed_chunk}"
|
f"PROCESSED CHUNK POST CHUNK CREATOR: {processed_chunk}"
|
||||||
|
|
|
@ -2,7 +2,6 @@
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import requests
|
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
|
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
|
||||||
|
@ -16,7 +15,7 @@ class BaseLLM:
|
||||||
def process_response(
|
def process_response(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
response: Union[requests.Response, httpx.Response],
|
response: httpx.Response,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
stream: bool,
|
stream: bool,
|
||||||
logging_obj: Any,
|
logging_obj: Any,
|
||||||
|
@ -35,7 +34,7 @@ class BaseLLM:
|
||||||
def process_text_completion_response(
|
def process_text_completion_response(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
response: Union[requests.Response, httpx.Response],
|
response: httpx.Response,
|
||||||
model_response: TextCompletionResponse,
|
model_response: TextCompletionResponse,
|
||||||
stream: bool,
|
stream: bool,
|
||||||
logging_obj: Any,
|
logging_obj: Any,
|
||||||
|
|
|
@ -107,7 +107,13 @@ class BaseConfig(ABC):
|
||||||
) -> dict:
|
) -> dict:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_complete_url(self, api_base: str, model: str) -> str:
|
def get_complete_url(
|
||||||
|
self,
|
||||||
|
api_base: str,
|
||||||
|
model: str,
|
||||||
|
optional_params: dict,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
OPTIONAL
|
OPTIONAL
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
from typing import TYPE_CHECKING, Any, List, Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||||
from litellm.types.llms.openai import AllMessageValues
|
from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues
|
||||||
from litellm.types.utils import EmbeddingResponse, ModelResponse
|
from litellm.types.utils import EmbeddingResponse, ModelResponse
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -16,12 +16,11 @@ else:
|
||||||
|
|
||||||
|
|
||||||
class BaseEmbeddingConfig(BaseConfig, ABC):
|
class BaseEmbeddingConfig(BaseConfig, ABC):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def transform_embedding_request(
|
def transform_embedding_request(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
input: Union[str, List[str], List[float], List[List[float]]],
|
input: AllEmbeddingInputValues,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
headers: dict,
|
headers: dict,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
@ -34,14 +33,20 @@ class BaseEmbeddingConfig(BaseConfig, ABC):
|
||||||
raw_response: httpx.Response,
|
raw_response: httpx.Response,
|
||||||
model_response: EmbeddingResponse,
|
model_response: EmbeddingResponse,
|
||||||
logging_obj: LiteLLMLoggingObj,
|
logging_obj: LiteLLMLoggingObj,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str],
|
||||||
request_data: dict = {},
|
request_data: dict,
|
||||||
optional_params: dict = {},
|
optional_params: dict,
|
||||||
litellm_params: dict = {},
|
litellm_params: dict,
|
||||||
) -> EmbeddingResponse:
|
) -> EmbeddingResponse:
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
|
def get_complete_url(
|
||||||
|
self,
|
||||||
|
api_base: Optional[str],
|
||||||
|
model: str,
|
||||||
|
optional_params: dict,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
OPTIONAL
|
OPTIONAL
|
||||||
|
|
||||||
|
|
|
@ -72,7 +72,13 @@ class CloudflareChatConfig(BaseConfig):
|
||||||
}
|
}
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
def get_complete_url(self, api_base: str, model: str) -> str:
|
def get_complete_url(
|
||||||
|
self,
|
||||||
|
api_base: str,
|
||||||
|
model: str,
|
||||||
|
optional_params: dict,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
) -> str:
|
||||||
return api_base + model
|
return api_base + model
|
||||||
|
|
||||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||||
|
|
|
@ -110,6 +110,8 @@ class BaseLLMHTTPHandler:
|
||||||
api_base = provider_config.get_complete_url(
|
api_base = provider_config.get_complete_url(
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
model=model,
|
model=model,
|
||||||
|
optional_params=optional_params,
|
||||||
|
stream=stream,
|
||||||
)
|
)
|
||||||
|
|
||||||
data = provider_config.transform_request(
|
data = provider_config.transform_request(
|
||||||
|
@ -402,6 +404,7 @@ class BaseLLMHTTPHandler:
|
||||||
logging_obj: LiteLLMLoggingObj,
|
logging_obj: LiteLLMLoggingObj,
|
||||||
api_base: Optional[str],
|
api_base: Optional[str],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
model_response: EmbeddingResponse,
|
model_response: EmbeddingResponse,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
|
@ -424,6 +427,7 @@ class BaseLLMHTTPHandler:
|
||||||
api_base = provider_config.get_complete_url(
|
api_base = provider_config.get_complete_url(
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
model=model,
|
model=model,
|
||||||
|
optional_params=optional_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
data = provider_config.transform_embedding_request(
|
data = provider_config.transform_embedding_request(
|
||||||
|
@ -457,6 +461,8 @@ class BaseLLMHTTPHandler:
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
client=client,
|
client=client,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
if client is None or not isinstance(client, HTTPHandler):
|
if client is None or not isinstance(client, HTTPHandler):
|
||||||
|
@ -484,6 +490,8 @@ class BaseLLMHTTPHandler:
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
request_data=data,
|
request_data=data,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def aembedding(
|
async def aembedding(
|
||||||
|
@ -496,6 +504,8 @@ class BaseLLMHTTPHandler:
|
||||||
provider_config: BaseEmbeddingConfig,
|
provider_config: BaseEmbeddingConfig,
|
||||||
model_response: EmbeddingResponse,
|
model_response: EmbeddingResponse,
|
||||||
logging_obj: LiteLLMLoggingObj,
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
|
@ -524,6 +534,8 @@ class BaseLLMHTTPHandler:
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
request_data=request_data,
|
request_data=request_data,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
def rerank(
|
def rerank(
|
||||||
|
|
|
@ -350,7 +350,13 @@ class OllamaConfig(BaseConfig):
|
||||||
) -> dict:
|
) -> dict:
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
def get_complete_url(self, api_base: str, model: str) -> str:
|
def get_complete_url(
|
||||||
|
self,
|
||||||
|
api_base: str,
|
||||||
|
model: str,
|
||||||
|
optional_params: dict,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
OPTIONAL
|
OPTIONAL
|
||||||
|
|
||||||
|
|
|
@ -168,7 +168,9 @@ def completion(
|
||||||
time.time()
|
time.time()
|
||||||
) # for pricing this must remain right before calling api
|
) # for pricing this must remain right before calling api
|
||||||
|
|
||||||
prediction_url = replicate_config.get_complete_url(api_base, model)
|
prediction_url = replicate_config.get_complete_url(
|
||||||
|
api_base=api_base, model=model, optional_params=optional_params
|
||||||
|
)
|
||||||
|
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
httpx_client = _get_httpx_client(
|
httpx_client = _get_httpx_client(
|
||||||
|
@ -235,7 +237,9 @@ async def async_completion(
|
||||||
headers: dict,
|
headers: dict,
|
||||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
|
|
||||||
prediction_url = replicate_config.get_complete_url(api_base=api_base, model=model)
|
prediction_url = replicate_config.get_complete_url(
|
||||||
|
api_base=api_base, model=model, optional_params=optional_params
|
||||||
|
)
|
||||||
async_handler = get_async_httpx_client(
|
async_handler = get_async_httpx_client(
|
||||||
llm_provider=litellm.LlmProviders.REPLICATE,
|
llm_provider=litellm.LlmProviders.REPLICATE,
|
||||||
params={"timeout": 600.0},
|
params={"timeout": 600.0},
|
||||||
|
|
|
@ -136,7 +136,13 @@ class ReplicateConfig(BaseConfig):
|
||||||
status_code=status_code, message=error_message, headers=headers
|
status_code=status_code, message=error_message, headers=headers
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_complete_url(self, api_base: str, model: str) -> str:
|
def get_complete_url(
|
||||||
|
self,
|
||||||
|
api_base: str,
|
||||||
|
model: str,
|
||||||
|
optional_params: dict,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
) -> str:
|
||||||
version_id = self.model_to_version_id(model)
|
version_id = self.model_to_version_id(model)
|
||||||
base_url = api_base
|
base_url = api_base
|
||||||
if "deployments" in version_id:
|
if "deployments" in version_id:
|
||||||
|
|
|
@ -7,6 +7,7 @@ from litellm.llms.base_llm.embedding.transformation import (
|
||||||
BaseEmbeddingConfig,
|
BaseEmbeddingConfig,
|
||||||
LiteLLMLoggingObj,
|
LiteLLMLoggingObj,
|
||||||
)
|
)
|
||||||
|
from litellm.types.llms.openai import AllEmbeddingInputValues
|
||||||
from litellm.types.utils import EmbeddingResponse
|
from litellm.types.utils import EmbeddingResponse
|
||||||
|
|
||||||
from ..common_utils import TritonError
|
from ..common_utils import TritonError
|
||||||
|
@ -48,7 +49,7 @@ class TritonEmbeddingConfig(BaseEmbeddingConfig):
|
||||||
def transform_embedding_request(
|
def transform_embedding_request(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
input: Union[str, List[str], List[float], List[List[float]]],
|
input: AllEmbeddingInputValues,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
headers: dict,
|
headers: dict,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -6,7 +6,7 @@ from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
||||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
|
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
from litellm.types.llms.openai import AllMessageValues
|
from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues
|
||||||
from litellm.types.utils import EmbeddingResponse, Usage
|
from litellm.types.utils import EmbeddingResponse, Usage
|
||||||
|
|
||||||
|
|
||||||
|
@ -38,7 +38,13 @@ class VoyageEmbeddingConfig(BaseEmbeddingConfig):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
|
def get_complete_url(
|
||||||
|
self,
|
||||||
|
api_base: Optional[str],
|
||||||
|
model: str,
|
||||||
|
optional_params: dict,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
) -> str:
|
||||||
if api_base:
|
if api_base:
|
||||||
if not api_base.endswith("/embeddings"):
|
if not api_base.endswith("/embeddings"):
|
||||||
api_base = f"{api_base}/embeddings"
|
api_base = f"{api_base}/embeddings"
|
||||||
|
@ -90,7 +96,7 @@ class VoyageEmbeddingConfig(BaseEmbeddingConfig):
|
||||||
def transform_embedding_request(
|
def transform_embedding_request(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
input: Union[str, List[str], List[float], List[List[float]]],
|
input: AllEmbeddingInputValues,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
headers: dict,
|
headers: dict,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
|
|
@ -3,57 +3,19 @@ from typing import Callable, Optional, Union
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.types.llms.watsonx import WatsonXAIEndpoint, WatsonXAPIParams
|
|
||||||
from litellm.types.utils import CustomStreamingDecoder, ModelResponse
|
from litellm.types.utils import CustomStreamingDecoder, ModelResponse
|
||||||
|
|
||||||
from ...openai_like.chat.handler import OpenAILikeChatHandler
|
from ...openai_like.chat.handler import OpenAILikeChatHandler
|
||||||
from ..common_utils import WatsonXAIError, _get_api_params
|
from ..common_utils import _get_api_params
|
||||||
|
from .transformation import IBMWatsonXChatConfig
|
||||||
|
|
||||||
|
watsonx_chat_transformation = IBMWatsonXChatConfig()
|
||||||
|
|
||||||
|
|
||||||
class WatsonXChatHandler(OpenAILikeChatHandler):
|
class WatsonXChatHandler(OpenAILikeChatHandler):
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def _prepare_url(
|
|
||||||
self, model: str, api_params: WatsonXAPIParams, stream: Optional[bool]
|
|
||||||
) -> str:
|
|
||||||
if model.startswith("deployment/"):
|
|
||||||
if api_params.get("space_id") is None:
|
|
||||||
raise WatsonXAIError(
|
|
||||||
status_code=401,
|
|
||||||
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
|
|
||||||
)
|
|
||||||
deployment_id = "/".join(model.split("/")[1:])
|
|
||||||
endpoint = (
|
|
||||||
WatsonXAIEndpoint.DEPLOYMENT_CHAT_STREAM.value
|
|
||||||
if stream is True
|
|
||||||
else WatsonXAIEndpoint.DEPLOYMENT_CHAT.value
|
|
||||||
)
|
|
||||||
endpoint = endpoint.format(deployment_id=deployment_id)
|
|
||||||
else:
|
|
||||||
endpoint = (
|
|
||||||
WatsonXAIEndpoint.CHAT_STREAM.value
|
|
||||||
if stream is True
|
|
||||||
else WatsonXAIEndpoint.CHAT.value
|
|
||||||
)
|
|
||||||
base_url = httpx.URL(api_params["url"])
|
|
||||||
base_url = base_url.join(endpoint)
|
|
||||||
full_url = str(
|
|
||||||
base_url.copy_add_param(key="version", value=api_params["api_version"])
|
|
||||||
)
|
|
||||||
|
|
||||||
return full_url
|
|
||||||
|
|
||||||
def _prepare_payload(
|
|
||||||
self, model: str, api_params: WatsonXAPIParams, stream: Optional[bool]
|
|
||||||
) -> dict:
|
|
||||||
payload: dict = {}
|
|
||||||
if model.startswith("deployment/"):
|
|
||||||
return payload
|
|
||||||
payload["model_id"] = model
|
|
||||||
payload["project_id"] = api_params["project_id"]
|
|
||||||
return payload
|
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
@ -70,32 +32,37 @@ class WatsonXChatHandler(OpenAILikeChatHandler):
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
acompletion=None,
|
acompletion=None,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
|
||||||
headers: Optional[dict] = None,
|
headers: Optional[dict] = None,
|
||||||
|
logger_fn=None,
|
||||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
custom_endpoint: Optional[bool] = None,
|
custom_endpoint: Optional[bool] = None,
|
||||||
streaming_decoder: Optional[CustomStreamingDecoder] = None,
|
streaming_decoder: Optional[CustomStreamingDecoder] = None,
|
||||||
fake_stream: bool = False,
|
fake_stream: bool = False,
|
||||||
):
|
):
|
||||||
api_params = _get_api_params(optional_params, print_verbose=print_verbose)
|
api_params = _get_api_params(params=optional_params)
|
||||||
|
|
||||||
if headers is None:
|
## UPDATE HEADERS
|
||||||
headers = {}
|
headers = watsonx_chat_transformation.validate_environment(
|
||||||
headers.update(
|
headers=headers or {},
|
||||||
{
|
model=model,
|
||||||
"Authorization": f"Bearer {api_params['token']}",
|
messages=messages,
|
||||||
"Content-Type": "application/json",
|
optional_params=optional_params,
|
||||||
"Accept": "application/json",
|
api_key=api_key,
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
stream: Optional[bool] = optional_params.get("stream", False)
|
## GET API URL
|
||||||
|
api_base = watsonx_chat_transformation.get_complete_url(
|
||||||
|
api_base=api_base,
|
||||||
|
model=model,
|
||||||
|
optional_params=optional_params,
|
||||||
|
stream=optional_params.get("stream", False),
|
||||||
|
)
|
||||||
|
|
||||||
## get api url and payload
|
## UPDATE PAYLOAD (optional params)
|
||||||
api_base = self._prepare_url(model=model, api_params=api_params, stream=stream)
|
watsonx_auth_payload = watsonx_chat_transformation._prepare_payload(
|
||||||
watsonx_auth_payload = self._prepare_payload(
|
model=model,
|
||||||
model=model, api_params=api_params, stream=stream
|
api_params=api_params,
|
||||||
)
|
)
|
||||||
optional_params.update(watsonx_auth_payload)
|
optional_params.update(watsonx_auth_payload)
|
||||||
|
|
||||||
|
|
|
@ -7,12 +7,14 @@ Docs: https://cloud.ibm.com/apidocs/watsonx-ai#text-chat
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
|
from litellm.types.llms.watsonx import WatsonXAIEndpoint, WatsonXAPIParams
|
||||||
|
|
||||||
from ....utils import _remove_additional_properties, _remove_strict_from_schema
|
from ....utils import _remove_additional_properties, _remove_strict_from_schema
|
||||||
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||||
|
from ..common_utils import IBMWatsonXMixin, WatsonXAIError
|
||||||
|
|
||||||
|
|
||||||
class IBMWatsonXChatConfig(OpenAIGPTConfig):
|
class IBMWatsonXChatConfig(IBMWatsonXMixin, OpenAIGPTConfig):
|
||||||
|
|
||||||
def get_supported_openai_params(self, model: str) -> List:
|
def get_supported_openai_params(self, model: str) -> List:
|
||||||
return [
|
return [
|
||||||
|
@ -75,3 +77,47 @@ class IBMWatsonXChatConfig(OpenAIGPTConfig):
|
||||||
api_key or get_secret_str("HOSTED_VLLM_API_KEY") or ""
|
api_key or get_secret_str("HOSTED_VLLM_API_KEY") or ""
|
||||||
) # vllm does not require an api key
|
) # vllm does not require an api key
|
||||||
return api_base, dynamic_api_key
|
return api_base, dynamic_api_key
|
||||||
|
|
||||||
|
def get_complete_url(
|
||||||
|
self,
|
||||||
|
api_base: str,
|
||||||
|
model: str,
|
||||||
|
optional_params: dict,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
) -> str:
|
||||||
|
url = self._get_base_url(api_base=api_base)
|
||||||
|
if model.startswith("deployment/"):
|
||||||
|
# deployment models are passed in as 'deployment/<deployment_id>'
|
||||||
|
if optional_params.get("space_id") is None:
|
||||||
|
raise WatsonXAIError(
|
||||||
|
status_code=401,
|
||||||
|
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
|
||||||
|
)
|
||||||
|
deployment_id = "/".join(model.split("/")[1:])
|
||||||
|
endpoint = (
|
||||||
|
WatsonXAIEndpoint.DEPLOYMENT_CHAT_STREAM.value
|
||||||
|
if stream
|
||||||
|
else WatsonXAIEndpoint.DEPLOYMENT_CHAT.value
|
||||||
|
)
|
||||||
|
endpoint = endpoint.format(deployment_id=deployment_id)
|
||||||
|
else:
|
||||||
|
endpoint = (
|
||||||
|
WatsonXAIEndpoint.DEPLOYMENT_CHAT_STREAM.value
|
||||||
|
if stream
|
||||||
|
else WatsonXAIEndpoint.DEPLOYMENT_CHAT.value
|
||||||
|
)
|
||||||
|
url = url.rstrip("/") + endpoint
|
||||||
|
|
||||||
|
## add api version
|
||||||
|
url = self._add_api_version_to_url(
|
||||||
|
url=url, api_version=optional_params.pop("api_version", None)
|
||||||
|
)
|
||||||
|
return url
|
||||||
|
|
||||||
|
def _prepare_payload(self, model: str, api_params: WatsonXAPIParams) -> dict:
|
||||||
|
payload: dict = {}
|
||||||
|
if model.startswith("deployment/"):
|
||||||
|
return payload
|
||||||
|
payload["model_id"] = model
|
||||||
|
payload["project_id"] = api_params["project_id"]
|
||||||
|
return payload
|
||||||
|
|
|
@ -1,13 +1,15 @@
|
||||||
from typing import Callable, Dict, Optional, Union, cast
|
from typing import Dict, List, Optional, Union, cast
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import verbose_logger
|
from litellm import verbose_logger
|
||||||
from litellm.caching import InMemoryCache
|
from litellm.caching import InMemoryCache
|
||||||
|
from litellm.litellm_core_utils.prompt_templates import factory as ptf
|
||||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
from litellm.types.llms.watsonx import WatsonXAPIParams
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
|
from litellm.types.llms.watsonx import WatsonXAPIParams, WatsonXCredentials
|
||||||
|
|
||||||
|
|
||||||
class WatsonXAIError(BaseLLMException):
|
class WatsonXAIError(BaseLLMException):
|
||||||
|
@ -65,18 +67,20 @@ def generate_iam_token(api_key=None, **params) -> str:
|
||||||
return cast(str, result)
|
return cast(str, result)
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_watsonx_token(api_key: Optional[str], token: Optional[str]) -> str:
|
||||||
|
if token is not None:
|
||||||
|
return token
|
||||||
|
token = generate_iam_token(api_key)
|
||||||
|
return token
|
||||||
|
|
||||||
|
|
||||||
def _get_api_params(
|
def _get_api_params(
|
||||||
params: dict,
|
params: dict,
|
||||||
print_verbose: Optional[Callable] = None,
|
|
||||||
generate_token: Optional[bool] = True,
|
|
||||||
) -> WatsonXAPIParams:
|
) -> WatsonXAPIParams:
|
||||||
"""
|
"""
|
||||||
Find watsonx.ai credentials in the params or environment variables and return the headers for authentication.
|
Find watsonx.ai credentials in the params or environment variables and return the headers for authentication.
|
||||||
"""
|
"""
|
||||||
# Load auth variables from params
|
# Load auth variables from params
|
||||||
url = params.pop("url", params.pop("api_base", params.pop("base_url", None)))
|
|
||||||
api_key = params.pop("apikey", None)
|
|
||||||
token = params.pop("token", None)
|
|
||||||
project_id = params.pop(
|
project_id = params.pop(
|
||||||
"project_id", params.pop("watsonx_project", None)
|
"project_id", params.pop("watsonx_project", None)
|
||||||
) # watsonx.ai project_id - allow 'watsonx_project' to be consistent with how vertex project implementation works -> reduce provider-specific params
|
) # watsonx.ai project_id - allow 'watsonx_project' to be consistent with how vertex project implementation works -> reduce provider-specific params
|
||||||
|
@ -86,29 +90,8 @@ def _get_api_params(
|
||||||
region_name = params.pop(
|
region_name = params.pop(
|
||||||
"watsonx_region_name", params.pop("watsonx_region", None)
|
"watsonx_region_name", params.pop("watsonx_region", None)
|
||||||
) # consistent with how vertex ai + aws regions are accepted
|
) # consistent with how vertex ai + aws regions are accepted
|
||||||
wx_credentials = params.pop(
|
|
||||||
"wx_credentials",
|
|
||||||
params.pop(
|
|
||||||
"watsonx_credentials", None
|
|
||||||
), # follow {provider}_credentials, same as vertex ai
|
|
||||||
)
|
|
||||||
api_version = params.pop("api_version", litellm.WATSONX_DEFAULT_API_VERSION)
|
|
||||||
# Load auth variables from environment variables
|
# Load auth variables from environment variables
|
||||||
if url is None:
|
|
||||||
url = (
|
|
||||||
get_secret_str("WATSONX_API_BASE") # consistent with 'AZURE_API_BASE'
|
|
||||||
or get_secret_str("WATSONX_URL")
|
|
||||||
or get_secret_str("WX_URL")
|
|
||||||
or get_secret_str("WML_URL")
|
|
||||||
)
|
|
||||||
if api_key is None:
|
|
||||||
api_key = (
|
|
||||||
get_secret_str("WATSONX_APIKEY")
|
|
||||||
or get_secret_str("WATSONX_API_KEY")
|
|
||||||
or get_secret_str("WX_API_KEY")
|
|
||||||
)
|
|
||||||
if token is None:
|
|
||||||
token = get_secret_str("WATSONX_TOKEN") or get_secret_str("WX_TOKEN")
|
|
||||||
if project_id is None:
|
if project_id is None:
|
||||||
project_id = (
|
project_id = (
|
||||||
get_secret_str("WATSONX_PROJECT_ID")
|
get_secret_str("WATSONX_PROJECT_ID")
|
||||||
|
@ -129,34 +112,6 @@ def _get_api_params(
|
||||||
or get_secret_str("SPACE_ID")
|
or get_secret_str("SPACE_ID")
|
||||||
)
|
)
|
||||||
|
|
||||||
# credentials parsing
|
|
||||||
if wx_credentials is not None:
|
|
||||||
url = wx_credentials.get("url", url)
|
|
||||||
api_key = wx_credentials.get("apikey", wx_credentials.get("api_key", api_key))
|
|
||||||
token = wx_credentials.get(
|
|
||||||
"token",
|
|
||||||
wx_credentials.get(
|
|
||||||
"watsonx_token", token
|
|
||||||
), # follow format of {provider}_token, same as azure - e.g. 'azure_ad_token=..'
|
|
||||||
)
|
|
||||||
|
|
||||||
# verify that all required credentials are present
|
|
||||||
if url is None:
|
|
||||||
raise WatsonXAIError(
|
|
||||||
status_code=401,
|
|
||||||
message="Error: Watsonx URL not set. Set WX_URL in environment variables or pass in as a parameter.",
|
|
||||||
)
|
|
||||||
|
|
||||||
if token is None and api_key is not None and generate_token:
|
|
||||||
# generate the auth token
|
|
||||||
if print_verbose is not None:
|
|
||||||
print_verbose("Generating IAM token for Watsonx.ai")
|
|
||||||
token = generate_iam_token(api_key)
|
|
||||||
elif token is None and api_key is None:
|
|
||||||
raise WatsonXAIError(
|
|
||||||
status_code=401,
|
|
||||||
message="Error: API key or token not found. Set WX_API_KEY or WX_TOKEN in environment variables or pass in as a parameter.",
|
|
||||||
)
|
|
||||||
if project_id is None:
|
if project_id is None:
|
||||||
raise WatsonXAIError(
|
raise WatsonXAIError(
|
||||||
status_code=401,
|
status_code=401,
|
||||||
|
@ -164,11 +119,147 @@ def _get_api_params(
|
||||||
)
|
)
|
||||||
|
|
||||||
return WatsonXAPIParams(
|
return WatsonXAPIParams(
|
||||||
url=url,
|
|
||||||
api_key=api_key,
|
|
||||||
token=cast(str, token),
|
|
||||||
project_id=project_id,
|
project_id=project_id,
|
||||||
space_id=space_id,
|
space_id=space_id,
|
||||||
region_name=region_name,
|
region_name=region_name,
|
||||||
api_version=api_version,
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_watsonx_messages_to_prompt(
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
provider: str,
|
||||||
|
custom_prompt_dict: Dict,
|
||||||
|
) -> str:
|
||||||
|
# handle anthropic prompts and amazon titan prompts
|
||||||
|
if model in custom_prompt_dict:
|
||||||
|
# check if the model has a registered custom prompt
|
||||||
|
model_prompt_dict = custom_prompt_dict[model]
|
||||||
|
prompt = ptf.custom_prompt(
|
||||||
|
messages=messages,
|
||||||
|
role_dict=model_prompt_dict.get(
|
||||||
|
"role_dict", model_prompt_dict.get("roles")
|
||||||
|
),
|
||||||
|
initial_prompt_value=model_prompt_dict.get("initial_prompt_value", ""),
|
||||||
|
final_prompt_value=model_prompt_dict.get("final_prompt_value", ""),
|
||||||
|
bos_token=model_prompt_dict.get("bos_token", ""),
|
||||||
|
eos_token=model_prompt_dict.get("eos_token", ""),
|
||||||
|
)
|
||||||
|
return prompt
|
||||||
|
elif provider == "ibm-mistralai":
|
||||||
|
prompt = ptf.mistral_instruct_pt(messages=messages)
|
||||||
|
else:
|
||||||
|
prompt: str = ptf.prompt_factory( # type: ignore
|
||||||
|
model=model, messages=messages, custom_llm_provider="watsonx"
|
||||||
|
)
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
# Mixin class for shared IBM Watson X functionality
|
||||||
|
class IBMWatsonXMixin:
|
||||||
|
def validate_environment(
|
||||||
|
self,
|
||||||
|
headers: Dict,
|
||||||
|
model: str,
|
||||||
|
messages: List[AllMessageValues],
|
||||||
|
optional_params: Dict,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
) -> Dict:
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept": "application/json",
|
||||||
|
}
|
||||||
|
token = cast(Optional[str], optional_params.get("token"))
|
||||||
|
if token:
|
||||||
|
headers["Authorization"] = f"Bearer {token}"
|
||||||
|
else:
|
||||||
|
token = _generate_watsonx_token(api_key=api_key, token=token)
|
||||||
|
# build auth headers
|
||||||
|
headers["Authorization"] = f"Bearer {token}"
|
||||||
|
return headers
|
||||||
|
|
||||||
|
def _get_base_url(self, api_base: Optional[str]) -> str:
|
||||||
|
url = (
|
||||||
|
api_base
|
||||||
|
or get_secret_str("WATSONX_API_BASE") # consistent with 'AZURE_API_BASE'
|
||||||
|
or get_secret_str("WATSONX_URL")
|
||||||
|
or get_secret_str("WX_URL")
|
||||||
|
or get_secret_str("WML_URL")
|
||||||
|
)
|
||||||
|
|
||||||
|
if url is None:
|
||||||
|
raise WatsonXAIError(
|
||||||
|
status_code=401,
|
||||||
|
message="Error: Watsonx URL not set. Set WATSONX_API_BASE in environment variables or pass in as parameter - 'api_base='.",
|
||||||
|
)
|
||||||
|
return url
|
||||||
|
|
||||||
|
def _add_api_version_to_url(self, url: str, api_version: Optional[str]) -> str:
|
||||||
|
api_version = api_version or litellm.WATSONX_DEFAULT_API_VERSION
|
||||||
|
url = url + f"?version={api_version}"
|
||||||
|
|
||||||
|
return url
|
||||||
|
|
||||||
|
def get_error_class(
|
||||||
|
self, error_message: str, status_code: int, headers: Union[Dict, httpx.Headers]
|
||||||
|
) -> BaseLLMException:
|
||||||
|
return WatsonXAIError(
|
||||||
|
status_code=status_code, message=error_message, headers=headers
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_watsonx_credentials(
|
||||||
|
optional_params: dict, api_key: Optional[str], api_base: Optional[str]
|
||||||
|
) -> WatsonXCredentials:
|
||||||
|
api_key = (
|
||||||
|
api_key
|
||||||
|
or optional_params.pop("apikey", None)
|
||||||
|
or get_secret_str("WATSONX_APIKEY")
|
||||||
|
or get_secret_str("WATSONX_API_KEY")
|
||||||
|
or get_secret_str("WX_API_KEY")
|
||||||
|
)
|
||||||
|
|
||||||
|
api_base = (
|
||||||
|
api_base
|
||||||
|
or optional_params.pop(
|
||||||
|
"url",
|
||||||
|
optional_params.pop("api_base", optional_params.pop("base_url", None)),
|
||||||
|
)
|
||||||
|
or get_secret_str("WATSONX_API_BASE")
|
||||||
|
or get_secret_str("WATSONX_URL")
|
||||||
|
or get_secret_str("WX_URL")
|
||||||
|
or get_secret_str("WML_URL")
|
||||||
|
)
|
||||||
|
|
||||||
|
wx_credentials = optional_params.pop(
|
||||||
|
"wx_credentials",
|
||||||
|
optional_params.pop(
|
||||||
|
"watsonx_credentials", None
|
||||||
|
), # follow {provider}_credentials, same as vertex ai
|
||||||
|
)
|
||||||
|
|
||||||
|
token: Optional[str] = None
|
||||||
|
if wx_credentials is not None:
|
||||||
|
api_base = wx_credentials.get("url", api_base)
|
||||||
|
api_key = wx_credentials.get(
|
||||||
|
"apikey", wx_credentials.get("api_key", api_key)
|
||||||
|
)
|
||||||
|
token = wx_credentials.get(
|
||||||
|
"token",
|
||||||
|
wx_credentials.get(
|
||||||
|
"watsonx_token", None
|
||||||
|
), # follow format of {provider}_token, same as azure - e.g. 'azure_ad_token=..'
|
||||||
|
)
|
||||||
|
if api_key is None or not isinstance(api_key, str):
|
||||||
|
raise WatsonXAIError(
|
||||||
|
status_code=401,
|
||||||
|
message="Error: Watsonx API key not set. Set WATSONX_API_KEY in environment variables or pass in as parameter - 'api_key='.",
|
||||||
|
)
|
||||||
|
if api_base is None or not isinstance(api_base, str):
|
||||||
|
raise WatsonXAIError(
|
||||||
|
status_code=401,
|
||||||
|
message="Error: Watsonx API base not set. Set WATSONX_API_BASE in environment variables or pass in as parameter - 'api_base='.",
|
||||||
|
)
|
||||||
|
return WatsonXCredentials(
|
||||||
|
api_key=api_key, api_base=api_base, token=cast(Optional[str], token)
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,551 +1,3 @@
|
||||||
import asyncio
|
|
||||||
import json # noqa: E401
|
|
||||||
import time
|
|
||||||
from contextlib import asynccontextmanager, contextmanager
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import (
|
|
||||||
Any,
|
|
||||||
AsyncGenerator,
|
|
||||||
AsyncIterator,
|
|
||||||
Callable,
|
|
||||||
Generator,
|
|
||||||
Iterator,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
import httpx # type: ignore
|
|
||||||
import requests # type: ignore
|
|
||||||
|
|
||||||
import litellm
|
|
||||||
from litellm.litellm_core_utils.prompt_templates import factory as ptf
|
|
||||||
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
|
|
||||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
|
||||||
from litellm.types.llms.openai import AllMessageValues
|
|
||||||
from litellm.types.llms.watsonx import WatsonXAIEndpoint
|
|
||||||
from litellm.utils import EmbeddingResponse, ModelResponse, Usage, map_finish_reason
|
|
||||||
|
|
||||||
from ...base import BaseLLM
|
|
||||||
from ..common_utils import WatsonXAIError, _get_api_params
|
|
||||||
from .transformation import IBMWatsonXAIConfig
|
|
||||||
|
|
||||||
|
|
||||||
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict) -> str:
|
|
||||||
# handle anthropic prompts and amazon titan prompts
|
|
||||||
if model in custom_prompt_dict:
|
|
||||||
# check if the model has a registered custom prompt
|
|
||||||
model_prompt_dict = custom_prompt_dict[model]
|
|
||||||
prompt = ptf.custom_prompt(
|
|
||||||
messages=messages,
|
|
||||||
role_dict=model_prompt_dict.get(
|
|
||||||
"role_dict", model_prompt_dict.get("roles")
|
|
||||||
),
|
|
||||||
initial_prompt_value=model_prompt_dict.get("initial_prompt_value", ""),
|
|
||||||
final_prompt_value=model_prompt_dict.get("final_prompt_value", ""),
|
|
||||||
bos_token=model_prompt_dict.get("bos_token", ""),
|
|
||||||
eos_token=model_prompt_dict.get("eos_token", ""),
|
|
||||||
)
|
|
||||||
return prompt
|
|
||||||
elif provider == "ibm-mistralai":
|
|
||||||
prompt = ptf.mistral_instruct_pt(messages=messages)
|
|
||||||
else:
|
|
||||||
prompt: str = ptf.prompt_factory( # type: ignore
|
|
||||||
model=model, messages=messages, custom_llm_provider="watsonx"
|
|
||||||
)
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
|
|
||||||
class IBMWatsonXAI(BaseLLM):
|
|
||||||
"""
|
"""
|
||||||
Class to interface with IBM watsonx.ai API for text generation and embeddings.
|
Watsonx uses the llm_http_handler.py to handle the requests.
|
||||||
|
|
||||||
Reference: https://cloud.ibm.com/apidocs/watsonx-ai
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
api_version = "2024-03-13"
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def _prepare_text_generation_req(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
messages: List[AllMessageValues],
|
|
||||||
prompt: str,
|
|
||||||
stream: bool,
|
|
||||||
optional_params: dict,
|
|
||||||
print_verbose: Optional[Callable] = None,
|
|
||||||
) -> dict:
|
|
||||||
"""
|
|
||||||
Get the request parameters for text generation.
|
|
||||||
"""
|
|
||||||
api_params = _get_api_params(optional_params, print_verbose=print_verbose)
|
|
||||||
# build auth headers
|
|
||||||
api_token = api_params.get("token")
|
|
||||||
self.token = api_token
|
|
||||||
headers = IBMWatsonXAIConfig().validate_environment(
|
|
||||||
headers={},
|
|
||||||
model=model_id,
|
|
||||||
messages=messages,
|
|
||||||
optional_params=optional_params,
|
|
||||||
api_key=api_token,
|
|
||||||
)
|
|
||||||
extra_body_params = optional_params.pop("extra_body", {})
|
|
||||||
optional_params.update(extra_body_params)
|
|
||||||
# init the payload to the text generation call
|
|
||||||
payload = {
|
|
||||||
"input": prompt,
|
|
||||||
"moderations": optional_params.pop("moderations", {}),
|
|
||||||
"parameters": optional_params,
|
|
||||||
}
|
|
||||||
request_params = dict(version=api_params["api_version"])
|
|
||||||
# text generation endpoint deployment or model / stream or not
|
|
||||||
if model_id.startswith("deployment/"):
|
|
||||||
# deployment models are passed in as 'deployment/<deployment_id>'
|
|
||||||
if api_params.get("space_id") is None:
|
|
||||||
raise WatsonXAIError(
|
|
||||||
status_code=401,
|
|
||||||
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
|
|
||||||
)
|
|
||||||
deployment_id = "/".join(model_id.split("/")[1:])
|
|
||||||
endpoint = (
|
|
||||||
WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION_STREAM.value
|
|
||||||
if stream
|
|
||||||
else WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION.value
|
|
||||||
)
|
|
||||||
endpoint = endpoint.format(deployment_id=deployment_id)
|
|
||||||
else:
|
|
||||||
payload["model_id"] = model_id
|
|
||||||
payload["project_id"] = api_params["project_id"]
|
|
||||||
endpoint = (
|
|
||||||
WatsonXAIEndpoint.TEXT_GENERATION_STREAM
|
|
||||||
if stream
|
|
||||||
else WatsonXAIEndpoint.TEXT_GENERATION
|
|
||||||
)
|
|
||||||
url = api_params["url"].rstrip("/") + endpoint
|
|
||||||
return dict(
|
|
||||||
method="POST", url=url, headers=headers, json=payload, params=request_params
|
|
||||||
)
|
|
||||||
|
|
||||||
def _process_text_gen_response(
|
|
||||||
self, json_resp: dict, model_response: Union[ModelResponse, None] = None
|
|
||||||
) -> ModelResponse:
|
|
||||||
if "results" not in json_resp:
|
|
||||||
raise WatsonXAIError(
|
|
||||||
status_code=500,
|
|
||||||
message=f"Error: Invalid response from Watsonx.ai API: {json_resp}",
|
|
||||||
)
|
|
||||||
if model_response is None:
|
|
||||||
model_response = ModelResponse(model=json_resp.get("model_id", None))
|
|
||||||
generated_text = json_resp["results"][0]["generated_text"]
|
|
||||||
prompt_tokens = json_resp["results"][0]["input_token_count"]
|
|
||||||
completion_tokens = json_resp["results"][0]["generated_token_count"]
|
|
||||||
model_response.choices[0].message.content = generated_text # type: ignore
|
|
||||||
model_response.choices[0].finish_reason = map_finish_reason(
|
|
||||||
json_resp["results"][0]["stop_reason"]
|
|
||||||
)
|
|
||||||
if json_resp.get("created_at"):
|
|
||||||
model_response.created = int(
|
|
||||||
datetime.fromisoformat(json_resp["created_at"]).timestamp()
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
model_response.created = int(time.time())
|
|
||||||
usage = Usage(
|
|
||||||
prompt_tokens=prompt_tokens,
|
|
||||||
completion_tokens=completion_tokens,
|
|
||||||
total_tokens=prompt_tokens + completion_tokens,
|
|
||||||
)
|
|
||||||
setattr(model_response, "usage", usage)
|
|
||||||
return model_response
|
|
||||||
|
|
||||||
def completion(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
messages: list,
|
|
||||||
custom_prompt_dict: dict,
|
|
||||||
model_response: ModelResponse,
|
|
||||||
print_verbose: Callable,
|
|
||||||
encoding,
|
|
||||||
logging_obj: Any,
|
|
||||||
optional_params: dict,
|
|
||||||
acompletion=None,
|
|
||||||
litellm_params=None,
|
|
||||||
logger_fn=None,
|
|
||||||
timeout=None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Send a text generation request to the IBM Watsonx.ai API.
|
|
||||||
Reference: https://cloud.ibm.com/apidocs/watsonx-ai#text-generation
|
|
||||||
"""
|
|
||||||
stream = optional_params.pop("stream", False)
|
|
||||||
|
|
||||||
# Load default configs
|
|
||||||
config = IBMWatsonXAIConfig.get_config()
|
|
||||||
for k, v in config.items():
|
|
||||||
if k not in optional_params:
|
|
||||||
optional_params[k] = v
|
|
||||||
|
|
||||||
# Make prompt to send to model
|
|
||||||
provider = model.split("/")[0]
|
|
||||||
# model_name = "/".join(model.split("/")[1:])
|
|
||||||
prompt = convert_messages_to_prompt(
|
|
||||||
model, messages, provider, custom_prompt_dict
|
|
||||||
)
|
|
||||||
model_response.model = model
|
|
||||||
|
|
||||||
def process_stream_response(
|
|
||||||
stream_resp: Union[Iterator[str], AsyncIterator],
|
|
||||||
) -> CustomStreamWrapper:
|
|
||||||
streamwrapper = litellm.CustomStreamWrapper(
|
|
||||||
stream_resp,
|
|
||||||
model=model,
|
|
||||||
custom_llm_provider="watsonx",
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
)
|
|
||||||
return streamwrapper
|
|
||||||
|
|
||||||
# create the function to manage the request to watsonx.ai
|
|
||||||
self.request_manager = RequestManager(logging_obj)
|
|
||||||
|
|
||||||
def handle_text_request(request_params: dict) -> ModelResponse:
|
|
||||||
with self.request_manager.request(
|
|
||||||
request_params,
|
|
||||||
input=prompt,
|
|
||||||
timeout=timeout,
|
|
||||||
) as resp:
|
|
||||||
json_resp = resp.json()
|
|
||||||
|
|
||||||
return self._process_text_gen_response(json_resp, model_response)
|
|
||||||
|
|
||||||
async def handle_text_request_async(request_params: dict) -> ModelResponse:
|
|
||||||
async with self.request_manager.async_request(
|
|
||||||
request_params,
|
|
||||||
input=prompt,
|
|
||||||
timeout=timeout,
|
|
||||||
) as resp:
|
|
||||||
json_resp = resp.json()
|
|
||||||
return self._process_text_gen_response(json_resp, model_response)
|
|
||||||
|
|
||||||
def handle_stream_request(request_params: dict) -> CustomStreamWrapper:
|
|
||||||
# stream the response - generated chunks will be handled
|
|
||||||
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
|
|
||||||
with self.request_manager.request(
|
|
||||||
request_params,
|
|
||||||
stream=True,
|
|
||||||
input=prompt,
|
|
||||||
timeout=timeout,
|
|
||||||
) as resp:
|
|
||||||
streamwrapper = process_stream_response(resp.iter_lines())
|
|
||||||
return streamwrapper
|
|
||||||
|
|
||||||
async def handle_stream_request_async(
|
|
||||||
request_params: dict,
|
|
||||||
) -> CustomStreamWrapper:
|
|
||||||
# stream the response - generated chunks will be handled
|
|
||||||
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
|
|
||||||
async with self.request_manager.async_request(
|
|
||||||
request_params,
|
|
||||||
stream=True,
|
|
||||||
input=prompt,
|
|
||||||
timeout=timeout,
|
|
||||||
) as resp:
|
|
||||||
streamwrapper = process_stream_response(resp.aiter_lines())
|
|
||||||
return streamwrapper
|
|
||||||
|
|
||||||
try:
|
|
||||||
## Get the response from the model
|
|
||||||
req_params = self._prepare_text_generation_req(
|
|
||||||
model_id=model,
|
|
||||||
prompt=prompt,
|
|
||||||
messages=messages,
|
|
||||||
stream=stream,
|
|
||||||
optional_params=optional_params,
|
|
||||||
print_verbose=print_verbose,
|
|
||||||
)
|
|
||||||
if stream and (acompletion is True):
|
|
||||||
# stream and async text generation
|
|
||||||
return handle_stream_request_async(req_params)
|
|
||||||
elif stream:
|
|
||||||
# streaming text generation
|
|
||||||
return handle_stream_request(req_params)
|
|
||||||
elif acompletion is True:
|
|
||||||
# async text generation
|
|
||||||
return handle_text_request_async(req_params)
|
|
||||||
else:
|
|
||||||
# regular text generation
|
|
||||||
return handle_text_request(req_params)
|
|
||||||
except WatsonXAIError as e:
|
|
||||||
raise e
|
|
||||||
except Exception as e:
|
|
||||||
raise WatsonXAIError(status_code=500, message=str(e))
|
|
||||||
|
|
||||||
def _process_embedding_response(
|
|
||||||
self, json_resp: dict, model_response: Optional[EmbeddingResponse] = None
|
|
||||||
) -> EmbeddingResponse:
|
|
||||||
if model_response is None:
|
|
||||||
model_response = EmbeddingResponse(model=json_resp.get("model_id", None))
|
|
||||||
results = json_resp.get("results", [])
|
|
||||||
embedding_response = []
|
|
||||||
for idx, result in enumerate(results):
|
|
||||||
embedding_response.append(
|
|
||||||
{
|
|
||||||
"object": "embedding",
|
|
||||||
"index": idx,
|
|
||||||
"embedding": result["embedding"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
model_response.object = "list"
|
|
||||||
model_response.data = embedding_response
|
|
||||||
input_tokens = json_resp.get("input_token_count", 0)
|
|
||||||
setattr(
|
|
||||||
model_response,
|
|
||||||
"usage",
|
|
||||||
Usage(
|
|
||||||
prompt_tokens=input_tokens,
|
|
||||||
completion_tokens=0,
|
|
||||||
total_tokens=input_tokens,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return model_response
|
|
||||||
|
|
||||||
def embedding(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
input: Union[list, str],
|
|
||||||
model_response: EmbeddingResponse,
|
|
||||||
api_key: Optional[str],
|
|
||||||
logging_obj: Any,
|
|
||||||
optional_params: dict,
|
|
||||||
encoding=None,
|
|
||||||
print_verbose=None,
|
|
||||||
aembedding=None,
|
|
||||||
) -> EmbeddingResponse:
|
|
||||||
"""
|
|
||||||
Send a text embedding request to the IBM Watsonx.ai API.
|
|
||||||
"""
|
|
||||||
if optional_params is None:
|
|
||||||
optional_params = {}
|
|
||||||
# Load default configs
|
|
||||||
config = IBMWatsonXAIConfig.get_config()
|
|
||||||
for k, v in config.items():
|
|
||||||
if k not in optional_params:
|
|
||||||
optional_params[k] = v
|
|
||||||
|
|
||||||
model_response.model = model
|
|
||||||
|
|
||||||
# Load auth variables from environment variables
|
|
||||||
if isinstance(input, str):
|
|
||||||
input = [input]
|
|
||||||
if api_key is not None:
|
|
||||||
optional_params["api_key"] = api_key
|
|
||||||
api_params = _get_api_params(optional_params)
|
|
||||||
# build auth headers
|
|
||||||
api_token = api_params.get("token")
|
|
||||||
self.token = api_token
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {api_token}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Accept": "application/json",
|
|
||||||
}
|
|
||||||
# init the payload to the text generation call
|
|
||||||
payload = {
|
|
||||||
"inputs": input,
|
|
||||||
"model_id": model,
|
|
||||||
"project_id": api_params["project_id"],
|
|
||||||
"parameters": optional_params,
|
|
||||||
}
|
|
||||||
request_params = dict(version=api_params["api_version"])
|
|
||||||
url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.EMBEDDINGS
|
|
||||||
req_params = {
|
|
||||||
"method": "POST",
|
|
||||||
"url": url,
|
|
||||||
"headers": headers,
|
|
||||||
"json": payload,
|
|
||||||
"params": request_params,
|
|
||||||
}
|
|
||||||
request_manager = RequestManager(logging_obj)
|
|
||||||
|
|
||||||
def handle_embedding(request_params: dict) -> EmbeddingResponse:
|
|
||||||
with request_manager.request(request_params, input=input) as resp:
|
|
||||||
json_resp = resp.json()
|
|
||||||
return self._process_embedding_response(json_resp, model_response)
|
|
||||||
|
|
||||||
async def handle_aembedding(request_params: dict) -> EmbeddingResponse:
|
|
||||||
async with request_manager.async_request(
|
|
||||||
request_params, input=input
|
|
||||||
) as resp:
|
|
||||||
json_resp = resp.json()
|
|
||||||
return self._process_embedding_response(json_resp, model_response)
|
|
||||||
|
|
||||||
try:
|
|
||||||
if aembedding is True:
|
|
||||||
return handle_aembedding(req_params) # type: ignore
|
|
||||||
else:
|
|
||||||
return handle_embedding(req_params)
|
|
||||||
except WatsonXAIError as e:
|
|
||||||
raise e
|
|
||||||
except Exception as e:
|
|
||||||
raise WatsonXAIError(status_code=500, message=str(e))
|
|
||||||
|
|
||||||
def get_available_models(self, *, ids_only: bool = True, **params):
|
|
||||||
api_params = _get_api_params(params)
|
|
||||||
self.token = api_params["token"]
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {api_params['token']}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Accept": "application/json",
|
|
||||||
}
|
|
||||||
request_params = dict(version=api_params["api_version"])
|
|
||||||
url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.AVAILABLE_MODELS
|
|
||||||
req_params = dict(method="GET", url=url, headers=headers, params=request_params)
|
|
||||||
with RequestManager(logging_obj=None).request(req_params) as resp:
|
|
||||||
json_resp = resp.json()
|
|
||||||
if not ids_only:
|
|
||||||
return json_resp
|
|
||||||
return [res["model_id"] for res in json_resp["resources"]]
|
|
||||||
|
|
||||||
|
|
||||||
class RequestManager:
|
|
||||||
"""
|
|
||||||
A class to handle sync/async HTTP requests to the IBM Watsonx.ai API.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
```python
|
|
||||||
request_params = dict(method="POST", url="https://api.example.com", headers={"Authorization" : "Bearer token"}, json={"key": "value"})
|
|
||||||
request_manager = RequestManager(logging_obj=logging_obj)
|
|
||||||
with request_manager.request(request_params) as resp:
|
|
||||||
...
|
|
||||||
# or
|
|
||||||
async with request_manager.async_request(request_params) as resp:
|
|
||||||
...
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, logging_obj=None):
|
|
||||||
self.logging_obj = logging_obj
|
|
||||||
|
|
||||||
def pre_call(
|
|
||||||
self,
|
|
||||||
request_params: dict,
|
|
||||||
input: Optional[Any] = None,
|
|
||||||
is_async: Optional[bool] = False,
|
|
||||||
):
|
|
||||||
if self.logging_obj is None:
|
|
||||||
return
|
|
||||||
request_str = (
|
|
||||||
f"response = {'await ' if is_async else ''}{request_params['method']}(\n"
|
|
||||||
f"\turl={request_params['url']},\n"
|
|
||||||
f"\tjson={request_params.get('json')},\n"
|
|
||||||
f")"
|
|
||||||
)
|
|
||||||
self.logging_obj.pre_call(
|
|
||||||
input=input,
|
|
||||||
api_key=request_params["headers"].get("Authorization"),
|
|
||||||
additional_args={
|
|
||||||
"complete_input_dict": request_params.get("json"),
|
|
||||||
"request_str": request_str,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def post_call(self, resp, request_params):
|
|
||||||
if self.logging_obj is None:
|
|
||||||
return
|
|
||||||
self.logging_obj.post_call(
|
|
||||||
input=input,
|
|
||||||
api_key=request_params["headers"].get("Authorization"),
|
|
||||||
original_response=json.dumps(resp.json()),
|
|
||||||
additional_args={
|
|
||||||
"status_code": resp.status_code,
|
|
||||||
"complete_input_dict": request_params.get(
|
|
||||||
"data", request_params.get("json")
|
|
||||||
),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def request(
|
|
||||||
self,
|
|
||||||
request_params: dict,
|
|
||||||
stream: bool = False,
|
|
||||||
input: Optional[Any] = None,
|
|
||||||
timeout=None,
|
|
||||||
) -> Generator[requests.Response, None, None]:
|
|
||||||
"""
|
|
||||||
Returns a context manager that yields the response from the request.
|
|
||||||
"""
|
|
||||||
self.pre_call(request_params, input)
|
|
||||||
if timeout:
|
|
||||||
request_params["timeout"] = timeout
|
|
||||||
if stream:
|
|
||||||
request_params["stream"] = stream
|
|
||||||
try:
|
|
||||||
resp = requests.request(**request_params)
|
|
||||||
if not resp.ok:
|
|
||||||
raise WatsonXAIError(
|
|
||||||
status_code=resp.status_code,
|
|
||||||
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
|
|
||||||
)
|
|
||||||
yield resp
|
|
||||||
except Exception as e:
|
|
||||||
raise WatsonXAIError(status_code=500, message=str(e))
|
|
||||||
if not stream:
|
|
||||||
self.post_call(resp, request_params)
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def async_request(
|
|
||||||
self,
|
|
||||||
request_params: dict,
|
|
||||||
stream: bool = False,
|
|
||||||
input: Optional[Any] = None,
|
|
||||||
timeout=None,
|
|
||||||
) -> AsyncGenerator[httpx.Response, None]:
|
|
||||||
self.pre_call(request_params, input, is_async=True)
|
|
||||||
if timeout:
|
|
||||||
request_params["timeout"] = timeout
|
|
||||||
if stream:
|
|
||||||
request_params["stream"] = stream
|
|
||||||
try:
|
|
||||||
self.async_handler = get_async_httpx_client(
|
|
||||||
llm_provider=litellm.LlmProviders.WATSONX,
|
|
||||||
params={
|
|
||||||
"timeout": httpx.Timeout(
|
|
||||||
timeout=request_params.pop("timeout", 600.0), connect=5.0
|
|
||||||
),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
if "json" in request_params:
|
|
||||||
request_params["data"] = json.dumps(request_params.pop("json", {}))
|
|
||||||
method = request_params.pop("method")
|
|
||||||
retries = 0
|
|
||||||
resp: Optional[httpx.Response] = None
|
|
||||||
while retries < 3:
|
|
||||||
if method.upper() == "POST":
|
|
||||||
resp = await self.async_handler.post(**request_params)
|
|
||||||
else:
|
|
||||||
resp = await self.async_handler.get(**request_params)
|
|
||||||
if resp is not None and resp.status_code in [429, 503, 504, 520]:
|
|
||||||
# to handle rate limiting and service unavailable errors
|
|
||||||
# see: ibm_watsonx_ai.foundation_models.inference.base_model_inference.BaseModelInference._send_inference_payload
|
|
||||||
await asyncio.sleep(2**retries)
|
|
||||||
retries += 1
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
if resp is None:
|
|
||||||
raise WatsonXAIError(
|
|
||||||
status_code=500,
|
|
||||||
message="No response from the server",
|
|
||||||
)
|
|
||||||
if resp.is_error:
|
|
||||||
error_reason = getattr(resp, "reason", "")
|
|
||||||
raise WatsonXAIError(
|
|
||||||
status_code=resp.status_code,
|
|
||||||
message=f"Error {resp.status_code} ({error_reason}): {resp.text}",
|
|
||||||
)
|
|
||||||
yield resp
|
|
||||||
# await async_handler.close()
|
|
||||||
except Exception as e:
|
|
||||||
raise e
|
|
||||||
raise WatsonXAIError(status_code=500, message=str(e))
|
|
||||||
if not stream:
|
|
||||||
self.post_call(resp, request_params)
|
|
||||||
|
|
|
@ -1,13 +1,31 @@
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
AsyncIterator,
|
||||||
|
Dict,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||||
from litellm.types.llms.openai import AllMessageValues
|
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUsageBlock
|
||||||
from litellm.utils import ModelResponse
|
from litellm.types.llms.watsonx import WatsonXAIEndpoint
|
||||||
|
from litellm.types.utils import GenericStreamingChunk, ModelResponse, Usage
|
||||||
|
from litellm.utils import map_finish_reason
|
||||||
|
|
||||||
from ...base_llm.chat.transformation import BaseConfig
|
from ...base_llm.chat.transformation import BaseConfig
|
||||||
from ..common_utils import WatsonXAIError
|
from ..common_utils import (
|
||||||
|
IBMWatsonXMixin,
|
||||||
|
WatsonXAIError,
|
||||||
|
_get_api_params,
|
||||||
|
convert_watsonx_messages_to_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||||
|
@ -17,7 +35,7 @@ else:
|
||||||
LiteLLMLoggingObj = Any
|
LiteLLMLoggingObj = Any
|
||||||
|
|
||||||
|
|
||||||
class IBMWatsonXAIConfig(BaseConfig):
|
class IBMWatsonXAIConfig(IBMWatsonXMixin, BaseConfig):
|
||||||
"""
|
"""
|
||||||
Reference: https://cloud.ibm.com/apidocs/watsonx-ai#text-generation
|
Reference: https://cloud.ibm.com/apidocs/watsonx-ai#text-generation
|
||||||
(See ibm_watsonx_ai.metanames.GenTextParamsMetaNames for a list of all available params)
|
(See ibm_watsonx_ai.metanames.GenTextParamsMetaNames for a list of all available params)
|
||||||
|
@ -210,13 +228,6 @@ class IBMWatsonXAIConfig(BaseConfig):
|
||||||
"us-south",
|
"us-south",
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_error_class(
|
|
||||||
self, error_message: str, status_code: int, headers: Union[Dict, httpx.Headers]
|
|
||||||
) -> BaseLLMException:
|
|
||||||
return WatsonXAIError(
|
|
||||||
status_code=status_code, message=error_message, headers=headers
|
|
||||||
)
|
|
||||||
|
|
||||||
def transform_request(
|
def transform_request(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -225,9 +236,28 @@ class IBMWatsonXAIConfig(BaseConfig):
|
||||||
litellm_params: Dict,
|
litellm_params: Dict,
|
||||||
headers: Dict,
|
headers: Dict,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
raise NotImplementedError(
|
provider = model.split("/")[0]
|
||||||
"transform_request not implemented. Done in watsonx/completion handler.py"
|
prompt = convert_watsonx_messages_to_prompt(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
provider=provider,
|
||||||
|
custom_prompt_dict={},
|
||||||
)
|
)
|
||||||
|
extra_body_params = optional_params.pop("extra_body", {})
|
||||||
|
optional_params.update(extra_body_params)
|
||||||
|
watsonx_api_params = _get_api_params(params=optional_params)
|
||||||
|
# init the payload to the text generation call
|
||||||
|
payload = {
|
||||||
|
"input": prompt,
|
||||||
|
"moderations": optional_params.pop("moderations", {}),
|
||||||
|
"parameters": optional_params,
|
||||||
|
}
|
||||||
|
|
||||||
|
if not model.startswith("deployment/"):
|
||||||
|
payload["model_id"] = model
|
||||||
|
payload["project_id"] = watsonx_api_params["project_id"]
|
||||||
|
|
||||||
|
return payload
|
||||||
|
|
||||||
def transform_response(
|
def transform_response(
|
||||||
self,
|
self,
|
||||||
|
@ -243,22 +273,120 @@ class IBMWatsonXAIConfig(BaseConfig):
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
json_mode: Optional[bool] = None,
|
json_mode: Optional[bool] = None,
|
||||||
) -> ModelResponse:
|
) -> ModelResponse:
|
||||||
raise NotImplementedError(
|
## LOGGING
|
||||||
"transform_response not implemented. Done in watsonx/completion handler.py"
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key="",
|
||||||
|
original_response=raw_response.text,
|
||||||
)
|
)
|
||||||
|
|
||||||
def validate_environment(
|
json_resp = raw_response.json()
|
||||||
|
|
||||||
|
if "results" not in json_resp:
|
||||||
|
raise WatsonXAIError(
|
||||||
|
status_code=500,
|
||||||
|
message=f"Error: Invalid response from Watsonx.ai API: {json_resp}",
|
||||||
|
)
|
||||||
|
if model_response is None:
|
||||||
|
model_response = ModelResponse(model=json_resp.get("model_id", None))
|
||||||
|
generated_text = json_resp["results"][0]["generated_text"]
|
||||||
|
prompt_tokens = json_resp["results"][0]["input_token_count"]
|
||||||
|
completion_tokens = json_resp["results"][0]["generated_token_count"]
|
||||||
|
model_response.choices[0].message.content = generated_text # type: ignore
|
||||||
|
model_response.choices[0].finish_reason = map_finish_reason(
|
||||||
|
json_resp["results"][0]["stop_reason"]
|
||||||
|
)
|
||||||
|
if json_resp.get("created_at"):
|
||||||
|
model_response.created = int(
|
||||||
|
datetime.fromisoformat(json_resp["created_at"]).timestamp()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model_response.created = int(time.time())
|
||||||
|
usage = Usage(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
)
|
||||||
|
setattr(model_response, "usage", usage)
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
def get_complete_url(
|
||||||
self,
|
self,
|
||||||
headers: Dict,
|
api_base: str,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[AllMessageValues],
|
optional_params: dict,
|
||||||
optional_params: Dict,
|
stream: Optional[bool] = None,
|
||||||
api_key: Optional[str] = None,
|
) -> str:
|
||||||
) -> Dict:
|
url = self._get_base_url(api_base=api_base)
|
||||||
headers = {
|
if model.startswith("deployment/"):
|
||||||
"Content-Type": "application/json",
|
# deployment models are passed in as 'deployment/<deployment_id>'
|
||||||
"Accept": "application/json",
|
if optional_params.get("space_id") is None:
|
||||||
}
|
raise WatsonXAIError(
|
||||||
if api_key:
|
status_code=401,
|
||||||
headers["Authorization"] = f"Bearer {api_key}"
|
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
|
||||||
return headers
|
)
|
||||||
|
deployment_id = "/".join(model.split("/")[1:])
|
||||||
|
endpoint = (
|
||||||
|
WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION_STREAM.value
|
||||||
|
if stream
|
||||||
|
else WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION.value
|
||||||
|
)
|
||||||
|
endpoint = endpoint.format(deployment_id=deployment_id)
|
||||||
|
else:
|
||||||
|
endpoint = (
|
||||||
|
WatsonXAIEndpoint.TEXT_GENERATION_STREAM
|
||||||
|
if stream
|
||||||
|
else WatsonXAIEndpoint.TEXT_GENERATION
|
||||||
|
)
|
||||||
|
url = url.rstrip("/") + endpoint
|
||||||
|
|
||||||
|
## add api version
|
||||||
|
url = self._add_api_version_to_url(
|
||||||
|
url=url, api_version=optional_params.pop("api_version", None)
|
||||||
|
)
|
||||||
|
return url
|
||||||
|
|
||||||
|
def get_model_response_iterator(
|
||||||
|
self,
|
||||||
|
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||||
|
sync_stream: bool,
|
||||||
|
json_mode: Optional[bool] = False,
|
||||||
|
):
|
||||||
|
return WatsonxTextCompletionResponseIterator(
|
||||||
|
streaming_response=streaming_response,
|
||||||
|
sync_stream=sync_stream,
|
||||||
|
json_mode=json_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WatsonxTextCompletionResponseIterator(BaseModelResponseIterator):
|
||||||
|
# def _handle_string_chunk(self, str_line: str) -> GenericStreamingChunk:
|
||||||
|
# return self.chunk_parser(json.loads(str_line))
|
||||||
|
|
||||||
|
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||||
|
try:
|
||||||
|
results = chunk.get("results", [])
|
||||||
|
if len(results) > 0:
|
||||||
|
text = results[0].get("generated_text", "")
|
||||||
|
finish_reason = results[0].get("stop_reason")
|
||||||
|
is_finished = finish_reason != "not_finished"
|
||||||
|
|
||||||
|
return GenericStreamingChunk(
|
||||||
|
text=text,
|
||||||
|
is_finished=is_finished,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
usage=ChatCompletionUsageBlock(
|
||||||
|
prompt_tokens=results[0].get("input_token_count", 0),
|
||||||
|
completion_tokens=results[0].get("generated_token_count", 0),
|
||||||
|
total_tokens=results[0].get("input_token_count", 0)
|
||||||
|
+ results[0].get("generated_token_count", 0),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return GenericStreamingChunk(
|
||||||
|
text="",
|
||||||
|
is_finished=False,
|
||||||
|
finish_reason="stop",
|
||||||
|
usage=None,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
116
litellm/llms/watsonx/embed/transformation.py
Normal file
116
litellm/llms/watsonx/embed/transformation.py
Normal file
|
@ -0,0 +1,116 @@
|
||||||
|
"""
|
||||||
|
Translates from OpenAI's `/v1/embeddings` to IBM's `/text/embeddings` route.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from litellm.llms.base_llm.embedding.transformation import (
|
||||||
|
BaseEmbeddingConfig,
|
||||||
|
LiteLLMLoggingObj,
|
||||||
|
)
|
||||||
|
from litellm.types.llms.openai import AllEmbeddingInputValues
|
||||||
|
from litellm.types.llms.watsonx import WatsonXAIEndpoint
|
||||||
|
from litellm.types.utils import EmbeddingResponse, Usage
|
||||||
|
|
||||||
|
from ..common_utils import IBMWatsonXMixin, WatsonXAIError, _get_api_params
|
||||||
|
|
||||||
|
|
||||||
|
class IBMWatsonXEmbeddingConfig(IBMWatsonXMixin, BaseEmbeddingConfig):
|
||||||
|
def get_supported_openai_params(self, model: str) -> list:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def map_openai_params(
|
||||||
|
self,
|
||||||
|
non_default_params: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
model: str,
|
||||||
|
drop_params: bool,
|
||||||
|
) -> dict:
|
||||||
|
return optional_params
|
||||||
|
|
||||||
|
def transform_embedding_request(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: AllEmbeddingInputValues,
|
||||||
|
optional_params: dict,
|
||||||
|
headers: dict,
|
||||||
|
) -> dict:
|
||||||
|
watsonx_api_params = _get_api_params(params=optional_params)
|
||||||
|
project_id = watsonx_api_params["project_id"]
|
||||||
|
if not project_id:
|
||||||
|
raise ValueError("project_id is required")
|
||||||
|
return {
|
||||||
|
"inputs": input,
|
||||||
|
"model_id": model,
|
||||||
|
"project_id": project_id,
|
||||||
|
"parameters": optional_params,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_complete_url(
|
||||||
|
self,
|
||||||
|
api_base: Optional[str],
|
||||||
|
model: str,
|
||||||
|
optional_params: dict,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
) -> str:
|
||||||
|
url = self._get_base_url(api_base=api_base)
|
||||||
|
endpoint = WatsonXAIEndpoint.EMBEDDINGS.value
|
||||||
|
if model.startswith("deployment/"):
|
||||||
|
# deployment models are passed in as 'deployment/<deployment_id>'
|
||||||
|
if optional_params.get("space_id") is None:
|
||||||
|
raise WatsonXAIError(
|
||||||
|
status_code=401,
|
||||||
|
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
|
||||||
|
)
|
||||||
|
deployment_id = "/".join(model.split("/")[1:])
|
||||||
|
endpoint = endpoint.format(deployment_id=deployment_id)
|
||||||
|
url = url.rstrip("/") + endpoint
|
||||||
|
|
||||||
|
## add api version
|
||||||
|
url = self._add_api_version_to_url(
|
||||||
|
url=url, api_version=optional_params.pop("api_version", None)
|
||||||
|
)
|
||||||
|
return url
|
||||||
|
|
||||||
|
def transform_embedding_response(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
raw_response: httpx.Response,
|
||||||
|
model_response: EmbeddingResponse,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
api_key: Optional[str],
|
||||||
|
request_data: dict,
|
||||||
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
|
) -> EmbeddingResponse:
|
||||||
|
logging_obj.post_call(
|
||||||
|
original_response=raw_response.text,
|
||||||
|
)
|
||||||
|
json_resp = raw_response.json()
|
||||||
|
if model_response is None:
|
||||||
|
model_response = EmbeddingResponse(model=json_resp.get("model_id", None))
|
||||||
|
results = json_resp.get("results", [])
|
||||||
|
embedding_response = []
|
||||||
|
for idx, result in enumerate(results):
|
||||||
|
embedding_response.append(
|
||||||
|
{
|
||||||
|
"object": "embedding",
|
||||||
|
"index": idx,
|
||||||
|
"embedding": result["embedding"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
model_response.object = "list"
|
||||||
|
model_response.data = embedding_response
|
||||||
|
input_tokens = json_resp.get("input_token_count", 0)
|
||||||
|
setattr(
|
||||||
|
model_response,
|
||||||
|
"usage",
|
||||||
|
Usage(
|
||||||
|
prompt_tokens=input_tokens,
|
||||||
|
completion_tokens=0,
|
||||||
|
total_tokens=input_tokens,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return model_response
|
134
litellm/main.py
134
litellm/main.py
|
@ -145,7 +145,7 @@ from .llms.vertex_ai.vertex_embeddings.embedding_handler import VertexEmbedding
|
||||||
from .llms.vertex_ai.vertex_model_garden.main import VertexAIModelGardenModels
|
from .llms.vertex_ai.vertex_model_garden.main import VertexAIModelGardenModels
|
||||||
from .llms.vllm.completion import handler as vllm_handler
|
from .llms.vllm.completion import handler as vllm_handler
|
||||||
from .llms.watsonx.chat.handler import WatsonXChatHandler
|
from .llms.watsonx.chat.handler import WatsonXChatHandler
|
||||||
from .llms.watsonx.completion.handler import IBMWatsonXAI
|
from .llms.watsonx.common_utils import IBMWatsonXMixin
|
||||||
from .types.llms.openai import (
|
from .types.llms.openai import (
|
||||||
ChatCompletionAssistantMessage,
|
ChatCompletionAssistantMessage,
|
||||||
ChatCompletionAudioParam,
|
ChatCompletionAudioParam,
|
||||||
|
@ -205,7 +205,6 @@ google_batch_embeddings = GoogleBatchEmbeddings()
|
||||||
vertex_partner_models_chat_completion = VertexAIPartnerModels()
|
vertex_partner_models_chat_completion = VertexAIPartnerModels()
|
||||||
vertex_model_garden_chat_completion = VertexAIModelGardenModels()
|
vertex_model_garden_chat_completion = VertexAIModelGardenModels()
|
||||||
vertex_text_to_speech = VertexTextToSpeechAPI()
|
vertex_text_to_speech = VertexTextToSpeechAPI()
|
||||||
watsonxai = IBMWatsonXAI()
|
|
||||||
sagemaker_llm = SagemakerLLM()
|
sagemaker_llm = SagemakerLLM()
|
||||||
watsonx_chat_completion = WatsonXChatHandler()
|
watsonx_chat_completion = WatsonXChatHandler()
|
||||||
openai_like_embedding = OpenAILikeEmbeddingHandler()
|
openai_like_embedding = OpenAILikeEmbeddingHandler()
|
||||||
|
@ -2585,43 +2584,68 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
custom_llm_provider="watsonx",
|
custom_llm_provider="watsonx",
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "watsonx_text":
|
elif custom_llm_provider == "watsonx_text":
|
||||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
api_key = (
|
||||||
response = watsonxai.completion(
|
api_key
|
||||||
model=model,
|
or optional_params.pop("apikey", None)
|
||||||
messages=messages,
|
or get_secret_str("WATSONX_APIKEY")
|
||||||
custom_prompt_dict=custom_prompt_dict,
|
or get_secret_str("WATSONX_API_KEY")
|
||||||
model_response=model_response,
|
or get_secret_str("WX_API_KEY")
|
||||||
print_verbose=print_verbose,
|
|
||||||
optional_params=optional_params,
|
|
||||||
litellm_params=litellm_params, # type: ignore
|
|
||||||
logger_fn=logger_fn,
|
|
||||||
encoding=encoding,
|
|
||||||
logging_obj=logging,
|
|
||||||
timeout=timeout, # type: ignore
|
|
||||||
acompletion=acompletion,
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
"stream" in optional_params
|
|
||||||
and optional_params["stream"] is True
|
|
||||||
and not isinstance(response, CustomStreamWrapper)
|
|
||||||
):
|
|
||||||
# don't try to access stream object,
|
|
||||||
response = CustomStreamWrapper(
|
|
||||||
iter(response),
|
|
||||||
model,
|
|
||||||
custom_llm_provider="watsonx",
|
|
||||||
logging_obj=logging,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if optional_params.get("stream", False):
|
api_base = (
|
||||||
## LOGGING
|
api_base
|
||||||
logging.post_call(
|
or optional_params.pop(
|
||||||
input=messages,
|
"url",
|
||||||
api_key=None,
|
optional_params.pop(
|
||||||
original_response=response,
|
"api_base", optional_params.pop("base_url", None)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
or get_secret_str("WATSONX_API_BASE")
|
||||||
|
or get_secret_str("WATSONX_URL")
|
||||||
|
or get_secret_str("WX_URL")
|
||||||
|
or get_secret_str("WML_URL")
|
||||||
|
)
|
||||||
|
|
||||||
|
wx_credentials = optional_params.pop(
|
||||||
|
"wx_credentials",
|
||||||
|
optional_params.pop(
|
||||||
|
"watsonx_credentials", None
|
||||||
|
), # follow {provider}_credentials, same as vertex ai
|
||||||
|
)
|
||||||
|
|
||||||
|
token: Optional[str] = None
|
||||||
|
if wx_credentials is not None:
|
||||||
|
api_base = wx_credentials.get("url", api_base)
|
||||||
|
api_key = wx_credentials.get(
|
||||||
|
"apikey", wx_credentials.get("api_key", api_key)
|
||||||
|
)
|
||||||
|
token = wx_credentials.get(
|
||||||
|
"token",
|
||||||
|
wx_credentials.get(
|
||||||
|
"watsonx_token", None
|
||||||
|
), # follow format of {provider}_token, same as azure - e.g. 'azure_ad_token=..'
|
||||||
|
)
|
||||||
|
|
||||||
|
if token is not None:
|
||||||
|
optional_params["token"] = token
|
||||||
|
|
||||||
|
response = base_llm_http_handler.completion(
|
||||||
|
model=model,
|
||||||
|
stream=stream,
|
||||||
|
messages=messages,
|
||||||
|
acompletion=acompletion,
|
||||||
|
api_base=api_base,
|
||||||
|
model_response=model_response,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
custom_llm_provider="watsonx_text",
|
||||||
|
timeout=timeout,
|
||||||
|
headers=headers,
|
||||||
|
encoding=encoding,
|
||||||
|
api_key=api_key,
|
||||||
|
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
||||||
|
client=client,
|
||||||
)
|
)
|
||||||
## RESPONSE OBJECT
|
|
||||||
response = response
|
|
||||||
elif custom_llm_provider == "vllm":
|
elif custom_llm_provider == "vllm":
|
||||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||||
model_response = vllm_handler.completion(
|
model_response = vllm_handler.completion(
|
||||||
|
@ -3485,6 +3509,7 @@ def embedding( # noqa: PLR0915
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
client=client,
|
client=client,
|
||||||
aembedding=aembedding,
|
aembedding=aembedding,
|
||||||
|
litellm_params={},
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "gemini":
|
elif custom_llm_provider == "gemini":
|
||||||
gemini_api_key = (
|
gemini_api_key = (
|
||||||
|
@ -3661,6 +3686,32 @@ def embedding( # noqa: PLR0915
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
client=client,
|
client=client,
|
||||||
aembedding=aembedding,
|
aembedding=aembedding,
|
||||||
|
litellm_params={},
|
||||||
|
)
|
||||||
|
elif custom_llm_provider == "watsonx":
|
||||||
|
credentials = IBMWatsonXMixin.get_watsonx_credentials(
|
||||||
|
optional_params=optional_params, api_key=api_key, api_base=api_base
|
||||||
|
)
|
||||||
|
|
||||||
|
api_key = credentials["api_key"]
|
||||||
|
api_base = credentials["api_base"]
|
||||||
|
|
||||||
|
if "token" in credentials:
|
||||||
|
optional_params["token"] = credentials["token"]
|
||||||
|
|
||||||
|
response = base_llm_http_handler.embedding(
|
||||||
|
model=model,
|
||||||
|
input=input,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
api_base=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
logging_obj=logging,
|
||||||
|
timeout=timeout,
|
||||||
|
model_response=EmbeddingResponse(),
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params={},
|
||||||
|
client=client,
|
||||||
|
aembedding=aembedding,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "xinference":
|
elif custom_llm_provider == "xinference":
|
||||||
api_key = (
|
api_key = (
|
||||||
|
@ -3687,17 +3738,6 @@ def embedding( # noqa: PLR0915
|
||||||
client=client,
|
client=client,
|
||||||
aembedding=aembedding,
|
aembedding=aembedding,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "watsonx":
|
|
||||||
response = watsonxai.embedding(
|
|
||||||
model=model,
|
|
||||||
input=input,
|
|
||||||
encoding=encoding,
|
|
||||||
logging_obj=logging,
|
|
||||||
optional_params=optional_params,
|
|
||||||
model_response=EmbeddingResponse(),
|
|
||||||
aembedding=aembedding,
|
|
||||||
api_key=api_key,
|
|
||||||
)
|
|
||||||
elif custom_llm_provider == "azure_ai":
|
elif custom_llm_provider == "azure_ai":
|
||||||
api_base = (
|
api_base = (
|
||||||
api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there
|
api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there
|
||||||
|
|
|
@ -611,3 +611,6 @@ class FineTuningJobCreate(BaseModel):
|
||||||
|
|
||||||
class LiteLLMFineTuningJobCreate(FineTuningJobCreate):
|
class LiteLLMFineTuningJobCreate(FineTuningJobCreate):
|
||||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"]
|
custom_llm_provider: Literal["openai", "azure", "vertex_ai"]
|
||||||
|
|
||||||
|
|
||||||
|
AllEmbeddingInputValues = Union[str, List[str], List[int], List[List[int]]]
|
||||||
|
|
|
@ -6,13 +6,15 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class WatsonXAPIParams(TypedDict):
|
class WatsonXAPIParams(TypedDict):
|
||||||
url: str
|
|
||||||
api_key: Optional[str]
|
|
||||||
token: str
|
|
||||||
project_id: str
|
project_id: str
|
||||||
space_id: Optional[str]
|
space_id: Optional[str]
|
||||||
region_name: Optional[str]
|
region_name: Optional[str]
|
||||||
api_version: str
|
|
||||||
|
|
||||||
|
class WatsonXCredentials(TypedDict):
|
||||||
|
api_key: str
|
||||||
|
api_base: str
|
||||||
|
token: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
class WatsonXAIEndpoint(str, Enum):
|
class WatsonXAIEndpoint(str, Enum):
|
||||||
|
|
|
@ -808,6 +808,8 @@ class ModelResponseStream(ModelResponseBase):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
choices: Optional[List[Union[StreamingChoices, dict, BaseModel]]] = None,
|
choices: Optional[List[Union[StreamingChoices, dict, BaseModel]]] = None,
|
||||||
|
id: Optional[str] = None,
|
||||||
|
created: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if choices is not None and isinstance(choices, list):
|
if choices is not None and isinstance(choices, list):
|
||||||
|
@ -824,6 +826,20 @@ class ModelResponseStream(ModelResponseBase):
|
||||||
kwargs["choices"] = new_choices
|
kwargs["choices"] = new_choices
|
||||||
else:
|
else:
|
||||||
kwargs["choices"] = [StreamingChoices()]
|
kwargs["choices"] = [StreamingChoices()]
|
||||||
|
|
||||||
|
if id is None:
|
||||||
|
id = _generate_id()
|
||||||
|
else:
|
||||||
|
id = id
|
||||||
|
if created is None:
|
||||||
|
created = int(time.time())
|
||||||
|
else:
|
||||||
|
created = created
|
||||||
|
|
||||||
|
kwargs["id"] = id
|
||||||
|
kwargs["created"] = created
|
||||||
|
kwargs["object"] = "chat.completion.chunk"
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def __contains__(self, key):
|
def __contains__(self, key):
|
||||||
|
|
|
@ -6244,7 +6244,9 @@ class ProviderConfigManager:
|
||||||
return litellm.VoyageEmbeddingConfig()
|
return litellm.VoyageEmbeddingConfig()
|
||||||
elif litellm.LlmProviders.TRITON == provider:
|
elif litellm.LlmProviders.TRITON == provider:
|
||||||
return litellm.TritonEmbeddingConfig()
|
return litellm.TritonEmbeddingConfig()
|
||||||
raise ValueError(f"Provider {provider} does not support embedding config")
|
elif litellm.LlmProviders.WATSONX == provider:
|
||||||
|
return litellm.IBMWatsonXEmbeddingConfig()
|
||||||
|
raise ValueError(f"Provider {provider.value} does not support embedding config")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_provider_rerank_config(
|
def get_provider_rerank_config(
|
||||||
|
|
2542
poetry.lock
generated
2542
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -26,7 +26,6 @@ tokenizers = "*"
|
||||||
click = "*"
|
click = "*"
|
||||||
jinja2 = "^3.1.2"
|
jinja2 = "^3.1.2"
|
||||||
aiohttp = "*"
|
aiohttp = "*"
|
||||||
requests = "^2.31.0"
|
|
||||||
pydantic = "^2.0.0"
|
pydantic = "^2.0.0"
|
||||||
jsonschema = "^4.22.0"
|
jsonschema = "^4.22.0"
|
||||||
|
|
||||||
|
|
|
@ -133,7 +133,7 @@ def test_completion_xai(stream):
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
print(chunk)
|
print(chunk)
|
||||||
assert chunk is not None
|
assert chunk is not None
|
||||||
assert isinstance(chunk, litellm.ModelResponse)
|
assert isinstance(chunk, litellm.ModelResponseStream)
|
||||||
assert isinstance(chunk.choices[0], litellm.utils.StreamingChoices)
|
assert isinstance(chunk.choices[0], litellm.utils.StreamingChoices)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -1,104 +0,0 @@
|
||||||
============================= test session starts ==============================
|
|
||||||
platform darwin -- Python 3.11.4, pytest-8.3.2, pluggy-1.5.0 -- /Users/krrishdholakia/Documents/litellm/myenv/bin/python3.11
|
|
||||||
cachedir: .pytest_cache
|
|
||||||
rootdir: /Users/krrishdholakia/Documents/litellm
|
|
||||||
configfile: pyproject.toml
|
|
||||||
plugins: asyncio-0.23.8, respx-0.21.1, anyio-4.6.0
|
|
||||||
asyncio: mode=Mode.STRICT
|
|
||||||
collecting ... collected 1 item
|
|
||||||
|
|
||||||
test_function_calling.py::test_aaparallel_function_call[claude-3-haiku-20240307] <module 'litellm' from '/Users/krrishdholakia/Documents/litellm/litellm/__init__.py'>
|
|
||||||
|
|
||||||
|
|
||||||
[92mRequest to litellm:[0m
|
|
||||||
[92mlitellm.completion(model='claude-3-haiku-20240307', messages=[{'role': 'user', 'content': "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses"}], tools=[{'type': 'function', 'function': {'name': 'get_current_weather', 'description': 'Get the current weather in a given location', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}}], tool_choice='auto')[0m
|
|
||||||
|
|
||||||
|
|
||||||
SYNC kwargs[caching]: False; litellm.cache: None; kwargs.get('cache')['no-cache']: False
|
|
||||||
Final returned optional params: {'tools': [{'type': 'function', 'function': {'name': 'get_current_weather', 'description': 'Get the current weather in a given location', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}}], 'tool_choice': {'type': 'auto'}}
|
|
||||||
optional_params: {'tools': [{'type': 'function', 'function': {'name': 'get_current_weather', 'description': 'Get the current weather in a given location', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}}], 'tool_choice': {'type': 'auto'}}
|
|
||||||
SENT optional_params: {'tools': [{'type': 'function', 'function': {'name': 'get_current_weather', 'description': 'Get the current weather in a given location', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}}], 'tool_choice': {'type': 'auto'}, 'max_tokens': 4096}
|
|
||||||
tool: {'type': 'function', 'function': {'name': 'get_current_weather', 'description': 'Get the current weather in a given location', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}}
|
|
||||||
[92m
|
|
||||||
|
|
||||||
POST Request Sent from LiteLLM:
|
|
||||||
curl -X POST \
|
|
||||||
https://api.anthropic.com/v1/messages \
|
|
||||||
-H 'accept: *****' -H 'anthropic-version: *****' -H 'content-type: *****' -H 'x-api-key: sk-ant-api03-bJf1M8qp-JDptRcZRE5ve5efAfSIaL5u-SZ9vItIkvuFcV5cUsd********************************************' -H 'anthropic-beta: *****' \
|
|
||||||
-d '{'messages': [{'role': 'user', 'content': [{'type': 'text', 'text': "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses"}]}], 'tools': [{'name': 'get_current_weather', 'description': 'Get the current weather in a given location', 'input_schema': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}], 'tool_choice': {'type': 'auto'}, 'max_tokens': 4096, 'model': 'claude-3-haiku-20240307'}'
|
|
||||||
[0m
|
|
||||||
|
|
||||||
_is_function_call: False
|
|
||||||
RAW RESPONSE:
|
|
||||||
{"id":"msg_01HRugqzL4WmcxMmbvDheTph","type":"message","role":"assistant","model":"claude-3-haiku-20240307","content":[{"type":"text","text":"Okay, let's check the current weather in those three cities:"},{"type":"tool_use","id":"toolu_016U6G3kpxjHSiJLwVCrrScz","name":"get_current_weather","input":{"location":"San Francisco","unit":"celsius"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":379,"output_tokens":87}}
|
|
||||||
|
|
||||||
|
|
||||||
raw model_response: {"id":"msg_01HRugqzL4WmcxMmbvDheTph","type":"message","role":"assistant","model":"claude-3-haiku-20240307","content":[{"type":"text","text":"Okay, let's check the current weather in those three cities:"},{"type":"tool_use","id":"toolu_016U6G3kpxjHSiJLwVCrrScz","name":"get_current_weather","input":{"location":"San Francisco","unit":"celsius"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":379,"output_tokens":87}}
|
|
||||||
Logging Details LiteLLM-Success Call: Cache_hit=None
|
|
||||||
Looking up model=claude-3-haiku-20240307 in model_cost_map
|
|
||||||
Looking up model=claude-3-haiku-20240307 in model_cost_map
|
|
||||||
Response
|
|
||||||
ModelResponse(id='chatcmpl-7222f6c2-962a-4776-8639-576723466cb7', choices=[Choices(finish_reason='tool_calls', index=0, message=Message(content="Okay, let's check the current weather in those three cities:", role='assistant', tool_calls=[ChatCompletionMessageToolCall(index=1, function=Function(arguments='{"location": "San Francisco", "unit": "celsius"}', name='get_current_weather'), id='toolu_016U6G3kpxjHSiJLwVCrrScz', type='function')], function_call=None))], created=1727897483, model='claude-3-haiku-20240307', object='chat.completion', system_fingerprint=None, usage=Usage(completion_tokens=87, prompt_tokens=379, total_tokens=466, completion_tokens_details=None))
|
|
||||||
length of tool calls 1
|
|
||||||
Expecting there to be 3 tool calls
|
|
||||||
tool_calls: [ChatCompletionMessageToolCall(index=1, function=Function(arguments='{"location": "San Francisco", "unit": "celsius"}', name='get_current_weather'), id='toolu_016U6G3kpxjHSiJLwVCrrScz', type='function')]
|
|
||||||
Response message
|
|
||||||
Message(content="Okay, let's check the current weather in those three cities:", role='assistant', tool_calls=[ChatCompletionMessageToolCall(index=1, function=Function(arguments='{"location": "San Francisco", "unit": "celsius"}', name='get_current_weather'), id='toolu_016U6G3kpxjHSiJLwVCrrScz', type='function')], function_call=None)
|
|
||||||
messages: [{'role': 'user', 'content': "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses"}, Message(content="Okay, let's check the current weather in those three cities:", role='assistant', tool_calls=[ChatCompletionMessageToolCall(index=1, function=Function(arguments='{"location": "San Francisco", "unit": "celsius"}', name='get_current_weather'), id='toolu_016U6G3kpxjHSiJLwVCrrScz', type='function')], function_call=None), {'tool_call_id': 'toolu_016U6G3kpxjHSiJLwVCrrScz', 'role': 'tool', 'name': 'get_current_weather', 'content': '{"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"}'}]
|
|
||||||
|
|
||||||
|
|
||||||
[92mRequest to litellm:[0m
|
|
||||||
[92mlitellm.completion(model='claude-3-haiku-20240307', messages=[{'role': 'user', 'content': "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses"}, Message(content="Okay, let's check the current weather in those three cities:", role='assistant', tool_calls=[ChatCompletionMessageToolCall(index=1, function=Function(arguments='{"location": "San Francisco", "unit": "celsius"}', name='get_current_weather'), id='toolu_016U6G3kpxjHSiJLwVCrrScz', type='function')], function_call=None), {'tool_call_id': 'toolu_016U6G3kpxjHSiJLwVCrrScz', 'role': 'tool', 'name': 'get_current_weather', 'content': '{"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"}'}], temperature=0.2, seed=22, drop_params=True)[0m
|
|
||||||
|
|
||||||
|
|
||||||
SYNC kwargs[caching]: False; litellm.cache: None; kwargs.get('cache')['no-cache']: False
|
|
||||||
Final returned optional params: {'temperature': 0.2, 'tools': [{'type': 'function', 'function': {'name': 'dummy-tool', 'description': '', 'parameters': {'type': 'object', 'properties': {}}}}]}
|
|
||||||
optional_params: {'temperature': 0.2, 'tools': [{'type': 'function', 'function': {'name': 'dummy-tool', 'description': '', 'parameters': {'type': 'object', 'properties': {}}}}]}
|
|
||||||
SENT optional_params: {'temperature': 0.2, 'tools': [{'type': 'function', 'function': {'name': 'dummy-tool', 'description': '', 'parameters': {'type': 'object', 'properties': {}}}}], 'max_tokens': 4096}
|
|
||||||
tool: {'type': 'function', 'function': {'name': 'dummy-tool', 'description': '', 'parameters': {'type': 'object', 'properties': {}}}}
|
|
||||||
[92m
|
|
||||||
|
|
||||||
POST Request Sent from LiteLLM:
|
|
||||||
curl -X POST \
|
|
||||||
https://api.anthropic.com/v1/messages \
|
|
||||||
-H 'accept: *****' -H 'anthropic-version: *****' -H 'content-type: *****' -H 'x-api-key: sk-ant-api03-bJf1M8qp-JDptRcZRE5ve5efAfSIaL5u-SZ9vItIkvuFcV5cUsd********************************************' -H 'anthropic-beta: *****' \
|
|
||||||
-d '{'messages': [{'role': 'user', 'content': [{'type': 'text', 'text': "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses"}]}, {'role': 'assistant', 'content': [{'type': 'tool_use', 'id': 'toolu_016U6G3kpxjHSiJLwVCrrScz', 'name': 'get_current_weather', 'input': {'location': 'San Francisco', 'unit': 'celsius'}}]}, {'role': 'user', 'content': [{'type': 'tool_result', 'tool_use_id': 'toolu_016U6G3kpxjHSiJLwVCrrScz', 'content': '{"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"}'}]}], 'temperature': 0.2, 'tools': [{'name': 'dummy-tool', 'description': '', 'input_schema': {'type': 'object', 'properties': {}}}], 'max_tokens': 4096, 'model': 'claude-3-haiku-20240307'}'
|
|
||||||
[0m
|
|
||||||
|
|
||||||
_is_function_call: False
|
|
||||||
RAW RESPONSE:
|
|
||||||
{"id":"msg_01Wp8NVScugz6yAGsmB5trpZ","type":"message","role":"assistant","model":"claude-3-haiku-20240307","content":[{"type":"text","text":"The current weather in San Francisco is 72°F (22°C)."},{"type":"tool_use","id":"toolu_01HTXEYDX4MspM76STtJqs1n","name":"get_current_weather","input":{"location":"Tokyo","unit":"celsius"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":426,"output_tokens":90}}
|
|
||||||
|
|
||||||
|
|
||||||
raw model_response: {"id":"msg_01Wp8NVScugz6yAGsmB5trpZ","type":"message","role":"assistant","model":"claude-3-haiku-20240307","content":[{"type":"text","text":"The current weather in San Francisco is 72°F (22°C)."},{"type":"tool_use","id":"toolu_01HTXEYDX4MspM76STtJqs1n","name":"get_current_weather","input":{"location":"Tokyo","unit":"celsius"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":426,"output_tokens":90}}
|
|
||||||
Logging Details LiteLLM-Success Call: Cache_hit=None
|
|
||||||
Looking up model=claude-3-haiku-20240307 in model_cost_map
|
|
||||||
Looking up model=claude-3-haiku-20240307 in model_cost_map
|
|
||||||
second response
|
|
||||||
ModelResponse(id='chatcmpl-c4ed5c25-ba7c-49e5-a6be-5720ab25fff0', choices=[Choices(finish_reason='tool_calls', index=0, message=Message(content='The current weather in San Francisco is 72°F (22°C).', role='assistant', tool_calls=[ChatCompletionMessageToolCall(index=1, function=Function(arguments='{"location": "Tokyo", "unit": "celsius"}', name='get_current_weather'), id='toolu_01HTXEYDX4MspM76STtJqs1n', type='function')], function_call=None))], created=1727897484, model='claude-3-haiku-20240307', object='chat.completion', system_fingerprint=None, usage=Usage(completion_tokens=90, prompt_tokens=426, total_tokens=516, completion_tokens_details=None))
|
|
||||||
PASSED
|
|
||||||
|
|
||||||
=============================== warnings summary ===============================
|
|
||||||
../../myenv/lib/python3.11/site-packages/pydantic/_internal/_config.py:284
|
|
||||||
/Users/krrishdholakia/Documents/litellm/myenv/lib/python3.11/site-packages/pydantic/_internal/_config.py:284: PydanticDeprecatedSince20: Support for class-based `config` is deprecated, use ConfigDict instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/
|
|
||||||
warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning)
|
|
||||||
|
|
||||||
../../litellm/utils.py:17
|
|
||||||
/Users/krrishdholakia/Documents/litellm/litellm/utils.py:17: DeprecationWarning: 'imghdr' is deprecated and slated for removal in Python 3.13
|
|
||||||
import imghdr
|
|
||||||
|
|
||||||
../../litellm/utils.py:124
|
|
||||||
/Users/krrishdholakia/Documents/litellm/litellm/utils.py:124: DeprecationWarning: open_text is deprecated. Use files() instead. Refer to https://importlib-resources.readthedocs.io/en/latest/using.html#migrating-from-legacy for migration advice.
|
|
||||||
with resources.open_text("litellm.llms.tokenizers", "anthropic_tokenizer.json") as f:
|
|
||||||
|
|
||||||
test_function_calling.py:56
|
|
||||||
/Users/krrishdholakia/Documents/litellm/tests/local_testing/test_function_calling.py:56: PytestUnknownMarkWarning: Unknown pytest.mark.flaky - is this a typo? You can register custom marks to avoid this warning - for details, see https://docs.pytest.org/en/stable/how-to/mark.html
|
|
||||||
@pytest.mark.flaky(retries=3, delay=1)
|
|
||||||
|
|
||||||
tests/local_testing/test_function_calling.py::test_aaparallel_function_call[claude-3-haiku-20240307]
|
|
||||||
tests/local_testing/test_function_calling.py::test_aaparallel_function_call[claude-3-haiku-20240307]
|
|
||||||
/Users/krrishdholakia/Documents/litellm/myenv/lib/python3.11/site-packages/httpx/_content.py:202: DeprecationWarning: Use 'content=<...>' to upload raw bytes/text content.
|
|
||||||
warnings.warn(message, DeprecationWarning)
|
|
||||||
|
|
||||||
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
|
|
||||||
======================== 1 passed, 6 warnings in 1.89s =========================
|
|
|
@ -1805,7 +1805,7 @@ async def test_gemini_pro_function_calling_streaming(sync_mode):
|
||||||
|
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
assert isinstance(chunk, litellm.ModelResponse)
|
assert isinstance(chunk, litellm.ModelResponseStream)
|
||||||
else:
|
else:
|
||||||
response = await litellm.acompletion(**data)
|
response = await litellm.acompletion(**data)
|
||||||
print(f"completion: {response}")
|
print(f"completion: {response}")
|
||||||
|
@ -1815,7 +1815,7 @@ async def test_gemini_pro_function_calling_streaming(sync_mode):
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
print(f"chunk: {chunk}")
|
print(f"chunk: {chunk}")
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
assert isinstance(chunk, litellm.ModelResponse)
|
assert isinstance(chunk, litellm.ModelResponseStream)
|
||||||
|
|
||||||
complete_response = litellm.stream_chunk_builder(chunks=chunks)
|
complete_response = litellm.stream_chunk_builder(chunks=chunks)
|
||||||
assert (
|
assert (
|
||||||
|
|
|
@ -4019,20 +4019,21 @@ def test_completion_deepseek():
|
||||||
@pytest.mark.skip(reason="Account deleted by IBM.")
|
@pytest.mark.skip(reason="Account deleted by IBM.")
|
||||||
def test_completion_watsonx_error():
|
def test_completion_watsonx_error():
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
model_name = "watsonx/ibm/granite-13b-chat-v2"
|
model_name = "watsonx_text/ibm/granite-13b-chat-v2"
|
||||||
|
|
||||||
with pytest.raises(litellm.BadRequestError) as e:
|
|
||||||
response = completion(
|
response = completion(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
stop=["stop"],
|
stop=["stop"],
|
||||||
max_tokens=20,
|
max_tokens=20,
|
||||||
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for chunk in response:
|
||||||
|
print(chunk)
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
assert "use 'watsonx_text' route instead" in str(e).lower()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Skip test. account deleted.")
|
@pytest.mark.skip(reason="Skip test. account deleted.")
|
||||||
def test_completion_stream_watsonx():
|
def test_completion_stream_watsonx():
|
||||||
|
|
|
@ -135,7 +135,7 @@ class CompletionCustomHandler(
|
||||||
## END TIME
|
## END TIME
|
||||||
assert isinstance(end_time, datetime)
|
assert isinstance(end_time, datetime)
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
assert isinstance(response_obj, litellm.ModelResponse)
|
assert isinstance(response_obj, litellm.ModelResponseStream)
|
||||||
## KWARGS
|
## KWARGS
|
||||||
assert isinstance(kwargs["model"], str)
|
assert isinstance(kwargs["model"], str)
|
||||||
assert isinstance(kwargs["messages"], list) and isinstance(
|
assert isinstance(kwargs["messages"], list) and isinstance(
|
||||||
|
|
|
@ -153,7 +153,7 @@ class CompletionCustomHandler(
|
||||||
## END TIME
|
## END TIME
|
||||||
assert isinstance(end_time, datetime)
|
assert isinstance(end_time, datetime)
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
assert isinstance(response_obj, litellm.ModelResponse)
|
assert isinstance(response_obj, litellm.ModelResponseStream)
|
||||||
## KWARGS
|
## KWARGS
|
||||||
assert isinstance(kwargs["model"], str)
|
assert isinstance(kwargs["model"], str)
|
||||||
assert isinstance(kwargs["messages"], list) and isinstance(
|
assert isinstance(kwargs["messages"], list) and isinstance(
|
||||||
|
|
|
@ -122,8 +122,13 @@ async def test_async_chat_openai_stream():
|
||||||
|
|
||||||
complete_streaming_response = complete_streaming_response.strip("'")
|
complete_streaming_response = complete_streaming_response.strip("'")
|
||||||
|
|
||||||
|
print(f"complete_streaming_response: {complete_streaming_response}")
|
||||||
|
|
||||||
await asyncio.sleep(3)
|
await asyncio.sleep(3)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"tmp_function.complete_streaming_response_in_callback: {tmp_function.complete_streaming_response_in_callback}"
|
||||||
|
)
|
||||||
# problematic line
|
# problematic line
|
||||||
response1 = tmp_function.complete_streaming_response_in_callback["choices"][0][
|
response1 = tmp_function.complete_streaming_response_in_callback["choices"][0][
|
||||||
"message"
|
"message"
|
||||||
|
|
|
@ -801,8 +801,11 @@ def test_fireworks_embeddings():
|
||||||
|
|
||||||
|
|
||||||
def test_watsonx_embeddings():
|
def test_watsonx_embeddings():
|
||||||
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
|
||||||
def mock_wx_embed_request(method: str, url: str, **kwargs):
|
client = HTTPHandler()
|
||||||
|
|
||||||
|
def mock_wx_embed_request(url: str, **kwargs):
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.status_code = 200
|
mock_response.status_code = 200
|
||||||
mock_response.headers = {"Content-Type": "application/json"}
|
mock_response.headers = {"Content-Type": "application/json"}
|
||||||
|
@ -816,12 +819,14 @@ def test_watsonx_embeddings():
|
||||||
|
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
with patch("requests.request", side_effect=mock_wx_embed_request):
|
with patch.object(client, "post", side_effect=mock_wx_embed_request):
|
||||||
response = litellm.embedding(
|
response = litellm.embedding(
|
||||||
model="watsonx/ibm/slate-30m-english-rtrvr",
|
model="watsonx/ibm/slate-30m-english-rtrvr",
|
||||||
input=["good morning from litellm"],
|
input=["good morning from litellm"],
|
||||||
token="secret-token",
|
token="secret-token",
|
||||||
|
client=client,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
assert isinstance(response.usage, litellm.Usage)
|
assert isinstance(response.usage, litellm.Usage)
|
||||||
except litellm.RateLimitError as e:
|
except litellm.RateLimitError as e:
|
||||||
|
@ -832,6 +837,9 @@ def test_watsonx_embeddings():
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_watsonx_aembeddings():
|
async def test_watsonx_aembeddings():
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
|
|
||||||
|
client = AsyncHTTPHandler()
|
||||||
|
|
||||||
def mock_async_client(*args, **kwargs):
|
def mock_async_client(*args, **kwargs):
|
||||||
|
|
||||||
|
@ -856,12 +864,14 @@ async def test_watsonx_aembeddings():
|
||||||
|
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
with patch("httpx.AsyncClient", side_effect=mock_async_client):
|
with patch.object(client, "post", side_effect=mock_async_client) as mock_client:
|
||||||
response = await litellm.aembedding(
|
response = await litellm.aembedding(
|
||||||
model="watsonx/ibm/slate-30m-english-rtrvr",
|
model="watsonx/ibm/slate-30m-english-rtrvr",
|
||||||
input=["good morning from litellm"],
|
input=["good morning from litellm"],
|
||||||
token="secret-token",
|
token="secret-token",
|
||||||
|
client=client,
|
||||||
)
|
)
|
||||||
|
mock_client.assert_called_once()
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
assert isinstance(response.usage, litellm.Usage)
|
assert isinstance(response.usage, litellm.Usage)
|
||||||
except litellm.RateLimitError as e:
|
except litellm.RateLimitError as e:
|
||||||
|
|
|
@ -17,6 +17,7 @@ from pydantic import BaseModel
|
||||||
import litellm.litellm_core_utils
|
import litellm.litellm_core_utils
|
||||||
import litellm.litellm_core_utils.litellm_logging
|
import litellm.litellm_core_utils.litellm_logging
|
||||||
from litellm.utils import ModelResponseListIterator
|
from litellm.utils import ModelResponseListIterator
|
||||||
|
from litellm.types.utils import ModelResponseStream
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
|
@ -69,7 +70,7 @@ first_openai_chunk_example = {
|
||||||
|
|
||||||
def validate_first_format(chunk):
|
def validate_first_format(chunk):
|
||||||
# write a test to make sure chunk follows the same format as first_openai_chunk_example
|
# write a test to make sure chunk follows the same format as first_openai_chunk_example
|
||||||
assert isinstance(chunk, ModelResponse), "Chunk should be a dictionary."
|
assert isinstance(chunk, ModelResponseStream), "Chunk should be a dictionary."
|
||||||
assert isinstance(chunk["id"], str), "'id' should be a string."
|
assert isinstance(chunk["id"], str), "'id' should be a string."
|
||||||
assert isinstance(chunk["object"], str), "'object' should be a string."
|
assert isinstance(chunk["object"], str), "'object' should be a string."
|
||||||
assert isinstance(chunk["created"], int), "'created' should be an integer."
|
assert isinstance(chunk["created"], int), "'created' should be an integer."
|
||||||
|
@ -99,7 +100,7 @@ second_openai_chunk_example = {
|
||||||
|
|
||||||
|
|
||||||
def validate_second_format(chunk):
|
def validate_second_format(chunk):
|
||||||
assert isinstance(chunk, ModelResponse), "Chunk should be a dictionary."
|
assert isinstance(chunk, ModelResponseStream), "Chunk should be a dictionary."
|
||||||
assert isinstance(chunk["id"], str), "'id' should be a string."
|
assert isinstance(chunk["id"], str), "'id' should be a string."
|
||||||
assert isinstance(chunk["object"], str), "'object' should be a string."
|
assert isinstance(chunk["object"], str), "'object' should be a string."
|
||||||
assert isinstance(chunk["created"], int), "'created' should be an integer."
|
assert isinstance(chunk["created"], int), "'created' should be an integer."
|
||||||
|
@ -137,7 +138,7 @@ def validate_last_format(chunk):
|
||||||
"""
|
"""
|
||||||
Ensure last chunk has no remaining content or tools
|
Ensure last chunk has no remaining content or tools
|
||||||
"""
|
"""
|
||||||
assert isinstance(chunk, ModelResponse), "Chunk should be a dictionary."
|
assert isinstance(chunk, ModelResponseStream), "Chunk should be a dictionary."
|
||||||
assert isinstance(chunk["id"], str), "'id' should be a string."
|
assert isinstance(chunk["id"], str), "'id' should be a string."
|
||||||
assert isinstance(chunk["object"], str), "'object' should be a string."
|
assert isinstance(chunk["object"], str), "'object' should be a string."
|
||||||
assert isinstance(chunk["created"], int), "'created' should be an integer."
|
assert isinstance(chunk["created"], int), "'created' should be an integer."
|
||||||
|
@ -1523,7 +1524,7 @@ async def test_parallel_streaming_requests(sync_mode, model):
|
||||||
num_finish_reason = 0
|
num_finish_reason = 0
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
print(f"chunk: {chunk}")
|
print(f"chunk: {chunk}")
|
||||||
if isinstance(chunk, ModelResponse):
|
if isinstance(chunk, ModelResponseStream):
|
||||||
if chunk.choices[0].finish_reason is not None:
|
if chunk.choices[0].finish_reason is not None:
|
||||||
num_finish_reason += 1
|
num_finish_reason += 1
|
||||||
assert num_finish_reason == 1
|
assert num_finish_reason == 1
|
||||||
|
@ -1541,7 +1542,7 @@ async def test_parallel_streaming_requests(sync_mode, model):
|
||||||
num_finish_reason = 0
|
num_finish_reason = 0
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
print(f"type of chunk: {type(chunk)}")
|
print(f"type of chunk: {type(chunk)}")
|
||||||
if isinstance(chunk, ModelResponse):
|
if isinstance(chunk, ModelResponseStream):
|
||||||
print(f"OUTSIDE CHUNK: {chunk.choices[0]}")
|
print(f"OUTSIDE CHUNK: {chunk.choices[0]}")
|
||||||
if chunk.choices[0].finish_reason is not None:
|
if chunk.choices[0].finish_reason is not None:
|
||||||
num_finish_reason += 1
|
num_finish_reason += 1
|
||||||
|
|
|
@ -1,104 +0,0 @@
|
||||||
============================= test session starts ==============================
|
|
||||||
platform darwin -- Python 3.11.4, pytest-8.3.2, pluggy-1.5.0 -- /Users/krrishdholakia/Documents/litellm/myenv/bin/python3.11
|
|
||||||
cachedir: .pytest_cache
|
|
||||||
rootdir: /Users/krrishdholakia/Documents/litellm
|
|
||||||
configfile: pyproject.toml
|
|
||||||
plugins: asyncio-0.23.8, respx-0.21.1, anyio-4.6.0
|
|
||||||
asyncio: mode=Mode.STRICT
|
|
||||||
collecting ... collected 1 item
|
|
||||||
|
|
||||||
test_function_calling.py::test_aaparallel_function_call[claude-3-haiku-20240307] <module 'litellm' from '/Users/krrishdholakia/Documents/litellm/litellm/__init__.py'>
|
|
||||||
|
|
||||||
|
|
||||||
[92mRequest to litellm:[0m
|
|
||||||
[92mlitellm.completion(model='claude-3-haiku-20240307', messages=[{'role': 'user', 'content': "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses"}], tools=[{'type': 'function', 'function': {'name': 'get_current_weather', 'description': 'Get the current weather in a given location', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}}], tool_choice='auto')[0m
|
|
||||||
|
|
||||||
|
|
||||||
SYNC kwargs[caching]: False; litellm.cache: None; kwargs.get('cache')['no-cache']: False
|
|
||||||
Final returned optional params: {'tools': [{'type': 'function', 'function': {'name': 'get_current_weather', 'description': 'Get the current weather in a given location', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}}], 'tool_choice': {'type': 'auto'}}
|
|
||||||
optional_params: {'tools': [{'type': 'function', 'function': {'name': 'get_current_weather', 'description': 'Get the current weather in a given location', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}}], 'tool_choice': {'type': 'auto'}}
|
|
||||||
SENT optional_params: {'tools': [{'type': 'function', 'function': {'name': 'get_current_weather', 'description': 'Get the current weather in a given location', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}}], 'tool_choice': {'type': 'auto'}, 'max_tokens': 4096}
|
|
||||||
tool: {'type': 'function', 'function': {'name': 'get_current_weather', 'description': 'Get the current weather in a given location', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}}
|
|
||||||
[92m
|
|
||||||
|
|
||||||
POST Request Sent from LiteLLM:
|
|
||||||
curl -X POST \
|
|
||||||
https://api.anthropic.com/v1/messages \
|
|
||||||
-H 'accept: *****' -H 'anthropic-version: *****' -H 'content-type: *****' -H 'x-api-key: sk-ant-api03-bJf1M8qp-JDptRcZRE5ve5efAfSIaL5u-SZ9vItIkvuFcV5cUsd********************************************' -H 'anthropic-beta: *****' \
|
|
||||||
-d '{'messages': [{'role': 'user', 'content': [{'type': 'text', 'text': "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses"}]}], 'tools': [{'name': 'get_current_weather', 'description': 'Get the current weather in a given location', 'input_schema': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}], 'tool_choice': {'type': 'auto'}, 'max_tokens': 4096, 'model': 'claude-3-haiku-20240307'}'
|
|
||||||
[0m
|
|
||||||
|
|
||||||
_is_function_call: False
|
|
||||||
RAW RESPONSE:
|
|
||||||
{"id":"msg_01HRugqzL4WmcxMmbvDheTph","type":"message","role":"assistant","model":"claude-3-haiku-20240307","content":[{"type":"text","text":"Okay, let's check the current weather in those three cities:"},{"type":"tool_use","id":"toolu_016U6G3kpxjHSiJLwVCrrScz","name":"get_current_weather","input":{"location":"San Francisco","unit":"celsius"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":379,"output_tokens":87}}
|
|
||||||
|
|
||||||
|
|
||||||
raw model_response: {"id":"msg_01HRugqzL4WmcxMmbvDheTph","type":"message","role":"assistant","model":"claude-3-haiku-20240307","content":[{"type":"text","text":"Okay, let's check the current weather in those three cities:"},{"type":"tool_use","id":"toolu_016U6G3kpxjHSiJLwVCrrScz","name":"get_current_weather","input":{"location":"San Francisco","unit":"celsius"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":379,"output_tokens":87}}
|
|
||||||
Logging Details LiteLLM-Success Call: Cache_hit=None
|
|
||||||
Looking up model=claude-3-haiku-20240307 in model_cost_map
|
|
||||||
Looking up model=claude-3-haiku-20240307 in model_cost_map
|
|
||||||
Response
|
|
||||||
ModelResponse(id='chatcmpl-7222f6c2-962a-4776-8639-576723466cb7', choices=[Choices(finish_reason='tool_calls', index=0, message=Message(content="Okay, let's check the current weather in those three cities:", role='assistant', tool_calls=[ChatCompletionMessageToolCall(index=1, function=Function(arguments='{"location": "San Francisco", "unit": "celsius"}', name='get_current_weather'), id='toolu_016U6G3kpxjHSiJLwVCrrScz', type='function')], function_call=None))], created=1727897483, model='claude-3-haiku-20240307', object='chat.completion', system_fingerprint=None, usage=Usage(completion_tokens=87, prompt_tokens=379, total_tokens=466, completion_tokens_details=None))
|
|
||||||
length of tool calls 1
|
|
||||||
Expecting there to be 3 tool calls
|
|
||||||
tool_calls: [ChatCompletionMessageToolCall(index=1, function=Function(arguments='{"location": "San Francisco", "unit": "celsius"}', name='get_current_weather'), id='toolu_016U6G3kpxjHSiJLwVCrrScz', type='function')]
|
|
||||||
Response message
|
|
||||||
Message(content="Okay, let's check the current weather in those three cities:", role='assistant', tool_calls=[ChatCompletionMessageToolCall(index=1, function=Function(arguments='{"location": "San Francisco", "unit": "celsius"}', name='get_current_weather'), id='toolu_016U6G3kpxjHSiJLwVCrrScz', type='function')], function_call=None)
|
|
||||||
messages: [{'role': 'user', 'content': "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses"}, Message(content="Okay, let's check the current weather in those three cities:", role='assistant', tool_calls=[ChatCompletionMessageToolCall(index=1, function=Function(arguments='{"location": "San Francisco", "unit": "celsius"}', name='get_current_weather'), id='toolu_016U6G3kpxjHSiJLwVCrrScz', type='function')], function_call=None), {'tool_call_id': 'toolu_016U6G3kpxjHSiJLwVCrrScz', 'role': 'tool', 'name': 'get_current_weather', 'content': '{"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"}'}]
|
|
||||||
|
|
||||||
|
|
||||||
[92mRequest to litellm:[0m
|
|
||||||
[92mlitellm.completion(model='claude-3-haiku-20240307', messages=[{'role': 'user', 'content': "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses"}, Message(content="Okay, let's check the current weather in those three cities:", role='assistant', tool_calls=[ChatCompletionMessageToolCall(index=1, function=Function(arguments='{"location": "San Francisco", "unit": "celsius"}', name='get_current_weather'), id='toolu_016U6G3kpxjHSiJLwVCrrScz', type='function')], function_call=None), {'tool_call_id': 'toolu_016U6G3kpxjHSiJLwVCrrScz', 'role': 'tool', 'name': 'get_current_weather', 'content': '{"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"}'}], temperature=0.2, seed=22, drop_params=True)[0m
|
|
||||||
|
|
||||||
|
|
||||||
SYNC kwargs[caching]: False; litellm.cache: None; kwargs.get('cache')['no-cache']: False
|
|
||||||
Final returned optional params: {'temperature': 0.2, 'tools': [{'type': 'function', 'function': {'name': 'dummy-tool', 'description': '', 'parameters': {'type': 'object', 'properties': {}}}}]}
|
|
||||||
optional_params: {'temperature': 0.2, 'tools': [{'type': 'function', 'function': {'name': 'dummy-tool', 'description': '', 'parameters': {'type': 'object', 'properties': {}}}}]}
|
|
||||||
SENT optional_params: {'temperature': 0.2, 'tools': [{'type': 'function', 'function': {'name': 'dummy-tool', 'description': '', 'parameters': {'type': 'object', 'properties': {}}}}], 'max_tokens': 4096}
|
|
||||||
tool: {'type': 'function', 'function': {'name': 'dummy-tool', 'description': '', 'parameters': {'type': 'object', 'properties': {}}}}
|
|
||||||
[92m
|
|
||||||
|
|
||||||
POST Request Sent from LiteLLM:
|
|
||||||
curl -X POST \
|
|
||||||
https://api.anthropic.com/v1/messages \
|
|
||||||
-H 'accept: *****' -H 'anthropic-version: *****' -H 'content-type: *****' -H 'x-api-key: sk-ant-api03-bJf1M8qp-JDptRcZRE5ve5efAfSIaL5u-SZ9vItIkvuFcV5cUsd********************************************' -H 'anthropic-beta: *****' \
|
|
||||||
-d '{'messages': [{'role': 'user', 'content': [{'type': 'text', 'text': "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses"}]}, {'role': 'assistant', 'content': [{'type': 'tool_use', 'id': 'toolu_016U6G3kpxjHSiJLwVCrrScz', 'name': 'get_current_weather', 'input': {'location': 'San Francisco', 'unit': 'celsius'}}]}, {'role': 'user', 'content': [{'type': 'tool_result', 'tool_use_id': 'toolu_016U6G3kpxjHSiJLwVCrrScz', 'content': '{"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"}'}]}], 'temperature': 0.2, 'tools': [{'name': 'dummy-tool', 'description': '', 'input_schema': {'type': 'object', 'properties': {}}}], 'max_tokens': 4096, 'model': 'claude-3-haiku-20240307'}'
|
|
||||||
[0m
|
|
||||||
|
|
||||||
_is_function_call: False
|
|
||||||
RAW RESPONSE:
|
|
||||||
{"id":"msg_01Wp8NVScugz6yAGsmB5trpZ","type":"message","role":"assistant","model":"claude-3-haiku-20240307","content":[{"type":"text","text":"The current weather in San Francisco is 72°F (22°C)."},{"type":"tool_use","id":"toolu_01HTXEYDX4MspM76STtJqs1n","name":"get_current_weather","input":{"location":"Tokyo","unit":"celsius"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":426,"output_tokens":90}}
|
|
||||||
|
|
||||||
|
|
||||||
raw model_response: {"id":"msg_01Wp8NVScugz6yAGsmB5trpZ","type":"message","role":"assistant","model":"claude-3-haiku-20240307","content":[{"type":"text","text":"The current weather in San Francisco is 72°F (22°C)."},{"type":"tool_use","id":"toolu_01HTXEYDX4MspM76STtJqs1n","name":"get_current_weather","input":{"location":"Tokyo","unit":"celsius"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":426,"output_tokens":90}}
|
|
||||||
Logging Details LiteLLM-Success Call: Cache_hit=None
|
|
||||||
Looking up model=claude-3-haiku-20240307 in model_cost_map
|
|
||||||
Looking up model=claude-3-haiku-20240307 in model_cost_map
|
|
||||||
second response
|
|
||||||
ModelResponse(id='chatcmpl-c4ed5c25-ba7c-49e5-a6be-5720ab25fff0', choices=[Choices(finish_reason='tool_calls', index=0, message=Message(content='The current weather in San Francisco is 72°F (22°C).', role='assistant', tool_calls=[ChatCompletionMessageToolCall(index=1, function=Function(arguments='{"location": "Tokyo", "unit": "celsius"}', name='get_current_weather'), id='toolu_01HTXEYDX4MspM76STtJqs1n', type='function')], function_call=None))], created=1727897484, model='claude-3-haiku-20240307', object='chat.completion', system_fingerprint=None, usage=Usage(completion_tokens=90, prompt_tokens=426, total_tokens=516, completion_tokens_details=None))
|
|
||||||
PASSED
|
|
||||||
|
|
||||||
=============================== warnings summary ===============================
|
|
||||||
../../myenv/lib/python3.11/site-packages/pydantic/_internal/_config.py:284
|
|
||||||
/Users/krrishdholakia/Documents/litellm/myenv/lib/python3.11/site-packages/pydantic/_internal/_config.py:284: PydanticDeprecatedSince20: Support for class-based `config` is deprecated, use ConfigDict instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/
|
|
||||||
warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning)
|
|
||||||
|
|
||||||
../../litellm/utils.py:17
|
|
||||||
/Users/krrishdholakia/Documents/litellm/litellm/utils.py:17: DeprecationWarning: 'imghdr' is deprecated and slated for removal in Python 3.13
|
|
||||||
import imghdr
|
|
||||||
|
|
||||||
../../litellm/utils.py:124
|
|
||||||
/Users/krrishdholakia/Documents/litellm/litellm/utils.py:124: DeprecationWarning: open_text is deprecated. Use files() instead. Refer to https://importlib-resources.readthedocs.io/en/latest/using.html#migrating-from-legacy for migration advice.
|
|
||||||
with resources.open_text("litellm.llms.tokenizers", "anthropic_tokenizer.json") as f:
|
|
||||||
|
|
||||||
test_function_calling.py:56
|
|
||||||
/Users/krrishdholakia/Documents/litellm/tests/local_testing/test_function_calling.py:56: PytestUnknownMarkWarning: Unknown pytest.mark.flaky - is this a typo? You can register custom marks to avoid this warning - for details, see https://docs.pytest.org/en/stable/how-to/mark.html
|
|
||||||
@pytest.mark.flaky(retries=3, delay=1)
|
|
||||||
|
|
||||||
tests/local_testing/test_function_calling.py::test_aaparallel_function_call[claude-3-haiku-20240307]
|
|
||||||
tests/local_testing/test_function_calling.py::test_aaparallel_function_call[claude-3-haiku-20240307]
|
|
||||||
/Users/krrishdholakia/Documents/litellm/myenv/lib/python3.11/site-packages/httpx/_content.py:202: DeprecationWarning: Use 'content=<...>' to upload raw bytes/text content.
|
|
||||||
warnings.warn(message, DeprecationWarning)
|
|
||||||
|
|
||||||
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
|
|
||||||
======================== 1 passed, 6 warnings in 1.89s =========================
|
|
Loading…
Add table
Add a link
Reference in a new issue