Complete 'requests' library removal (#7350)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 12s

* refactor: initial commit moving watsonx_text to base_llm_http_handler + clarifying new provider directory structure

* refactor(watsonx/completion/handler.py): move to using base llm http handler

removes 'requests' library usage

* fix(watsonx_text/transformation.py): fix result transformation

migrates to transformation.py, for usage with base llm http handler

* fix(streaming_handler.py): migrate watsonx streaming to transformation.py

ensures streaming works with base llm http handler

* fix(streaming_handler.py): fix streaming linting errors and remove watsonx conditional logic

* fix(watsonx/): fix chat route post completion route refactor

* refactor(watsonx/embed): refactor watsonx to use base llm http handler for embedding calls as well

* refactor(base.py): remove requests library usage from litellm

* build(pyproject.toml): remove requests library usage

* fix: fix linting errors

* fix: fix linting errors

* fix(types/utils.py): fix validation errors for modelresponsestream

* fix(replicate/handler.py): fix linting errors

* fix(litellm_logging.py): handle modelresponsestream object

* fix(streaming_handler.py): fix modelresponsestream args

* fix: remove unused imports

* test: fix test

* fix: fix test

* test: fix test

* test: fix tests

* test: fix test

* test: fix patch target

* test: fix test
This commit is contained in:
Krish Dholakia 2024-12-22 07:21:25 -08:00 committed by GitHub
parent 8b1ea40e7b
commit 3671829e39
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
39 changed files with 2147 additions and 2279 deletions

1
.gitignore vendored
View file

@ -67,3 +67,4 @@ litellm/tests/langfuse.log
litellm/proxy/google-cloud-sdk/*
tests/llm_translation/log.txt
venv/
tests/local_testing/log.txt

View file

@ -5,16 +5,16 @@ When adding a new provider, you need to create a directory for the provider that
```
litellm/llms/
└── provider_name/
├── completion/
├── completion/ # use when endpoint is equivalent to openai's `/v1/completions`
│ ├── handler.py
│ └── transformation.py
├── chat/
├── chat/ # use when endpoint is equivalent to openai's `/v1/chat/completions`
│ ├── handler.py
│ └── transformation.py
├── embed/
├── embed/ # use when endpoint is equivalent to openai's `/v1/embeddings`
│ ├── handler.py
│ └── transformation.py
└── rerank/
└── rerank/ # use when endpoint is equivalent to cohere's `/rerank` endpoint.
├── handler.py
└── transformation.py
```

View file

@ -991,6 +991,7 @@ from .utils import (
get_api_base,
get_first_chars_messages,
ModelResponse,
ModelResponseStream,
EmbeddingResponse,
ImageResponse,
TranscriptionResponse,
@ -1157,6 +1158,7 @@ from .llms.perplexity.chat.transformation import PerplexityChatConfig
from .llms.azure.chat.o1_transformation import AzureOpenAIO1Config
from .llms.watsonx.completion.transformation import IBMWatsonXAIConfig
from .llms.watsonx.chat.transformation import IBMWatsonXChatConfig
from .llms.watsonx.embed.transformation import IBMWatsonXEmbeddingConfig
from .main import * # type: ignore
from .integrations import *
from .exceptions import (

View file

@ -43,6 +43,7 @@ from litellm.types.utils import (
ImageResponse,
LiteLLMLoggingBaseClass,
ModelResponse,
ModelResponseStream,
StandardCallbackDynamicParams,
StandardLoggingAdditionalHeaders,
StandardLoggingHiddenParams,
@ -741,6 +742,7 @@ class Logging(LiteLLMLoggingBaseClass):
self,
result: Union[
ModelResponse,
ModelResponseStream,
EmbeddingResponse,
ImageResponse,
TranscriptionResponse,
@ -848,6 +850,7 @@ class Logging(LiteLLMLoggingBaseClass):
): # handle streaming separately
if (
isinstance(result, ModelResponse)
or isinstance(result, ModelResponseStream)
or isinstance(result, EmbeddingResponse)
or isinstance(result, ImageResponse)
or isinstance(result, TranscriptionResponse)
@ -955,6 +958,7 @@ class Logging(LiteLLMLoggingBaseClass):
if self.stream and (
isinstance(result, litellm.ModelResponse)
or isinstance(result, TextCompletionResponse)
or isinstance(result, ModelResponseStream)
):
complete_streaming_response: Optional[
Union[ModelResponse, TextCompletionResponse]
@ -966,9 +970,6 @@ class Logging(LiteLLMLoggingBaseClass):
streaming_chunks=self.sync_streaming_chunks,
is_async=False,
)
_caching_complete_streaming_response: Optional[
Union[ModelResponse, TextCompletionResponse]
] = None
if complete_streaming_response is not None:
verbose_logger.debug(
"Logging Details LiteLLM-Success Call streaming complete"
@ -976,9 +977,6 @@ class Logging(LiteLLMLoggingBaseClass):
self.model_call_details["complete_streaming_response"] = (
complete_streaming_response
)
_caching_complete_streaming_response = copy.deepcopy(
complete_streaming_response
)
self.model_call_details["response_cost"] = (
self._response_cost_calculator(result=complete_streaming_response)
)
@ -1474,6 +1472,7 @@ class Logging(LiteLLMLoggingBaseClass):
] = None
if self.stream is True and (
isinstance(result, litellm.ModelResponse)
or isinstance(result, litellm.ModelResponseStream)
or isinstance(result, TextCompletionResponse)
):
complete_streaming_response: Optional[

View file

@ -2,7 +2,11 @@ from datetime import datetime
from typing import TYPE_CHECKING, Any, List, Optional, Union
from litellm._logging import verbose_logger
from litellm.types.utils import ModelResponse, TextCompletionResponse
from litellm.types.utils import (
ModelResponse,
ModelResponseStream,
TextCompletionResponse,
)
if TYPE_CHECKING:
from litellm import ModelResponse as _ModelResponse
@ -38,7 +42,7 @@ def convert_litellm_response_object_to_str(
def _assemble_complete_response_from_streaming_chunks(
result: Union[ModelResponse, TextCompletionResponse],
result: Union[ModelResponse, TextCompletionResponse, ModelResponseStream],
start_time: datetime,
end_time: datetime,
request_kwargs: dict,

View file

@ -5,7 +5,7 @@ import time
import traceback
import uuid
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Callable, List, Optional
from typing import Any, Callable, Dict, List, Optional, cast
import httpx
from pydantic import BaseModel
@ -611,44 +611,6 @@ class CustomStreamWrapper:
except Exception as e:
raise e
def handle_watsonx_stream(self, chunk):
try:
if isinstance(chunk, dict):
parsed_response = chunk
elif isinstance(chunk, (str, bytes)):
if isinstance(chunk, bytes):
chunk = chunk.decode("utf-8")
if "generated_text" in chunk:
response = chunk.replace("data: ", "").strip()
parsed_response = json.loads(response)
else:
return {
"text": "",
"is_finished": False,
"prompt_tokens": 0,
"completion_tokens": 0,
}
else:
print_verbose(f"chunk: {chunk} (Type: {type(chunk)})")
raise ValueError(
f"Unable to parse response. Original response: {chunk}"
)
results = parsed_response.get("results", [])
if len(results) > 0:
text = results[0].get("generated_text", "")
finish_reason = results[0].get("stop_reason")
is_finished = finish_reason != "not_finished"
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
"prompt_tokens": results[0].get("input_token_count", 0),
"completion_tokens": results[0].get("generated_token_count", 0),
}
return {"text": "", "is_finished": False}
except Exception as e:
raise e
def handle_triton_stream(self, chunk):
try:
if isinstance(chunk, dict):
@ -702,9 +664,18 @@ class CustomStreamWrapper:
# pop model keyword
chunk.pop("model", None)
model_response = ModelResponse(
stream=True, model=_model, stream_options=self.stream_options, **chunk
)
chunk_dict = {}
for key, value in chunk.items():
if key != "stream":
chunk_dict[key] = value
args = {
"model": _model,
"stream_options": self.stream_options,
**chunk_dict,
}
model_response = ModelResponseStream(**args)
if self.response_id is not None:
model_response.id = self.response_id
else:
@ -742,9 +713,9 @@ class CustomStreamWrapper:
def return_processed_chunk_logic( # noqa
self,
completion_obj: dict,
completion_obj: Dict[str, Any],
model_response: ModelResponseStream,
response_obj: dict,
response_obj: Dict[str, Any],
):
print_verbose(
@ -887,11 +858,11 @@ class CustomStreamWrapper:
def chunk_creator(self, chunk): # type: ignore # noqa: PLR0915
model_response = self.model_response_creator()
response_obj: dict = {}
response_obj: Dict[str, Any] = {}
try:
# return this for all models
completion_obj = {"content": ""}
completion_obj: Dict[str, Any] = {"content": ""}
from litellm.types.utils import GenericStreamingChunk as GChunk
if (
@ -1089,11 +1060,6 @@ class CustomStreamWrapper:
print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "watsonx":
response_obj = self.handle_watsonx_stream(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "triton":
response_obj = self.handle_triton_stream(chunk)
completion_obj["content"] = response_obj["text"]
@ -1158,7 +1124,7 @@ class CustomStreamWrapper:
self.received_finish_reason = response_obj["finish_reason"]
else: # openai / azure chat model
if self.custom_llm_provider == "azure":
if hasattr(chunk, "model"):
if isinstance(chunk, BaseModel) and hasattr(chunk, "model"):
# for azure, we need to pass the model from the orignal chunk
self.model = chunk.model
response_obj = self.handle_openai_chat_completion_chunk(chunk)
@ -1190,21 +1156,29 @@ class CustomStreamWrapper:
if response_obj["usage"] is not None:
if isinstance(response_obj["usage"], dict):
model_response.usage = litellm.Usage(
prompt_tokens=response_obj["usage"].get(
"prompt_tokens", None
)
or None,
completion_tokens=response_obj["usage"].get(
"completion_tokens", None
)
or None,
total_tokens=response_obj["usage"].get("total_tokens", None)
or None,
setattr(
model_response,
"usage",
litellm.Usage(
prompt_tokens=response_obj["usage"].get(
"prompt_tokens", None
)
or None,
completion_tokens=response_obj["usage"].get(
"completion_tokens", None
)
or None,
total_tokens=response_obj["usage"].get(
"total_tokens", None
)
or None,
),
)
elif isinstance(response_obj["usage"], BaseModel):
model_response.usage = litellm.Usage(
**response_obj["usage"].model_dump()
setattr(
model_response,
"usage",
litellm.Usage(**response_obj["usage"].model_dump()),
)
model_response.model = self.model
@ -1337,7 +1311,7 @@ class CustomStreamWrapper:
raise StopIteration
except Exception as e:
traceback.format_exc()
e.message = str(e)
setattr(e, "message", str(e))
raise exception_type(
model=self.model,
custom_llm_provider=self.custom_llm_provider,
@ -1434,7 +1408,9 @@ class CustomStreamWrapper:
print_verbose(
f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}; custom_llm_provider: {self.custom_llm_provider}"
)
response: Optional[ModelResponse] = self.chunk_creator(chunk=chunk)
response: Optional[ModelResponseStream] = self.chunk_creator(
chunk=chunk
)
print_verbose(f"PROCESSED CHUNK POST CHUNK CREATOR: {response}")
if response is None:
@ -1597,7 +1573,7 @@ class CustomStreamWrapper:
# __anext__ also calls async_success_handler, which does logging
print_verbose(f"PROCESSED ASYNC CHUNK PRE CHUNK CREATOR: {chunk}")
processed_chunk: Optional[ModelResponse] = self.chunk_creator(
processed_chunk: Optional[ModelResponseStream] = self.chunk_creator(
chunk=chunk
)
print_verbose(
@ -1624,7 +1600,7 @@ class CustomStreamWrapper:
if self.logging_obj._llm_caching_handler is not None:
asyncio.create_task(
self.logging_obj._llm_caching_handler._add_streaming_response_to_cache(
processed_chunk=processed_chunk,
processed_chunk=cast(ModelResponse, processed_chunk),
)
)
@ -1663,8 +1639,8 @@ class CustomStreamWrapper:
chunk = next(self.completion_stream)
if chunk is not None and chunk != b"":
print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}")
processed_chunk: Optional[ModelResponse] = self.chunk_creator(
chunk=chunk
processed_chunk: Optional[ModelResponseStream] = (
self.chunk_creator(chunk=chunk)
)
print_verbose(
f"PROCESSED CHUNK POST CHUNK CREATOR: {processed_chunk}"

View file

@ -2,7 +2,6 @@
from typing import Any, Optional, Union
import httpx
import requests
import litellm
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
@ -16,7 +15,7 @@ class BaseLLM:
def process_response(
self,
model: str,
response: Union[requests.Response, httpx.Response],
response: httpx.Response,
model_response: ModelResponse,
stream: bool,
logging_obj: Any,
@ -35,7 +34,7 @@ class BaseLLM:
def process_text_completion_response(
self,
model: str,
response: Union[requests.Response, httpx.Response],
response: httpx.Response,
model_response: TextCompletionResponse,
stream: bool,
logging_obj: Any,

View file

@ -107,7 +107,13 @@ class BaseConfig(ABC):
) -> dict:
pass
def get_complete_url(self, api_base: str, model: str) -> str:
def get_complete_url(
self,
api_base: str,
model: str,
optional_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
OPTIONAL

View file

@ -1,10 +1,10 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, List, Optional, Union
from typing import TYPE_CHECKING, Any, List, Optional
import httpx
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.types.llms.openai import AllMessageValues
from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues
from litellm.types.utils import EmbeddingResponse, ModelResponse
if TYPE_CHECKING:
@ -16,12 +16,11 @@ else:
class BaseEmbeddingConfig(BaseConfig, ABC):
@abstractmethod
def transform_embedding_request(
self,
model: str,
input: Union[str, List[str], List[float], List[List[float]]],
input: AllEmbeddingInputValues,
optional_params: dict,
headers: dict,
) -> dict:
@ -34,14 +33,20 @@ class BaseEmbeddingConfig(BaseConfig, ABC):
raw_response: httpx.Response,
model_response: EmbeddingResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str] = None,
request_data: dict = {},
optional_params: dict = {},
litellm_params: dict = {},
api_key: Optional[str],
request_data: dict,
optional_params: dict,
litellm_params: dict,
) -> EmbeddingResponse:
return model_response
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
def get_complete_url(
self,
api_base: Optional[str],
model: str,
optional_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
OPTIONAL

View file

@ -72,7 +72,13 @@ class CloudflareChatConfig(BaseConfig):
}
return headers
def get_complete_url(self, api_base: str, model: str) -> str:
def get_complete_url(
self,
api_base: str,
model: str,
optional_params: dict,
stream: Optional[bool] = None,
) -> str:
return api_base + model
def get_supported_openai_params(self, model: str) -> List[str]:

View file

@ -110,6 +110,8 @@ class BaseLLMHTTPHandler:
api_base = provider_config.get_complete_url(
api_base=api_base,
model=model,
optional_params=optional_params,
stream=stream,
)
data = provider_config.transform_request(
@ -402,6 +404,7 @@ class BaseLLMHTTPHandler:
logging_obj: LiteLLMLoggingObj,
api_base: Optional[str],
optional_params: dict,
litellm_params: dict,
model_response: EmbeddingResponse,
api_key: Optional[str] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
@ -424,6 +427,7 @@ class BaseLLMHTTPHandler:
api_base = provider_config.get_complete_url(
api_base=api_base,
model=model,
optional_params=optional_params,
)
data = provider_config.transform_embedding_request(
@ -457,6 +461,8 @@ class BaseLLMHTTPHandler:
api_key=api_key,
timeout=timeout,
client=client,
optional_params=optional_params,
litellm_params=litellm_params,
)
if client is None or not isinstance(client, HTTPHandler):
@ -484,6 +490,8 @@ class BaseLLMHTTPHandler:
logging_obj=logging_obj,
api_key=api_key,
request_data=data,
optional_params=optional_params,
litellm_params=litellm_params,
)
async def aembedding(
@ -496,6 +504,8 @@ class BaseLLMHTTPHandler:
provider_config: BaseEmbeddingConfig,
model_response: EmbeddingResponse,
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
@ -524,6 +534,8 @@ class BaseLLMHTTPHandler:
logging_obj=logging_obj,
api_key=api_key,
request_data=request_data,
optional_params=optional_params,
litellm_params=litellm_params,
)
def rerank(

View file

@ -350,7 +350,13 @@ class OllamaConfig(BaseConfig):
) -> dict:
return headers
def get_complete_url(self, api_base: str, model: str) -> str:
def get_complete_url(
self,
api_base: str,
model: str,
optional_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
OPTIONAL

View file

@ -168,7 +168,9 @@ def completion(
time.time()
) # for pricing this must remain right before calling api
prediction_url = replicate_config.get_complete_url(api_base, model)
prediction_url = replicate_config.get_complete_url(
api_base=api_base, model=model, optional_params=optional_params
)
## COMPLETION CALL
httpx_client = _get_httpx_client(
@ -235,7 +237,9 @@ async def async_completion(
headers: dict,
) -> Union[ModelResponse, CustomStreamWrapper]:
prediction_url = replicate_config.get_complete_url(api_base=api_base, model=model)
prediction_url = replicate_config.get_complete_url(
api_base=api_base, model=model, optional_params=optional_params
)
async_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.REPLICATE,
params={"timeout": 600.0},

View file

@ -136,7 +136,13 @@ class ReplicateConfig(BaseConfig):
status_code=status_code, message=error_message, headers=headers
)
def get_complete_url(self, api_base: str, model: str) -> str:
def get_complete_url(
self,
api_base: str,
model: str,
optional_params: dict,
stream: Optional[bool] = None,
) -> str:
version_id = self.model_to_version_id(model)
base_url = api_base
if "deployments" in version_id:

View file

@ -7,6 +7,7 @@ from litellm.llms.base_llm.embedding.transformation import (
BaseEmbeddingConfig,
LiteLLMLoggingObj,
)
from litellm.types.llms.openai import AllEmbeddingInputValues
from litellm.types.utils import EmbeddingResponse
from ..common_utils import TritonError
@ -48,7 +49,7 @@ class TritonEmbeddingConfig(BaseEmbeddingConfig):
def transform_embedding_request(
self,
model: str,
input: Union[str, List[str], List[float], List[List[float]]],
input: AllEmbeddingInputValues,
optional_params: dict,
headers: dict,
) -> dict:

View file

@ -6,7 +6,7 @@ from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues
from litellm.types.utils import EmbeddingResponse, Usage
@ -38,7 +38,13 @@ class VoyageEmbeddingConfig(BaseEmbeddingConfig):
def __init__(self) -> None:
pass
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
def get_complete_url(
self,
api_base: Optional[str],
model: str,
optional_params: dict,
stream: Optional[bool] = None,
) -> str:
if api_base:
if not api_base.endswith("/embeddings"):
api_base = f"{api_base}/embeddings"
@ -90,7 +96,7 @@ class VoyageEmbeddingConfig(BaseEmbeddingConfig):
def transform_embedding_request(
self,
model: str,
input: Union[str, List[str], List[float], List[List[float]]],
input: AllEmbeddingInputValues,
optional_params: dict,
headers: dict,
) -> dict:

View file

@ -3,57 +3,19 @@ from typing import Callable, Optional, Union
import httpx
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.watsonx import WatsonXAIEndpoint, WatsonXAPIParams
from litellm.types.utils import CustomStreamingDecoder, ModelResponse
from ...openai_like.chat.handler import OpenAILikeChatHandler
from ..common_utils import WatsonXAIError, _get_api_params
from ..common_utils import _get_api_params
from .transformation import IBMWatsonXChatConfig
watsonx_chat_transformation = IBMWatsonXChatConfig()
class WatsonXChatHandler(OpenAILikeChatHandler):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def _prepare_url(
self, model: str, api_params: WatsonXAPIParams, stream: Optional[bool]
) -> str:
if model.startswith("deployment/"):
if api_params.get("space_id") is None:
raise WatsonXAIError(
status_code=401,
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
)
deployment_id = "/".join(model.split("/")[1:])
endpoint = (
WatsonXAIEndpoint.DEPLOYMENT_CHAT_STREAM.value
if stream is True
else WatsonXAIEndpoint.DEPLOYMENT_CHAT.value
)
endpoint = endpoint.format(deployment_id=deployment_id)
else:
endpoint = (
WatsonXAIEndpoint.CHAT_STREAM.value
if stream is True
else WatsonXAIEndpoint.CHAT.value
)
base_url = httpx.URL(api_params["url"])
base_url = base_url.join(endpoint)
full_url = str(
base_url.copy_add_param(key="version", value=api_params["api_version"])
)
return full_url
def _prepare_payload(
self, model: str, api_params: WatsonXAPIParams, stream: Optional[bool]
) -> dict:
payload: dict = {}
if model.startswith("deployment/"):
return payload
payload["model_id"] = model
payload["project_id"] = api_params["project_id"]
return payload
def completion(
self,
*,
@ -70,32 +32,37 @@ class WatsonXChatHandler(OpenAILikeChatHandler):
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
headers: Optional[dict] = None,
logger_fn=None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
custom_endpoint: Optional[bool] = None,
streaming_decoder: Optional[CustomStreamingDecoder] = None,
fake_stream: bool = False,
):
api_params = _get_api_params(optional_params, print_verbose=print_verbose)
api_params = _get_api_params(params=optional_params)
if headers is None:
headers = {}
headers.update(
{
"Authorization": f"Bearer {api_params['token']}",
"Content-Type": "application/json",
"Accept": "application/json",
}
## UPDATE HEADERS
headers = watsonx_chat_transformation.validate_environment(
headers=headers or {},
model=model,
messages=messages,
optional_params=optional_params,
api_key=api_key,
)
stream: Optional[bool] = optional_params.get("stream", False)
## GET API URL
api_base = watsonx_chat_transformation.get_complete_url(
api_base=api_base,
model=model,
optional_params=optional_params,
stream=optional_params.get("stream", False),
)
## get api url and payload
api_base = self._prepare_url(model=model, api_params=api_params, stream=stream)
watsonx_auth_payload = self._prepare_payload(
model=model, api_params=api_params, stream=stream
## UPDATE PAYLOAD (optional params)
watsonx_auth_payload = watsonx_chat_transformation._prepare_payload(
model=model,
api_params=api_params,
)
optional_params.update(watsonx_auth_payload)

View file

@ -7,12 +7,14 @@ Docs: https://cloud.ibm.com/apidocs/watsonx-ai#text-chat
from typing import List, Optional, Tuple, Union
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.watsonx import WatsonXAIEndpoint, WatsonXAPIParams
from ....utils import _remove_additional_properties, _remove_strict_from_schema
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
from ..common_utils import IBMWatsonXMixin, WatsonXAIError
class IBMWatsonXChatConfig(OpenAIGPTConfig):
class IBMWatsonXChatConfig(IBMWatsonXMixin, OpenAIGPTConfig):
def get_supported_openai_params(self, model: str) -> List:
return [
@ -75,3 +77,47 @@ class IBMWatsonXChatConfig(OpenAIGPTConfig):
api_key or get_secret_str("HOSTED_VLLM_API_KEY") or ""
) # vllm does not require an api key
return api_base, dynamic_api_key
def get_complete_url(
self,
api_base: str,
model: str,
optional_params: dict,
stream: Optional[bool] = None,
) -> str:
url = self._get_base_url(api_base=api_base)
if model.startswith("deployment/"):
# deployment models are passed in as 'deployment/<deployment_id>'
if optional_params.get("space_id") is None:
raise WatsonXAIError(
status_code=401,
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
)
deployment_id = "/".join(model.split("/")[1:])
endpoint = (
WatsonXAIEndpoint.DEPLOYMENT_CHAT_STREAM.value
if stream
else WatsonXAIEndpoint.DEPLOYMENT_CHAT.value
)
endpoint = endpoint.format(deployment_id=deployment_id)
else:
endpoint = (
WatsonXAIEndpoint.DEPLOYMENT_CHAT_STREAM.value
if stream
else WatsonXAIEndpoint.DEPLOYMENT_CHAT.value
)
url = url.rstrip("/") + endpoint
## add api version
url = self._add_api_version_to_url(
url=url, api_version=optional_params.pop("api_version", None)
)
return url
def _prepare_payload(self, model: str, api_params: WatsonXAPIParams) -> dict:
payload: dict = {}
if model.startswith("deployment/"):
return payload
payload["model_id"] = model
payload["project_id"] = api_params["project_id"]
return payload

View file

@ -1,13 +1,15 @@
from typing import Callable, Dict, Optional, Union, cast
from typing import Dict, List, Optional, Union, cast
import httpx
import litellm
from litellm import verbose_logger
from litellm.caching import InMemoryCache
from litellm.litellm_core_utils.prompt_templates import factory as ptf
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.watsonx import WatsonXAPIParams
from litellm.types.llms.openai import AllMessageValues
from litellm.types.llms.watsonx import WatsonXAPIParams, WatsonXCredentials
class WatsonXAIError(BaseLLMException):
@ -65,18 +67,20 @@ def generate_iam_token(api_key=None, **params) -> str:
return cast(str, result)
def _generate_watsonx_token(api_key: Optional[str], token: Optional[str]) -> str:
if token is not None:
return token
token = generate_iam_token(api_key)
return token
def _get_api_params(
params: dict,
print_verbose: Optional[Callable] = None,
generate_token: Optional[bool] = True,
) -> WatsonXAPIParams:
"""
Find watsonx.ai credentials in the params or environment variables and return the headers for authentication.
"""
# Load auth variables from params
url = params.pop("url", params.pop("api_base", params.pop("base_url", None)))
api_key = params.pop("apikey", None)
token = params.pop("token", None)
project_id = params.pop(
"project_id", params.pop("watsonx_project", None)
) # watsonx.ai project_id - allow 'watsonx_project' to be consistent with how vertex project implementation works -> reduce provider-specific params
@ -86,29 +90,8 @@ def _get_api_params(
region_name = params.pop(
"watsonx_region_name", params.pop("watsonx_region", None)
) # consistent with how vertex ai + aws regions are accepted
wx_credentials = params.pop(
"wx_credentials",
params.pop(
"watsonx_credentials", None
), # follow {provider}_credentials, same as vertex ai
)
api_version = params.pop("api_version", litellm.WATSONX_DEFAULT_API_VERSION)
# Load auth variables from environment variables
if url is None:
url = (
get_secret_str("WATSONX_API_BASE") # consistent with 'AZURE_API_BASE'
or get_secret_str("WATSONX_URL")
or get_secret_str("WX_URL")
or get_secret_str("WML_URL")
)
if api_key is None:
api_key = (
get_secret_str("WATSONX_APIKEY")
or get_secret_str("WATSONX_API_KEY")
or get_secret_str("WX_API_KEY")
)
if token is None:
token = get_secret_str("WATSONX_TOKEN") or get_secret_str("WX_TOKEN")
if project_id is None:
project_id = (
get_secret_str("WATSONX_PROJECT_ID")
@ -129,34 +112,6 @@ def _get_api_params(
or get_secret_str("SPACE_ID")
)
# credentials parsing
if wx_credentials is not None:
url = wx_credentials.get("url", url)
api_key = wx_credentials.get("apikey", wx_credentials.get("api_key", api_key))
token = wx_credentials.get(
"token",
wx_credentials.get(
"watsonx_token", token
), # follow format of {provider}_token, same as azure - e.g. 'azure_ad_token=..'
)
# verify that all required credentials are present
if url is None:
raise WatsonXAIError(
status_code=401,
message="Error: Watsonx URL not set. Set WX_URL in environment variables or pass in as a parameter.",
)
if token is None and api_key is not None and generate_token:
# generate the auth token
if print_verbose is not None:
print_verbose("Generating IAM token for Watsonx.ai")
token = generate_iam_token(api_key)
elif token is None and api_key is None:
raise WatsonXAIError(
status_code=401,
message="Error: API key or token not found. Set WX_API_KEY or WX_TOKEN in environment variables or pass in as a parameter.",
)
if project_id is None:
raise WatsonXAIError(
status_code=401,
@ -164,11 +119,147 @@ def _get_api_params(
)
return WatsonXAPIParams(
url=url,
api_key=api_key,
token=cast(str, token),
project_id=project_id,
space_id=space_id,
region_name=region_name,
api_version=api_version,
)
def convert_watsonx_messages_to_prompt(
model: str,
messages: List[AllMessageValues],
provider: str,
custom_prompt_dict: Dict,
) -> str:
# handle anthropic prompts and amazon titan prompts
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_dict = custom_prompt_dict[model]
prompt = ptf.custom_prompt(
messages=messages,
role_dict=model_prompt_dict.get(
"role_dict", model_prompt_dict.get("roles")
),
initial_prompt_value=model_prompt_dict.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_dict.get("final_prompt_value", ""),
bos_token=model_prompt_dict.get("bos_token", ""),
eos_token=model_prompt_dict.get("eos_token", ""),
)
return prompt
elif provider == "ibm-mistralai":
prompt = ptf.mistral_instruct_pt(messages=messages)
else:
prompt: str = ptf.prompt_factory( # type: ignore
model=model, messages=messages, custom_llm_provider="watsonx"
)
return prompt
# Mixin class for shared IBM Watson X functionality
class IBMWatsonXMixin:
def validate_environment(
self,
headers: Dict,
model: str,
messages: List[AllMessageValues],
optional_params: Dict,
api_key: Optional[str] = None,
) -> Dict:
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
}
token = cast(Optional[str], optional_params.get("token"))
if token:
headers["Authorization"] = f"Bearer {token}"
else:
token = _generate_watsonx_token(api_key=api_key, token=token)
# build auth headers
headers["Authorization"] = f"Bearer {token}"
return headers
def _get_base_url(self, api_base: Optional[str]) -> str:
url = (
api_base
or get_secret_str("WATSONX_API_BASE") # consistent with 'AZURE_API_BASE'
or get_secret_str("WATSONX_URL")
or get_secret_str("WX_URL")
or get_secret_str("WML_URL")
)
if url is None:
raise WatsonXAIError(
status_code=401,
message="Error: Watsonx URL not set. Set WATSONX_API_BASE in environment variables or pass in as parameter - 'api_base='.",
)
return url
def _add_api_version_to_url(self, url: str, api_version: Optional[str]) -> str:
api_version = api_version or litellm.WATSONX_DEFAULT_API_VERSION
url = url + f"?version={api_version}"
return url
def get_error_class(
self, error_message: str, status_code: int, headers: Union[Dict, httpx.Headers]
) -> BaseLLMException:
return WatsonXAIError(
status_code=status_code, message=error_message, headers=headers
)
@staticmethod
def get_watsonx_credentials(
optional_params: dict, api_key: Optional[str], api_base: Optional[str]
) -> WatsonXCredentials:
api_key = (
api_key
or optional_params.pop("apikey", None)
or get_secret_str("WATSONX_APIKEY")
or get_secret_str("WATSONX_API_KEY")
or get_secret_str("WX_API_KEY")
)
api_base = (
api_base
or optional_params.pop(
"url",
optional_params.pop("api_base", optional_params.pop("base_url", None)),
)
or get_secret_str("WATSONX_API_BASE")
or get_secret_str("WATSONX_URL")
or get_secret_str("WX_URL")
or get_secret_str("WML_URL")
)
wx_credentials = optional_params.pop(
"wx_credentials",
optional_params.pop(
"watsonx_credentials", None
), # follow {provider}_credentials, same as vertex ai
)
token: Optional[str] = None
if wx_credentials is not None:
api_base = wx_credentials.get("url", api_base)
api_key = wx_credentials.get(
"apikey", wx_credentials.get("api_key", api_key)
)
token = wx_credentials.get(
"token",
wx_credentials.get(
"watsonx_token", None
), # follow format of {provider}_token, same as azure - e.g. 'azure_ad_token=..'
)
if api_key is None or not isinstance(api_key, str):
raise WatsonXAIError(
status_code=401,
message="Error: Watsonx API key not set. Set WATSONX_API_KEY in environment variables or pass in as parameter - 'api_key='.",
)
if api_base is None or not isinstance(api_base, str):
raise WatsonXAIError(
status_code=401,
message="Error: Watsonx API base not set. Set WATSONX_API_BASE in environment variables or pass in as parameter - 'api_base='.",
)
return WatsonXCredentials(
api_key=api_key, api_base=api_base, token=cast(Optional[str], token)
)

View file

@ -1,551 +1,3 @@
import asyncio
import json # noqa: E401
import time
from contextlib import asynccontextmanager, contextmanager
from datetime import datetime
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Callable,
Generator,
Iterator,
List,
Optional,
Union,
)
import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.litellm_core_utils.prompt_templates import factory as ptf
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.types.llms.openai import AllMessageValues
from litellm.types.llms.watsonx import WatsonXAIEndpoint
from litellm.utils import EmbeddingResponse, ModelResponse, Usage, map_finish_reason
from ...base import BaseLLM
from ..common_utils import WatsonXAIError, _get_api_params
from .transformation import IBMWatsonXAIConfig
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict) -> str:
# handle anthropic prompts and amazon titan prompts
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_dict = custom_prompt_dict[model]
prompt = ptf.custom_prompt(
messages=messages,
role_dict=model_prompt_dict.get(
"role_dict", model_prompt_dict.get("roles")
),
initial_prompt_value=model_prompt_dict.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_dict.get("final_prompt_value", ""),
bos_token=model_prompt_dict.get("bos_token", ""),
eos_token=model_prompt_dict.get("eos_token", ""),
)
return prompt
elif provider == "ibm-mistralai":
prompt = ptf.mistral_instruct_pt(messages=messages)
else:
prompt: str = ptf.prompt_factory( # type: ignore
model=model, messages=messages, custom_llm_provider="watsonx"
)
return prompt
class IBMWatsonXAI(BaseLLM):
"""
Class to interface with IBM watsonx.ai API for text generation and embeddings.
Reference: https://cloud.ibm.com/apidocs/watsonx-ai
"""
api_version = "2024-03-13"
def __init__(self) -> None:
super().__init__()
def _prepare_text_generation_req(
self,
model_id: str,
messages: List[AllMessageValues],
prompt: str,
stream: bool,
optional_params: dict,
print_verbose: Optional[Callable] = None,
) -> dict:
"""
Get the request parameters for text generation.
"""
api_params = _get_api_params(optional_params, print_verbose=print_verbose)
# build auth headers
api_token = api_params.get("token")
self.token = api_token
headers = IBMWatsonXAIConfig().validate_environment(
headers={},
model=model_id,
messages=messages,
optional_params=optional_params,
api_key=api_token,
)
extra_body_params = optional_params.pop("extra_body", {})
optional_params.update(extra_body_params)
# init the payload to the text generation call
payload = {
"input": prompt,
"moderations": optional_params.pop("moderations", {}),
"parameters": optional_params,
}
request_params = dict(version=api_params["api_version"])
# text generation endpoint deployment or model / stream or not
if model_id.startswith("deployment/"):
# deployment models are passed in as 'deployment/<deployment_id>'
if api_params.get("space_id") is None:
raise WatsonXAIError(
status_code=401,
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
)
deployment_id = "/".join(model_id.split("/")[1:])
endpoint = (
WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION_STREAM.value
if stream
else WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION.value
)
endpoint = endpoint.format(deployment_id=deployment_id)
else:
payload["model_id"] = model_id
payload["project_id"] = api_params["project_id"]
endpoint = (
WatsonXAIEndpoint.TEXT_GENERATION_STREAM
if stream
else WatsonXAIEndpoint.TEXT_GENERATION
)
url = api_params["url"].rstrip("/") + endpoint
return dict(
method="POST", url=url, headers=headers, json=payload, params=request_params
)
def _process_text_gen_response(
self, json_resp: dict, model_response: Union[ModelResponse, None] = None
) -> ModelResponse:
if "results" not in json_resp:
raise WatsonXAIError(
status_code=500,
message=f"Error: Invalid response from Watsonx.ai API: {json_resp}",
)
if model_response is None:
model_response = ModelResponse(model=json_resp.get("model_id", None))
generated_text = json_resp["results"][0]["generated_text"]
prompt_tokens = json_resp["results"][0]["input_token_count"]
completion_tokens = json_resp["results"][0]["generated_token_count"]
model_response.choices[0].message.content = generated_text # type: ignore
model_response.choices[0].finish_reason = map_finish_reason(
json_resp["results"][0]["stop_reason"]
)
if json_resp.get("created_at"):
model_response.created = int(
datetime.fromisoformat(json_resp["created_at"]).timestamp()
)
else:
model_response.created = int(time.time())
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
def completion(
self,
model: str,
messages: list,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
logging_obj: Any,
optional_params: dict,
acompletion=None,
litellm_params=None,
logger_fn=None,
timeout=None,
):
"""
Send a text generation request to the IBM Watsonx.ai API.
Reference: https://cloud.ibm.com/apidocs/watsonx-ai#text-generation
"""
stream = optional_params.pop("stream", False)
# Load default configs
config = IBMWatsonXAIConfig.get_config()
for k, v in config.items():
if k not in optional_params:
optional_params[k] = v
# Make prompt to send to model
provider = model.split("/")[0]
# model_name = "/".join(model.split("/")[1:])
prompt = convert_messages_to_prompt(
model, messages, provider, custom_prompt_dict
)
model_response.model = model
def process_stream_response(
stream_resp: Union[Iterator[str], AsyncIterator],
) -> CustomStreamWrapper:
streamwrapper = litellm.CustomStreamWrapper(
stream_resp,
model=model,
custom_llm_provider="watsonx",
logging_obj=logging_obj,
)
return streamwrapper
# create the function to manage the request to watsonx.ai
self.request_manager = RequestManager(logging_obj)
def handle_text_request(request_params: dict) -> ModelResponse:
with self.request_manager.request(
request_params,
input=prompt,
timeout=timeout,
) as resp:
json_resp = resp.json()
return self._process_text_gen_response(json_resp, model_response)
async def handle_text_request_async(request_params: dict) -> ModelResponse:
async with self.request_manager.async_request(
request_params,
input=prompt,
timeout=timeout,
) as resp:
json_resp = resp.json()
return self._process_text_gen_response(json_resp, model_response)
def handle_stream_request(request_params: dict) -> CustomStreamWrapper:
# stream the response - generated chunks will be handled
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
with self.request_manager.request(
request_params,
stream=True,
input=prompt,
timeout=timeout,
) as resp:
streamwrapper = process_stream_response(resp.iter_lines())
return streamwrapper
async def handle_stream_request_async(
request_params: dict,
) -> CustomStreamWrapper:
# stream the response - generated chunks will be handled
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
async with self.request_manager.async_request(
request_params,
stream=True,
input=prompt,
timeout=timeout,
) as resp:
streamwrapper = process_stream_response(resp.aiter_lines())
return streamwrapper
try:
## Get the response from the model
req_params = self._prepare_text_generation_req(
model_id=model,
prompt=prompt,
messages=messages,
stream=stream,
optional_params=optional_params,
print_verbose=print_verbose,
)
if stream and (acompletion is True):
# stream and async text generation
return handle_stream_request_async(req_params)
elif stream:
# streaming text generation
return handle_stream_request(req_params)
elif acompletion is True:
# async text generation
return handle_text_request_async(req_params)
else:
# regular text generation
return handle_text_request(req_params)
except WatsonXAIError as e:
raise e
except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e))
def _process_embedding_response(
self, json_resp: dict, model_response: Optional[EmbeddingResponse] = None
) -> EmbeddingResponse:
if model_response is None:
model_response = EmbeddingResponse(model=json_resp.get("model_id", None))
results = json_resp.get("results", [])
embedding_response = []
for idx, result in enumerate(results):
embedding_response.append(
{
"object": "embedding",
"index": idx,
"embedding": result["embedding"],
}
)
model_response.object = "list"
model_response.data = embedding_response
input_tokens = json_resp.get("input_token_count", 0)
setattr(
model_response,
"usage",
Usage(
prompt_tokens=input_tokens,
completion_tokens=0,
total_tokens=input_tokens,
),
)
return model_response
def embedding(
self,
model: str,
input: Union[list, str],
model_response: EmbeddingResponse,
api_key: Optional[str],
logging_obj: Any,
optional_params: dict,
encoding=None,
print_verbose=None,
aembedding=None,
) -> EmbeddingResponse:
"""
Send a text embedding request to the IBM Watsonx.ai API.
"""
if optional_params is None:
optional_params = {}
# Load default configs
config = IBMWatsonXAIConfig.get_config()
for k, v in config.items():
if k not in optional_params:
optional_params[k] = v
model_response.model = model
# Load auth variables from environment variables
if isinstance(input, str):
input = [input]
if api_key is not None:
optional_params["api_key"] = api_key
api_params = _get_api_params(optional_params)
# build auth headers
api_token = api_params.get("token")
self.token = api_token
headers = {
"Authorization": f"Bearer {api_token}",
"Content-Type": "application/json",
"Accept": "application/json",
}
# init the payload to the text generation call
payload = {
"inputs": input,
"model_id": model,
"project_id": api_params["project_id"],
"parameters": optional_params,
}
request_params = dict(version=api_params["api_version"])
url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.EMBEDDINGS
req_params = {
"method": "POST",
"url": url,
"headers": headers,
"json": payload,
"params": request_params,
}
request_manager = RequestManager(logging_obj)
def handle_embedding(request_params: dict) -> EmbeddingResponse:
with request_manager.request(request_params, input=input) as resp:
json_resp = resp.json()
return self._process_embedding_response(json_resp, model_response)
async def handle_aembedding(request_params: dict) -> EmbeddingResponse:
async with request_manager.async_request(
request_params, input=input
) as resp:
json_resp = resp.json()
return self._process_embedding_response(json_resp, model_response)
try:
if aembedding is True:
return handle_aembedding(req_params) # type: ignore
else:
return handle_embedding(req_params)
except WatsonXAIError as e:
raise e
except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e))
def get_available_models(self, *, ids_only: bool = True, **params):
api_params = _get_api_params(params)
self.token = api_params["token"]
headers = {
"Authorization": f"Bearer {api_params['token']}",
"Content-Type": "application/json",
"Accept": "application/json",
}
request_params = dict(version=api_params["api_version"])
url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.AVAILABLE_MODELS
req_params = dict(method="GET", url=url, headers=headers, params=request_params)
with RequestManager(logging_obj=None).request(req_params) as resp:
json_resp = resp.json()
if not ids_only:
return json_resp
return [res["model_id"] for res in json_resp["resources"]]
class RequestManager:
"""
A class to handle sync/async HTTP requests to the IBM Watsonx.ai API.
Usage:
```python
request_params = dict(method="POST", url="https://api.example.com", headers={"Authorization" : "Bearer token"}, json={"key": "value"})
request_manager = RequestManager(logging_obj=logging_obj)
with request_manager.request(request_params) as resp:
...
# or
async with request_manager.async_request(request_params) as resp:
...
```
"""
def __init__(self, logging_obj=None):
self.logging_obj = logging_obj
def pre_call(
self,
request_params: dict,
input: Optional[Any] = None,
is_async: Optional[bool] = False,
):
if self.logging_obj is None:
return
request_str = (
f"response = {'await ' if is_async else ''}{request_params['method']}(\n"
f"\turl={request_params['url']},\n"
f"\tjson={request_params.get('json')},\n"
f")"
)
self.logging_obj.pre_call(
input=input,
api_key=request_params["headers"].get("Authorization"),
additional_args={
"complete_input_dict": request_params.get("json"),
"request_str": request_str,
},
)
def post_call(self, resp, request_params):
if self.logging_obj is None:
return
self.logging_obj.post_call(
input=input,
api_key=request_params["headers"].get("Authorization"),
original_response=json.dumps(resp.json()),
additional_args={
"status_code": resp.status_code,
"complete_input_dict": request_params.get(
"data", request_params.get("json")
),
},
)
@contextmanager
def request(
self,
request_params: dict,
stream: bool = False,
input: Optional[Any] = None,
timeout=None,
) -> Generator[requests.Response, None, None]:
"""
Returns a context manager that yields the response from the request.
"""
self.pre_call(request_params, input)
if timeout:
request_params["timeout"] = timeout
if stream:
request_params["stream"] = stream
try:
resp = requests.request(**request_params)
if not resp.ok:
raise WatsonXAIError(
status_code=resp.status_code,
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
)
yield resp
except Exception as e:
raise WatsonXAIError(status_code=500, message=str(e))
if not stream:
self.post_call(resp, request_params)
@asynccontextmanager
async def async_request(
self,
request_params: dict,
stream: bool = False,
input: Optional[Any] = None,
timeout=None,
) -> AsyncGenerator[httpx.Response, None]:
self.pre_call(request_params, input, is_async=True)
if timeout:
request_params["timeout"] = timeout
if stream:
request_params["stream"] = stream
try:
self.async_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.WATSONX,
params={
"timeout": httpx.Timeout(
timeout=request_params.pop("timeout", 600.0), connect=5.0
),
},
)
if "json" in request_params:
request_params["data"] = json.dumps(request_params.pop("json", {}))
method = request_params.pop("method")
retries = 0
resp: Optional[httpx.Response] = None
while retries < 3:
if method.upper() == "POST":
resp = await self.async_handler.post(**request_params)
else:
resp = await self.async_handler.get(**request_params)
if resp is not None and resp.status_code in [429, 503, 504, 520]:
# to handle rate limiting and service unavailable errors
# see: ibm_watsonx_ai.foundation_models.inference.base_model_inference.BaseModelInference._send_inference_payload
await asyncio.sleep(2**retries)
retries += 1
else:
break
if resp is None:
raise WatsonXAIError(
status_code=500,
message="No response from the server",
)
if resp.is_error:
error_reason = getattr(resp, "reason", "")
raise WatsonXAIError(
status_code=resp.status_code,
message=f"Error {resp.status_code} ({error_reason}): {resp.text}",
)
yield resp
# await async_handler.close()
except Exception as e:
raise e
raise WatsonXAIError(status_code=500, message=str(e))
if not stream:
self.post_call(resp, request_params)
"""
Watsonx uses the llm_http_handler.py to handle the requests.
"""

View file

@ -1,13 +1,31 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import time
from datetime import datetime
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Dict,
Iterator,
List,
Optional,
Union,
)
import httpx
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.llms.openai import AllMessageValues
from litellm.utils import ModelResponse
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUsageBlock
from litellm.types.llms.watsonx import WatsonXAIEndpoint
from litellm.types.utils import GenericStreamingChunk, ModelResponse, Usage
from litellm.utils import map_finish_reason
from ...base_llm.chat.transformation import BaseConfig
from ..common_utils import WatsonXAIError
from ..common_utils import (
IBMWatsonXMixin,
WatsonXAIError,
_get_api_params,
convert_watsonx_messages_to_prompt,
)
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
@ -17,7 +35,7 @@ else:
LiteLLMLoggingObj = Any
class IBMWatsonXAIConfig(BaseConfig):
class IBMWatsonXAIConfig(IBMWatsonXMixin, BaseConfig):
"""
Reference: https://cloud.ibm.com/apidocs/watsonx-ai#text-generation
(See ibm_watsonx_ai.metanames.GenTextParamsMetaNames for a list of all available params)
@ -210,13 +228,6 @@ class IBMWatsonXAIConfig(BaseConfig):
"us-south",
]
def get_error_class(
self, error_message: str, status_code: int, headers: Union[Dict, httpx.Headers]
) -> BaseLLMException:
return WatsonXAIError(
status_code=status_code, message=error_message, headers=headers
)
def transform_request(
self,
model: str,
@ -225,9 +236,28 @@ class IBMWatsonXAIConfig(BaseConfig):
litellm_params: Dict,
headers: Dict,
) -> Dict:
raise NotImplementedError(
"transform_request not implemented. Done in watsonx/completion handler.py"
provider = model.split("/")[0]
prompt = convert_watsonx_messages_to_prompt(
model=model,
messages=messages,
provider=provider,
custom_prompt_dict={},
)
extra_body_params = optional_params.pop("extra_body", {})
optional_params.update(extra_body_params)
watsonx_api_params = _get_api_params(params=optional_params)
# init the payload to the text generation call
payload = {
"input": prompt,
"moderations": optional_params.pop("moderations", {}),
"parameters": optional_params,
}
if not model.startswith("deployment/"):
payload["model_id"] = model
payload["project_id"] = watsonx_api_params["project_id"]
return payload
def transform_response(
self,
@ -243,22 +273,120 @@ class IBMWatsonXAIConfig(BaseConfig):
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
raise NotImplementedError(
"transform_response not implemented. Done in watsonx/completion handler.py"
## LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response=raw_response.text,
)
def validate_environment(
json_resp = raw_response.json()
if "results" not in json_resp:
raise WatsonXAIError(
status_code=500,
message=f"Error: Invalid response from Watsonx.ai API: {json_resp}",
)
if model_response is None:
model_response = ModelResponse(model=json_resp.get("model_id", None))
generated_text = json_resp["results"][0]["generated_text"]
prompt_tokens = json_resp["results"][0]["input_token_count"]
completion_tokens = json_resp["results"][0]["generated_token_count"]
model_response.choices[0].message.content = generated_text # type: ignore
model_response.choices[0].finish_reason = map_finish_reason(
json_resp["results"][0]["stop_reason"]
)
if json_resp.get("created_at"):
model_response.created = int(
datetime.fromisoformat(json_resp["created_at"]).timestamp()
)
else:
model_response.created = int(time.time())
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
def get_complete_url(
self,
headers: Dict,
api_base: str,
model: str,
messages: List[AllMessageValues],
optional_params: Dict,
api_key: Optional[str] = None,
) -> Dict:
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
return headers
optional_params: dict,
stream: Optional[bool] = None,
) -> str:
url = self._get_base_url(api_base=api_base)
if model.startswith("deployment/"):
# deployment models are passed in as 'deployment/<deployment_id>'
if optional_params.get("space_id") is None:
raise WatsonXAIError(
status_code=401,
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
)
deployment_id = "/".join(model.split("/")[1:])
endpoint = (
WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION_STREAM.value
if stream
else WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION.value
)
endpoint = endpoint.format(deployment_id=deployment_id)
else:
endpoint = (
WatsonXAIEndpoint.TEXT_GENERATION_STREAM
if stream
else WatsonXAIEndpoint.TEXT_GENERATION
)
url = url.rstrip("/") + endpoint
## add api version
url = self._add_api_version_to_url(
url=url, api_version=optional_params.pop("api_version", None)
)
return url
def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
sync_stream: bool,
json_mode: Optional[bool] = False,
):
return WatsonxTextCompletionResponseIterator(
streaming_response=streaming_response,
sync_stream=sync_stream,
json_mode=json_mode,
)
class WatsonxTextCompletionResponseIterator(BaseModelResponseIterator):
# def _handle_string_chunk(self, str_line: str) -> GenericStreamingChunk:
# return self.chunk_parser(json.loads(str_line))
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try:
results = chunk.get("results", [])
if len(results) > 0:
text = results[0].get("generated_text", "")
finish_reason = results[0].get("stop_reason")
is_finished = finish_reason != "not_finished"
return GenericStreamingChunk(
text=text,
is_finished=is_finished,
finish_reason=finish_reason,
usage=ChatCompletionUsageBlock(
prompt_tokens=results[0].get("input_token_count", 0),
completion_tokens=results[0].get("generated_token_count", 0),
total_tokens=results[0].get("input_token_count", 0)
+ results[0].get("generated_token_count", 0),
),
)
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="stop",
usage=None,
)
except Exception as e:
raise e

View file

@ -0,0 +1,116 @@
"""
Translates from OpenAI's `/v1/embeddings` to IBM's `/text/embeddings` route.
"""
from typing import Optional
import httpx
from litellm.llms.base_llm.embedding.transformation import (
BaseEmbeddingConfig,
LiteLLMLoggingObj,
)
from litellm.types.llms.openai import AllEmbeddingInputValues
from litellm.types.llms.watsonx import WatsonXAIEndpoint
from litellm.types.utils import EmbeddingResponse, Usage
from ..common_utils import IBMWatsonXMixin, WatsonXAIError, _get_api_params
class IBMWatsonXEmbeddingConfig(IBMWatsonXMixin, BaseEmbeddingConfig):
def get_supported_openai_params(self, model: str) -> list:
return []
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
return optional_params
def transform_embedding_request(
self,
model: str,
input: AllEmbeddingInputValues,
optional_params: dict,
headers: dict,
) -> dict:
watsonx_api_params = _get_api_params(params=optional_params)
project_id = watsonx_api_params["project_id"]
if not project_id:
raise ValueError("project_id is required")
return {
"inputs": input,
"model_id": model,
"project_id": project_id,
"parameters": optional_params,
}
def get_complete_url(
self,
api_base: Optional[str],
model: str,
optional_params: dict,
stream: Optional[bool] = None,
) -> str:
url = self._get_base_url(api_base=api_base)
endpoint = WatsonXAIEndpoint.EMBEDDINGS.value
if model.startswith("deployment/"):
# deployment models are passed in as 'deployment/<deployment_id>'
if optional_params.get("space_id") is None:
raise WatsonXAIError(
status_code=401,
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
)
deployment_id = "/".join(model.split("/")[1:])
endpoint = endpoint.format(deployment_id=deployment_id)
url = url.rstrip("/") + endpoint
## add api version
url = self._add_api_version_to_url(
url=url, api_version=optional_params.pop("api_version", None)
)
return url
def transform_embedding_response(
self,
model: str,
raw_response: httpx.Response,
model_response: EmbeddingResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str],
request_data: dict,
optional_params: dict,
litellm_params: dict,
) -> EmbeddingResponse:
logging_obj.post_call(
original_response=raw_response.text,
)
json_resp = raw_response.json()
if model_response is None:
model_response = EmbeddingResponse(model=json_resp.get("model_id", None))
results = json_resp.get("results", [])
embedding_response = []
for idx, result in enumerate(results):
embedding_response.append(
{
"object": "embedding",
"index": idx,
"embedding": result["embedding"],
}
)
model_response.object = "list"
model_response.data = embedding_response
input_tokens = json_resp.get("input_token_count", 0)
setattr(
model_response,
"usage",
Usage(
prompt_tokens=input_tokens,
completion_tokens=0,
total_tokens=input_tokens,
),
)
return model_response

View file

@ -145,7 +145,7 @@ from .llms.vertex_ai.vertex_embeddings.embedding_handler import VertexEmbedding
from .llms.vertex_ai.vertex_model_garden.main import VertexAIModelGardenModels
from .llms.vllm.completion import handler as vllm_handler
from .llms.watsonx.chat.handler import WatsonXChatHandler
from .llms.watsonx.completion.handler import IBMWatsonXAI
from .llms.watsonx.common_utils import IBMWatsonXMixin
from .types.llms.openai import (
ChatCompletionAssistantMessage,
ChatCompletionAudioParam,
@ -205,7 +205,6 @@ google_batch_embeddings = GoogleBatchEmbeddings()
vertex_partner_models_chat_completion = VertexAIPartnerModels()
vertex_model_garden_chat_completion = VertexAIModelGardenModels()
vertex_text_to_speech = VertexTextToSpeechAPI()
watsonxai = IBMWatsonXAI()
sagemaker_llm = SagemakerLLM()
watsonx_chat_completion = WatsonXChatHandler()
openai_like_embedding = OpenAILikeEmbeddingHandler()
@ -2585,43 +2584,68 @@ def completion( # type: ignore # noqa: PLR0915
custom_llm_provider="watsonx",
)
elif custom_llm_provider == "watsonx_text":
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
response = watsonxai.completion(
model=model,
messages=messages,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params, # type: ignore
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
timeout=timeout, # type: ignore
acompletion=acompletion,
api_key = (
api_key
or optional_params.pop("apikey", None)
or get_secret_str("WATSONX_APIKEY")
or get_secret_str("WATSONX_API_KEY")
or get_secret_str("WX_API_KEY")
)
if (
"stream" in optional_params
and optional_params["stream"] is True
and not isinstance(response, CustomStreamWrapper)
):
# don't try to access stream object,
response = CustomStreamWrapper(
iter(response),
model,
custom_llm_provider="watsonx",
logging_obj=logging,
api_base = (
api_base
or optional_params.pop(
"url",
optional_params.pop(
"api_base", optional_params.pop("base_url", None)
),
)
or get_secret_str("WATSONX_API_BASE")
or get_secret_str("WATSONX_URL")
or get_secret_str("WX_URL")
or get_secret_str("WML_URL")
)
wx_credentials = optional_params.pop(
"wx_credentials",
optional_params.pop(
"watsonx_credentials", None
), # follow {provider}_credentials, same as vertex ai
)
token: Optional[str] = None
if wx_credentials is not None:
api_base = wx_credentials.get("url", api_base)
api_key = wx_credentials.get(
"apikey", wx_credentials.get("api_key", api_key)
)
token = wx_credentials.get(
"token",
wx_credentials.get(
"watsonx_token", None
), # follow format of {provider}_token, same as azure - e.g. 'azure_ad_token=..'
)
if optional_params.get("stream", False):
## LOGGING
logging.post_call(
input=messages,
api_key=None,
original_response=response,
)
## RESPONSE OBJECT
response = response
if token is not None:
optional_params["token"] = token
response = base_llm_http_handler.completion(
model=model,
stream=stream,
messages=messages,
acompletion=acompletion,
api_base=api_base,
model_response=model_response,
optional_params=optional_params,
litellm_params=litellm_params,
custom_llm_provider="watsonx_text",
timeout=timeout,
headers=headers,
encoding=encoding,
api_key=api_key,
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
client=client,
)
elif custom_llm_provider == "vllm":
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
model_response = vllm_handler.completion(
@ -3485,6 +3509,7 @@ def embedding( # noqa: PLR0915
optional_params=optional_params,
client=client,
aembedding=aembedding,
litellm_params={},
)
elif custom_llm_provider == "gemini":
gemini_api_key = (
@ -3661,6 +3686,32 @@ def embedding( # noqa: PLR0915
optional_params=optional_params,
client=client,
aembedding=aembedding,
litellm_params={},
)
elif custom_llm_provider == "watsonx":
credentials = IBMWatsonXMixin.get_watsonx_credentials(
optional_params=optional_params, api_key=api_key, api_base=api_base
)
api_key = credentials["api_key"]
api_base = credentials["api_base"]
if "token" in credentials:
optional_params["token"] = credentials["token"]
response = base_llm_http_handler.embedding(
model=model,
input=input,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
api_key=api_key,
logging_obj=logging,
timeout=timeout,
model_response=EmbeddingResponse(),
optional_params=optional_params,
litellm_params={},
client=client,
aembedding=aembedding,
)
elif custom_llm_provider == "xinference":
api_key = (
@ -3687,17 +3738,6 @@ def embedding( # noqa: PLR0915
client=client,
aembedding=aembedding,
)
elif custom_llm_provider == "watsonx":
response = watsonxai.embedding(
model=model,
input=input,
encoding=encoding,
logging_obj=logging,
optional_params=optional_params,
model_response=EmbeddingResponse(),
aembedding=aembedding,
api_key=api_key,
)
elif custom_llm_provider == "azure_ai":
api_base = (
api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there

View file

@ -611,3 +611,6 @@ class FineTuningJobCreate(BaseModel):
class LiteLLMFineTuningJobCreate(FineTuningJobCreate):
custom_llm_provider: Literal["openai", "azure", "vertex_ai"]
AllEmbeddingInputValues = Union[str, List[str], List[int], List[List[int]]]

View file

@ -6,13 +6,15 @@ from pydantic import BaseModel
class WatsonXAPIParams(TypedDict):
url: str
api_key: Optional[str]
token: str
project_id: str
space_id: Optional[str]
region_name: Optional[str]
api_version: str
class WatsonXCredentials(TypedDict):
api_key: str
api_base: str
token: Optional[str]
class WatsonXAIEndpoint(str, Enum):

View file

@ -808,6 +808,8 @@ class ModelResponseStream(ModelResponseBase):
def __init__(
self,
choices: Optional[List[Union[StreamingChoices, dict, BaseModel]]] = None,
id: Optional[str] = None,
created: Optional[int] = None,
**kwargs,
):
if choices is not None and isinstance(choices, list):
@ -824,6 +826,20 @@ class ModelResponseStream(ModelResponseBase):
kwargs["choices"] = new_choices
else:
kwargs["choices"] = [StreamingChoices()]
if id is None:
id = _generate_id()
else:
id = id
if created is None:
created = int(time.time())
else:
created = created
kwargs["id"] = id
kwargs["created"] = created
kwargs["object"] = "chat.completion.chunk"
super().__init__(**kwargs)
def __contains__(self, key):

View file

@ -6244,7 +6244,9 @@ class ProviderConfigManager:
return litellm.VoyageEmbeddingConfig()
elif litellm.LlmProviders.TRITON == provider:
return litellm.TritonEmbeddingConfig()
raise ValueError(f"Provider {provider} does not support embedding config")
elif litellm.LlmProviders.WATSONX == provider:
return litellm.IBMWatsonXEmbeddingConfig()
raise ValueError(f"Provider {provider.value} does not support embedding config")
@staticmethod
def get_provider_rerank_config(

2542
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -26,7 +26,6 @@ tokenizers = "*"
click = "*"
jinja2 = "^3.1.2"
aiohttp = "*"
requests = "^2.31.0"
pydantic = "^2.0.0"
jsonschema = "^4.22.0"

View file

@ -133,7 +133,7 @@ def test_completion_xai(stream):
for chunk in response:
print(chunk)
assert chunk is not None
assert isinstance(chunk, litellm.ModelResponse)
assert isinstance(chunk, litellm.ModelResponseStream)
assert isinstance(chunk.choices[0], litellm.utils.StreamingChoices)
else:

View file

@ -1,104 +0,0 @@
============================= test session starts ==============================
platform darwin -- Python 3.11.4, pytest-8.3.2, pluggy-1.5.0 -- /Users/krrishdholakia/Documents/litellm/myenv/bin/python3.11
cachedir: .pytest_cache
rootdir: /Users/krrishdholakia/Documents/litellm
configfile: pyproject.toml
plugins: asyncio-0.23.8, respx-0.21.1, anyio-4.6.0
asyncio: mode=Mode.STRICT
collecting ... collected 1 item
test_function_calling.py::test_aaparallel_function_call[claude-3-haiku-20240307] <module 'litellm' from '/Users/krrishdholakia/Documents/litellm/litellm/__init__.py'>
Request to litellm:
litellm.completion(model='claude-3-haiku-20240307', messages=[{'role': 'user', 'content': "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses"}], tools=[{'type': 'function', 'function': {'name': 'get_current_weather', 'description': 'Get the current weather in a given location', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}}], tool_choice='auto')
SYNC kwargs[caching]: False; litellm.cache: None; kwargs.get('cache')['no-cache']: False
Final returned optional params: {'tools': [{'type': 'function', 'function': {'name': 'get_current_weather', 'description': 'Get the current weather in a given location', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}}], 'tool_choice': {'type': 'auto'}}
optional_params: {'tools': [{'type': 'function', 'function': {'name': 'get_current_weather', 'description': 'Get the current weather in a given location', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}}], 'tool_choice': {'type': 'auto'}}
SENT optional_params: {'tools': [{'type': 'function', 'function': {'name': 'get_current_weather', 'description': 'Get the current weather in a given location', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}}], 'tool_choice': {'type': 'auto'}, 'max_tokens': 4096}
tool: {'type': 'function', 'function': {'name': 'get_current_weather', 'description': 'Get the current weather in a given location', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}}

POST Request Sent from LiteLLM:
curl -X POST \
https://api.anthropic.com/v1/messages \
-H 'accept: *****' -H 'anthropic-version: *****' -H 'content-type: *****' -H 'x-api-key: sk-ant-api03-bJf1M8qp-JDptRcZRE5ve5efAfSIaL5u-SZ9vItIkvuFcV5cUsd********************************************' -H 'anthropic-beta: *****' \
-d '{'messages': [{'role': 'user', 'content': [{'type': 'text', 'text': "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses"}]}], 'tools': [{'name': 'get_current_weather', 'description': 'Get the current weather in a given location', 'input_schema': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}], 'tool_choice': {'type': 'auto'}, 'max_tokens': 4096, 'model': 'claude-3-haiku-20240307'}'

_is_function_call: False
RAW RESPONSE:
{"id":"msg_01HRugqzL4WmcxMmbvDheTph","type":"message","role":"assistant","model":"claude-3-haiku-20240307","content":[{"type":"text","text":"Okay, let's check the current weather in those three cities:"},{"type":"tool_use","id":"toolu_016U6G3kpxjHSiJLwVCrrScz","name":"get_current_weather","input":{"location":"San Francisco","unit":"celsius"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":379,"output_tokens":87}}
raw model_response: {"id":"msg_01HRugqzL4WmcxMmbvDheTph","type":"message","role":"assistant","model":"claude-3-haiku-20240307","content":[{"type":"text","text":"Okay, let's check the current weather in those three cities:"},{"type":"tool_use","id":"toolu_016U6G3kpxjHSiJLwVCrrScz","name":"get_current_weather","input":{"location":"San Francisco","unit":"celsius"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":379,"output_tokens":87}}
Logging Details LiteLLM-Success Call: Cache_hit=None
Looking up model=claude-3-haiku-20240307 in model_cost_map
Looking up model=claude-3-haiku-20240307 in model_cost_map
Response
ModelResponse(id='chatcmpl-7222f6c2-962a-4776-8639-576723466cb7', choices=[Choices(finish_reason='tool_calls', index=0, message=Message(content="Okay, let's check the current weather in those three cities:", role='assistant', tool_calls=[ChatCompletionMessageToolCall(index=1, function=Function(arguments='{"location": "San Francisco", "unit": "celsius"}', name='get_current_weather'), id='toolu_016U6G3kpxjHSiJLwVCrrScz', type='function')], function_call=None))], created=1727897483, model='claude-3-haiku-20240307', object='chat.completion', system_fingerprint=None, usage=Usage(completion_tokens=87, prompt_tokens=379, total_tokens=466, completion_tokens_details=None))
length of tool calls 1
Expecting there to be 3 tool calls
tool_calls: [ChatCompletionMessageToolCall(index=1, function=Function(arguments='{"location": "San Francisco", "unit": "celsius"}', name='get_current_weather'), id='toolu_016U6G3kpxjHSiJLwVCrrScz', type='function')]
Response message
Message(content="Okay, let's check the current weather in those three cities:", role='assistant', tool_calls=[ChatCompletionMessageToolCall(index=1, function=Function(arguments='{"location": "San Francisco", "unit": "celsius"}', name='get_current_weather'), id='toolu_016U6G3kpxjHSiJLwVCrrScz', type='function')], function_call=None)
messages: [{'role': 'user', 'content': "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses"}, Message(content="Okay, let's check the current weather in those three cities:", role='assistant', tool_calls=[ChatCompletionMessageToolCall(index=1, function=Function(arguments='{"location": "San Francisco", "unit": "celsius"}', name='get_current_weather'), id='toolu_016U6G3kpxjHSiJLwVCrrScz', type='function')], function_call=None), {'tool_call_id': 'toolu_016U6G3kpxjHSiJLwVCrrScz', 'role': 'tool', 'name': 'get_current_weather', 'content': '{"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"}'}]
Request to litellm:
litellm.completion(model='claude-3-haiku-20240307', messages=[{'role': 'user', 'content': "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses"}, Message(content="Okay, let's check the current weather in those three cities:", role='assistant', tool_calls=[ChatCompletionMessageToolCall(index=1, function=Function(arguments='{"location": "San Francisco", "unit": "celsius"}', name='get_current_weather'), id='toolu_016U6G3kpxjHSiJLwVCrrScz', type='function')], function_call=None), {'tool_call_id': 'toolu_016U6G3kpxjHSiJLwVCrrScz', 'role': 'tool', 'name': 'get_current_weather', 'content': '{"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"}'}], temperature=0.2, seed=22, drop_params=True)
SYNC kwargs[caching]: False; litellm.cache: None; kwargs.get('cache')['no-cache']: False
Final returned optional params: {'temperature': 0.2, 'tools': [{'type': 'function', 'function': {'name': 'dummy-tool', 'description': '', 'parameters': {'type': 'object', 'properties': {}}}}]}
optional_params: {'temperature': 0.2, 'tools': [{'type': 'function', 'function': {'name': 'dummy-tool', 'description': '', 'parameters': {'type': 'object', 'properties': {}}}}]}
SENT optional_params: {'temperature': 0.2, 'tools': [{'type': 'function', 'function': {'name': 'dummy-tool', 'description': '', 'parameters': {'type': 'object', 'properties': {}}}}], 'max_tokens': 4096}
tool: {'type': 'function', 'function': {'name': 'dummy-tool', 'description': '', 'parameters': {'type': 'object', 'properties': {}}}}

POST Request Sent from LiteLLM:
curl -X POST \
https://api.anthropic.com/v1/messages \
-H 'accept: *****' -H 'anthropic-version: *****' -H 'content-type: *****' -H 'x-api-key: sk-ant-api03-bJf1M8qp-JDptRcZRE5ve5efAfSIaL5u-SZ9vItIkvuFcV5cUsd********************************************' -H 'anthropic-beta: *****' \
-d '{'messages': [{'role': 'user', 'content': [{'type': 'text', 'text': "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses"}]}, {'role': 'assistant', 'content': [{'type': 'tool_use', 'id': 'toolu_016U6G3kpxjHSiJLwVCrrScz', 'name': 'get_current_weather', 'input': {'location': 'San Francisco', 'unit': 'celsius'}}]}, {'role': 'user', 'content': [{'type': 'tool_result', 'tool_use_id': 'toolu_016U6G3kpxjHSiJLwVCrrScz', 'content': '{"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"}'}]}], 'temperature': 0.2, 'tools': [{'name': 'dummy-tool', 'description': '', 'input_schema': {'type': 'object', 'properties': {}}}], 'max_tokens': 4096, 'model': 'claude-3-haiku-20240307'}'

_is_function_call: False
RAW RESPONSE:
{"id":"msg_01Wp8NVScugz6yAGsmB5trpZ","type":"message","role":"assistant","model":"claude-3-haiku-20240307","content":[{"type":"text","text":"The current weather in San Francisco is 72°F (22°C)."},{"type":"tool_use","id":"toolu_01HTXEYDX4MspM76STtJqs1n","name":"get_current_weather","input":{"location":"Tokyo","unit":"celsius"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":426,"output_tokens":90}}
raw model_response: {"id":"msg_01Wp8NVScugz6yAGsmB5trpZ","type":"message","role":"assistant","model":"claude-3-haiku-20240307","content":[{"type":"text","text":"The current weather in San Francisco is 72°F (22°C)."},{"type":"tool_use","id":"toolu_01HTXEYDX4MspM76STtJqs1n","name":"get_current_weather","input":{"location":"Tokyo","unit":"celsius"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":426,"output_tokens":90}}
Logging Details LiteLLM-Success Call: Cache_hit=None
Looking up model=claude-3-haiku-20240307 in model_cost_map
Looking up model=claude-3-haiku-20240307 in model_cost_map
second response
ModelResponse(id='chatcmpl-c4ed5c25-ba7c-49e5-a6be-5720ab25fff0', choices=[Choices(finish_reason='tool_calls', index=0, message=Message(content='The current weather in San Francisco is 72°F (22°C).', role='assistant', tool_calls=[ChatCompletionMessageToolCall(index=1, function=Function(arguments='{"location": "Tokyo", "unit": "celsius"}', name='get_current_weather'), id='toolu_01HTXEYDX4MspM76STtJqs1n', type='function')], function_call=None))], created=1727897484, model='claude-3-haiku-20240307', object='chat.completion', system_fingerprint=None, usage=Usage(completion_tokens=90, prompt_tokens=426, total_tokens=516, completion_tokens_details=None))
PASSED
=============================== warnings summary ===============================
../../myenv/lib/python3.11/site-packages/pydantic/_internal/_config.py:284
/Users/krrishdholakia/Documents/litellm/myenv/lib/python3.11/site-packages/pydantic/_internal/_config.py:284: PydanticDeprecatedSince20: Support for class-based `config` is deprecated, use ConfigDict instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/
warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning)
../../litellm/utils.py:17
/Users/krrishdholakia/Documents/litellm/litellm/utils.py:17: DeprecationWarning: 'imghdr' is deprecated and slated for removal in Python 3.13
import imghdr
../../litellm/utils.py:124
/Users/krrishdholakia/Documents/litellm/litellm/utils.py:124: DeprecationWarning: open_text is deprecated. Use files() instead. Refer to https://importlib-resources.readthedocs.io/en/latest/using.html#migrating-from-legacy for migration advice.
with resources.open_text("litellm.llms.tokenizers", "anthropic_tokenizer.json") as f:
test_function_calling.py:56
/Users/krrishdholakia/Documents/litellm/tests/local_testing/test_function_calling.py:56: PytestUnknownMarkWarning: Unknown pytest.mark.flaky - is this a typo? You can register custom marks to avoid this warning - for details, see https://docs.pytest.org/en/stable/how-to/mark.html
@pytest.mark.flaky(retries=3, delay=1)
tests/local_testing/test_function_calling.py::test_aaparallel_function_call[claude-3-haiku-20240307]
tests/local_testing/test_function_calling.py::test_aaparallel_function_call[claude-3-haiku-20240307]
/Users/krrishdholakia/Documents/litellm/myenv/lib/python3.11/site-packages/httpx/_content.py:202: DeprecationWarning: Use 'content=<...>' to upload raw bytes/text content.
warnings.warn(message, DeprecationWarning)
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================== 1 passed, 6 warnings in 1.89s =========================

View file

@ -1805,7 +1805,7 @@ async def test_gemini_pro_function_calling_streaming(sync_mode):
for chunk in response:
chunks.append(chunk)
assert isinstance(chunk, litellm.ModelResponse)
assert isinstance(chunk, litellm.ModelResponseStream)
else:
response = await litellm.acompletion(**data)
print(f"completion: {response}")
@ -1815,7 +1815,7 @@ async def test_gemini_pro_function_calling_streaming(sync_mode):
async for chunk in response:
print(f"chunk: {chunk}")
chunks.append(chunk)
assert isinstance(chunk, litellm.ModelResponse)
assert isinstance(chunk, litellm.ModelResponseStream)
complete_response = litellm.stream_chunk_builder(chunks=chunks)
assert (

View file

@ -4019,19 +4019,20 @@ def test_completion_deepseek():
@pytest.mark.skip(reason="Account deleted by IBM.")
def test_completion_watsonx_error():
litellm.set_verbose = True
model_name = "watsonx/ibm/granite-13b-chat-v2"
model_name = "watsonx_text/ibm/granite-13b-chat-v2"
with pytest.raises(litellm.BadRequestError) as e:
response = completion(
model=model_name,
messages=messages,
stop=["stop"],
max_tokens=20,
)
# Add any assertions here to check the response
print(response)
response = completion(
model=model_name,
messages=messages,
stop=["stop"],
max_tokens=20,
stream=True,
)
assert "use 'watsonx_text' route instead" in str(e).lower()
for chunk in response:
print(chunk)
# Add any assertions here to check the response
print(response)
@pytest.mark.skip(reason="Skip test. account deleted.")

View file

@ -135,7 +135,7 @@ class CompletionCustomHandler(
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert isinstance(response_obj, litellm.ModelResponse)
assert isinstance(response_obj, litellm.ModelResponseStream)
## KWARGS
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list) and isinstance(

View file

@ -153,7 +153,7 @@ class CompletionCustomHandler(
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert isinstance(response_obj, litellm.ModelResponse)
assert isinstance(response_obj, litellm.ModelResponseStream)
## KWARGS
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list) and isinstance(

View file

@ -122,8 +122,13 @@ async def test_async_chat_openai_stream():
complete_streaming_response = complete_streaming_response.strip("'")
print(f"complete_streaming_response: {complete_streaming_response}")
await asyncio.sleep(3)
print(
f"tmp_function.complete_streaming_response_in_callback: {tmp_function.complete_streaming_response_in_callback}"
)
# problematic line
response1 = tmp_function.complete_streaming_response_in_callback["choices"][0][
"message"

View file

@ -801,8 +801,11 @@ def test_fireworks_embeddings():
def test_watsonx_embeddings():
from litellm.llms.custom_httpx.http_handler import HTTPHandler
def mock_wx_embed_request(method: str, url: str, **kwargs):
client = HTTPHandler()
def mock_wx_embed_request(url: str, **kwargs):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}
@ -816,12 +819,14 @@ def test_watsonx_embeddings():
try:
litellm.set_verbose = True
with patch("requests.request", side_effect=mock_wx_embed_request):
with patch.object(client, "post", side_effect=mock_wx_embed_request):
response = litellm.embedding(
model="watsonx/ibm/slate-30m-english-rtrvr",
input=["good morning from litellm"],
token="secret-token",
client=client,
)
print(f"response: {response}")
assert isinstance(response.usage, litellm.Usage)
except litellm.RateLimitError as e:
@ -832,6 +837,9 @@ def test_watsonx_embeddings():
@pytest.mark.asyncio
async def test_watsonx_aembeddings():
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
client = AsyncHTTPHandler()
def mock_async_client(*args, **kwargs):
@ -856,12 +864,14 @@ async def test_watsonx_aembeddings():
try:
litellm.set_verbose = True
with patch("httpx.AsyncClient", side_effect=mock_async_client):
with patch.object(client, "post", side_effect=mock_async_client) as mock_client:
response = await litellm.aembedding(
model="watsonx/ibm/slate-30m-english-rtrvr",
input=["good morning from litellm"],
token="secret-token",
client=client,
)
mock_client.assert_called_once()
print(f"response: {response}")
assert isinstance(response.usage, litellm.Usage)
except litellm.RateLimitError as e:

View file

@ -17,6 +17,7 @@ from pydantic import BaseModel
import litellm.litellm_core_utils
import litellm.litellm_core_utils.litellm_logging
from litellm.utils import ModelResponseListIterator
from litellm.types.utils import ModelResponseStream
sys.path.insert(
0, os.path.abspath("../..")
@ -69,7 +70,7 @@ first_openai_chunk_example = {
def validate_first_format(chunk):
# write a test to make sure chunk follows the same format as first_openai_chunk_example
assert isinstance(chunk, ModelResponse), "Chunk should be a dictionary."
assert isinstance(chunk, ModelResponseStream), "Chunk should be a dictionary."
assert isinstance(chunk["id"], str), "'id' should be a string."
assert isinstance(chunk["object"], str), "'object' should be a string."
assert isinstance(chunk["created"], int), "'created' should be an integer."
@ -99,7 +100,7 @@ second_openai_chunk_example = {
def validate_second_format(chunk):
assert isinstance(chunk, ModelResponse), "Chunk should be a dictionary."
assert isinstance(chunk, ModelResponseStream), "Chunk should be a dictionary."
assert isinstance(chunk["id"], str), "'id' should be a string."
assert isinstance(chunk["object"], str), "'object' should be a string."
assert isinstance(chunk["created"], int), "'created' should be an integer."
@ -137,7 +138,7 @@ def validate_last_format(chunk):
"""
Ensure last chunk has no remaining content or tools
"""
assert isinstance(chunk, ModelResponse), "Chunk should be a dictionary."
assert isinstance(chunk, ModelResponseStream), "Chunk should be a dictionary."
assert isinstance(chunk["id"], str), "'id' should be a string."
assert isinstance(chunk["object"], str), "'object' should be a string."
assert isinstance(chunk["created"], int), "'created' should be an integer."
@ -1523,7 +1524,7 @@ async def test_parallel_streaming_requests(sync_mode, model):
num_finish_reason = 0
for chunk in response:
print(f"chunk: {chunk}")
if isinstance(chunk, ModelResponse):
if isinstance(chunk, ModelResponseStream):
if chunk.choices[0].finish_reason is not None:
num_finish_reason += 1
assert num_finish_reason == 1
@ -1541,7 +1542,7 @@ async def test_parallel_streaming_requests(sync_mode, model):
num_finish_reason = 0
async for chunk in response:
print(f"type of chunk: {type(chunk)}")
if isinstance(chunk, ModelResponse):
if isinstance(chunk, ModelResponseStream):
print(f"OUTSIDE CHUNK: {chunk.choices[0]}")
if chunk.choices[0].finish_reason is not None:
num_finish_reason += 1

View file

@ -1,104 +0,0 @@
============================= test session starts ==============================
platform darwin -- Python 3.11.4, pytest-8.3.2, pluggy-1.5.0 -- /Users/krrishdholakia/Documents/litellm/myenv/bin/python3.11
cachedir: .pytest_cache
rootdir: /Users/krrishdholakia/Documents/litellm
configfile: pyproject.toml
plugins: asyncio-0.23.8, respx-0.21.1, anyio-4.6.0
asyncio: mode=Mode.STRICT
collecting ... collected 1 item
test_function_calling.py::test_aaparallel_function_call[claude-3-haiku-20240307] <module 'litellm' from '/Users/krrishdholakia/Documents/litellm/litellm/__init__.py'>
Request to litellm:
litellm.completion(model='claude-3-haiku-20240307', messages=[{'role': 'user', 'content': "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses"}], tools=[{'type': 'function', 'function': {'name': 'get_current_weather', 'description': 'Get the current weather in a given location', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}}], tool_choice='auto')
SYNC kwargs[caching]: False; litellm.cache: None; kwargs.get('cache')['no-cache']: False
Final returned optional params: {'tools': [{'type': 'function', 'function': {'name': 'get_current_weather', 'description': 'Get the current weather in a given location', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}}], 'tool_choice': {'type': 'auto'}}
optional_params: {'tools': [{'type': 'function', 'function': {'name': 'get_current_weather', 'description': 'Get the current weather in a given location', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}}], 'tool_choice': {'type': 'auto'}}
SENT optional_params: {'tools': [{'type': 'function', 'function': {'name': 'get_current_weather', 'description': 'Get the current weather in a given location', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}}], 'tool_choice': {'type': 'auto'}, 'max_tokens': 4096}
tool: {'type': 'function', 'function': {'name': 'get_current_weather', 'description': 'Get the current weather in a given location', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}}

POST Request Sent from LiteLLM:
curl -X POST \
https://api.anthropic.com/v1/messages \
-H 'accept: *****' -H 'anthropic-version: *****' -H 'content-type: *****' -H 'x-api-key: sk-ant-api03-bJf1M8qp-JDptRcZRE5ve5efAfSIaL5u-SZ9vItIkvuFcV5cUsd********************************************' -H 'anthropic-beta: *****' \
-d '{'messages': [{'role': 'user', 'content': [{'type': 'text', 'text': "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses"}]}], 'tools': [{'name': 'get_current_weather', 'description': 'Get the current weather in a given location', 'input_schema': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}], 'tool_choice': {'type': 'auto'}, 'max_tokens': 4096, 'model': 'claude-3-haiku-20240307'}'

_is_function_call: False
RAW RESPONSE:
{"id":"msg_01HRugqzL4WmcxMmbvDheTph","type":"message","role":"assistant","model":"claude-3-haiku-20240307","content":[{"type":"text","text":"Okay, let's check the current weather in those three cities:"},{"type":"tool_use","id":"toolu_016U6G3kpxjHSiJLwVCrrScz","name":"get_current_weather","input":{"location":"San Francisco","unit":"celsius"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":379,"output_tokens":87}}
raw model_response: {"id":"msg_01HRugqzL4WmcxMmbvDheTph","type":"message","role":"assistant","model":"claude-3-haiku-20240307","content":[{"type":"text","text":"Okay, let's check the current weather in those three cities:"},{"type":"tool_use","id":"toolu_016U6G3kpxjHSiJLwVCrrScz","name":"get_current_weather","input":{"location":"San Francisco","unit":"celsius"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":379,"output_tokens":87}}
Logging Details LiteLLM-Success Call: Cache_hit=None
Looking up model=claude-3-haiku-20240307 in model_cost_map
Looking up model=claude-3-haiku-20240307 in model_cost_map
Response
ModelResponse(id='chatcmpl-7222f6c2-962a-4776-8639-576723466cb7', choices=[Choices(finish_reason='tool_calls', index=0, message=Message(content="Okay, let's check the current weather in those three cities:", role='assistant', tool_calls=[ChatCompletionMessageToolCall(index=1, function=Function(arguments='{"location": "San Francisco", "unit": "celsius"}', name='get_current_weather'), id='toolu_016U6G3kpxjHSiJLwVCrrScz', type='function')], function_call=None))], created=1727897483, model='claude-3-haiku-20240307', object='chat.completion', system_fingerprint=None, usage=Usage(completion_tokens=87, prompt_tokens=379, total_tokens=466, completion_tokens_details=None))
length of tool calls 1
Expecting there to be 3 tool calls
tool_calls: [ChatCompletionMessageToolCall(index=1, function=Function(arguments='{"location": "San Francisco", "unit": "celsius"}', name='get_current_weather'), id='toolu_016U6G3kpxjHSiJLwVCrrScz', type='function')]
Response message
Message(content="Okay, let's check the current weather in those three cities:", role='assistant', tool_calls=[ChatCompletionMessageToolCall(index=1, function=Function(arguments='{"location": "San Francisco", "unit": "celsius"}', name='get_current_weather'), id='toolu_016U6G3kpxjHSiJLwVCrrScz', type='function')], function_call=None)
messages: [{'role': 'user', 'content': "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses"}, Message(content="Okay, let's check the current weather in those three cities:", role='assistant', tool_calls=[ChatCompletionMessageToolCall(index=1, function=Function(arguments='{"location": "San Francisco", "unit": "celsius"}', name='get_current_weather'), id='toolu_016U6G3kpxjHSiJLwVCrrScz', type='function')], function_call=None), {'tool_call_id': 'toolu_016U6G3kpxjHSiJLwVCrrScz', 'role': 'tool', 'name': 'get_current_weather', 'content': '{"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"}'}]
Request to litellm:
litellm.completion(model='claude-3-haiku-20240307', messages=[{'role': 'user', 'content': "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses"}, Message(content="Okay, let's check the current weather in those three cities:", role='assistant', tool_calls=[ChatCompletionMessageToolCall(index=1, function=Function(arguments='{"location": "San Francisco", "unit": "celsius"}', name='get_current_weather'), id='toolu_016U6G3kpxjHSiJLwVCrrScz', type='function')], function_call=None), {'tool_call_id': 'toolu_016U6G3kpxjHSiJLwVCrrScz', 'role': 'tool', 'name': 'get_current_weather', 'content': '{"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"}'}], temperature=0.2, seed=22, drop_params=True)
SYNC kwargs[caching]: False; litellm.cache: None; kwargs.get('cache')['no-cache']: False
Final returned optional params: {'temperature': 0.2, 'tools': [{'type': 'function', 'function': {'name': 'dummy-tool', 'description': '', 'parameters': {'type': 'object', 'properties': {}}}}]}
optional_params: {'temperature': 0.2, 'tools': [{'type': 'function', 'function': {'name': 'dummy-tool', 'description': '', 'parameters': {'type': 'object', 'properties': {}}}}]}
SENT optional_params: {'temperature': 0.2, 'tools': [{'type': 'function', 'function': {'name': 'dummy-tool', 'description': '', 'parameters': {'type': 'object', 'properties': {}}}}], 'max_tokens': 4096}
tool: {'type': 'function', 'function': {'name': 'dummy-tool', 'description': '', 'parameters': {'type': 'object', 'properties': {}}}}

POST Request Sent from LiteLLM:
curl -X POST \
https://api.anthropic.com/v1/messages \
-H 'accept: *****' -H 'anthropic-version: *****' -H 'content-type: *****' -H 'x-api-key: sk-ant-api03-bJf1M8qp-JDptRcZRE5ve5efAfSIaL5u-SZ9vItIkvuFcV5cUsd********************************************' -H 'anthropic-beta: *****' \
-d '{'messages': [{'role': 'user', 'content': [{'type': 'text', 'text': "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses"}]}, {'role': 'assistant', 'content': [{'type': 'tool_use', 'id': 'toolu_016U6G3kpxjHSiJLwVCrrScz', 'name': 'get_current_weather', 'input': {'location': 'San Francisco', 'unit': 'celsius'}}]}, {'role': 'user', 'content': [{'type': 'tool_result', 'tool_use_id': 'toolu_016U6G3kpxjHSiJLwVCrrScz', 'content': '{"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"}'}]}], 'temperature': 0.2, 'tools': [{'name': 'dummy-tool', 'description': '', 'input_schema': {'type': 'object', 'properties': {}}}], 'max_tokens': 4096, 'model': 'claude-3-haiku-20240307'}'

_is_function_call: False
RAW RESPONSE:
{"id":"msg_01Wp8NVScugz6yAGsmB5trpZ","type":"message","role":"assistant","model":"claude-3-haiku-20240307","content":[{"type":"text","text":"The current weather in San Francisco is 72°F (22°C)."},{"type":"tool_use","id":"toolu_01HTXEYDX4MspM76STtJqs1n","name":"get_current_weather","input":{"location":"Tokyo","unit":"celsius"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":426,"output_tokens":90}}
raw model_response: {"id":"msg_01Wp8NVScugz6yAGsmB5trpZ","type":"message","role":"assistant","model":"claude-3-haiku-20240307","content":[{"type":"text","text":"The current weather in San Francisco is 72°F (22°C)."},{"type":"tool_use","id":"toolu_01HTXEYDX4MspM76STtJqs1n","name":"get_current_weather","input":{"location":"Tokyo","unit":"celsius"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":426,"output_tokens":90}}
Logging Details LiteLLM-Success Call: Cache_hit=None
Looking up model=claude-3-haiku-20240307 in model_cost_map
Looking up model=claude-3-haiku-20240307 in model_cost_map
second response
ModelResponse(id='chatcmpl-c4ed5c25-ba7c-49e5-a6be-5720ab25fff0', choices=[Choices(finish_reason='tool_calls', index=0, message=Message(content='The current weather in San Francisco is 72°F (22°C).', role='assistant', tool_calls=[ChatCompletionMessageToolCall(index=1, function=Function(arguments='{"location": "Tokyo", "unit": "celsius"}', name='get_current_weather'), id='toolu_01HTXEYDX4MspM76STtJqs1n', type='function')], function_call=None))], created=1727897484, model='claude-3-haiku-20240307', object='chat.completion', system_fingerprint=None, usage=Usage(completion_tokens=90, prompt_tokens=426, total_tokens=516, completion_tokens_details=None))
PASSED
=============================== warnings summary ===============================
../../myenv/lib/python3.11/site-packages/pydantic/_internal/_config.py:284
/Users/krrishdholakia/Documents/litellm/myenv/lib/python3.11/site-packages/pydantic/_internal/_config.py:284: PydanticDeprecatedSince20: Support for class-based `config` is deprecated, use ConfigDict instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/
warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning)
../../litellm/utils.py:17
/Users/krrishdholakia/Documents/litellm/litellm/utils.py:17: DeprecationWarning: 'imghdr' is deprecated and slated for removal in Python 3.13
import imghdr
../../litellm/utils.py:124
/Users/krrishdholakia/Documents/litellm/litellm/utils.py:124: DeprecationWarning: open_text is deprecated. Use files() instead. Refer to https://importlib-resources.readthedocs.io/en/latest/using.html#migrating-from-legacy for migration advice.
with resources.open_text("litellm.llms.tokenizers", "anthropic_tokenizer.json") as f:
test_function_calling.py:56
/Users/krrishdholakia/Documents/litellm/tests/local_testing/test_function_calling.py:56: PytestUnknownMarkWarning: Unknown pytest.mark.flaky - is this a typo? You can register custom marks to avoid this warning - for details, see https://docs.pytest.org/en/stable/how-to/mark.html
@pytest.mark.flaky(retries=3, delay=1)
tests/local_testing/test_function_calling.py::test_aaparallel_function_call[claude-3-haiku-20240307]
tests/local_testing/test_function_calling.py::test_aaparallel_function_call[claude-3-haiku-20240307]
/Users/krrishdholakia/Documents/litellm/myenv/lib/python3.11/site-packages/httpx/_content.py:202: DeprecationWarning: Use 'content=<...>' to upload raw bytes/text content.
warnings.warn(message, DeprecationWarning)
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================== 1 passed, 6 warnings in 1.89s =========================