mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Complete 'requests' library removal (#7350)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 12s
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 12s
* refactor: initial commit moving watsonx_text to base_llm_http_handler + clarifying new provider directory structure * refactor(watsonx/completion/handler.py): move to using base llm http handler removes 'requests' library usage * fix(watsonx_text/transformation.py): fix result transformation migrates to transformation.py, for usage with base llm http handler * fix(streaming_handler.py): migrate watsonx streaming to transformation.py ensures streaming works with base llm http handler * fix(streaming_handler.py): fix streaming linting errors and remove watsonx conditional logic * fix(watsonx/): fix chat route post completion route refactor * refactor(watsonx/embed): refactor watsonx to use base llm http handler for embedding calls as well * refactor(base.py): remove requests library usage from litellm * build(pyproject.toml): remove requests library usage * fix: fix linting errors * fix: fix linting errors * fix(types/utils.py): fix validation errors for modelresponsestream * fix(replicate/handler.py): fix linting errors * fix(litellm_logging.py): handle modelresponsestream object * fix(streaming_handler.py): fix modelresponsestream args * fix: remove unused imports * test: fix test * fix: fix test * test: fix test * test: fix tests * test: fix test * test: fix patch target * test: fix test
This commit is contained in:
parent
8b1ea40e7b
commit
3671829e39
39 changed files with 2147 additions and 2279 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -67,3 +67,4 @@ litellm/tests/langfuse.log
|
|||
litellm/proxy/google-cloud-sdk/*
|
||||
tests/llm_translation/log.txt
|
||||
venv/
|
||||
tests/local_testing/log.txt
|
||||
|
|
|
@ -5,16 +5,16 @@ When adding a new provider, you need to create a directory for the provider that
|
|||
```
|
||||
litellm/llms/
|
||||
└── provider_name/
|
||||
├── completion/
|
||||
├── completion/ # use when endpoint is equivalent to openai's `/v1/completions`
|
||||
│ ├── handler.py
|
||||
│ └── transformation.py
|
||||
├── chat/
|
||||
├── chat/ # use when endpoint is equivalent to openai's `/v1/chat/completions`
|
||||
│ ├── handler.py
|
||||
│ └── transformation.py
|
||||
├── embed/
|
||||
├── embed/ # use when endpoint is equivalent to openai's `/v1/embeddings`
|
||||
│ ├── handler.py
|
||||
│ └── transformation.py
|
||||
└── rerank/
|
||||
└── rerank/ # use when endpoint is equivalent to cohere's `/rerank` endpoint.
|
||||
├── handler.py
|
||||
└── transformation.py
|
||||
```
|
||||
|
|
|
@ -991,6 +991,7 @@ from .utils import (
|
|||
get_api_base,
|
||||
get_first_chars_messages,
|
||||
ModelResponse,
|
||||
ModelResponseStream,
|
||||
EmbeddingResponse,
|
||||
ImageResponse,
|
||||
TranscriptionResponse,
|
||||
|
@ -1157,6 +1158,7 @@ from .llms.perplexity.chat.transformation import PerplexityChatConfig
|
|||
from .llms.azure.chat.o1_transformation import AzureOpenAIO1Config
|
||||
from .llms.watsonx.completion.transformation import IBMWatsonXAIConfig
|
||||
from .llms.watsonx.chat.transformation import IBMWatsonXChatConfig
|
||||
from .llms.watsonx.embed.transformation import IBMWatsonXEmbeddingConfig
|
||||
from .main import * # type: ignore
|
||||
from .integrations import *
|
||||
from .exceptions import (
|
||||
|
|
|
@ -43,6 +43,7 @@ from litellm.types.utils import (
|
|||
ImageResponse,
|
||||
LiteLLMLoggingBaseClass,
|
||||
ModelResponse,
|
||||
ModelResponseStream,
|
||||
StandardCallbackDynamicParams,
|
||||
StandardLoggingAdditionalHeaders,
|
||||
StandardLoggingHiddenParams,
|
||||
|
@ -741,6 +742,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
self,
|
||||
result: Union[
|
||||
ModelResponse,
|
||||
ModelResponseStream,
|
||||
EmbeddingResponse,
|
||||
ImageResponse,
|
||||
TranscriptionResponse,
|
||||
|
@ -848,6 +850,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
): # handle streaming separately
|
||||
if (
|
||||
isinstance(result, ModelResponse)
|
||||
or isinstance(result, ModelResponseStream)
|
||||
or isinstance(result, EmbeddingResponse)
|
||||
or isinstance(result, ImageResponse)
|
||||
or isinstance(result, TranscriptionResponse)
|
||||
|
@ -955,6 +958,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
if self.stream and (
|
||||
isinstance(result, litellm.ModelResponse)
|
||||
or isinstance(result, TextCompletionResponse)
|
||||
or isinstance(result, ModelResponseStream)
|
||||
):
|
||||
complete_streaming_response: Optional[
|
||||
Union[ModelResponse, TextCompletionResponse]
|
||||
|
@ -966,9 +970,6 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
streaming_chunks=self.sync_streaming_chunks,
|
||||
is_async=False,
|
||||
)
|
||||
_caching_complete_streaming_response: Optional[
|
||||
Union[ModelResponse, TextCompletionResponse]
|
||||
] = None
|
||||
if complete_streaming_response is not None:
|
||||
verbose_logger.debug(
|
||||
"Logging Details LiteLLM-Success Call streaming complete"
|
||||
|
@ -976,9 +977,6 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
self.model_call_details["complete_streaming_response"] = (
|
||||
complete_streaming_response
|
||||
)
|
||||
_caching_complete_streaming_response = copy.deepcopy(
|
||||
complete_streaming_response
|
||||
)
|
||||
self.model_call_details["response_cost"] = (
|
||||
self._response_cost_calculator(result=complete_streaming_response)
|
||||
)
|
||||
|
@ -1474,6 +1472,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
] = None
|
||||
if self.stream is True and (
|
||||
isinstance(result, litellm.ModelResponse)
|
||||
or isinstance(result, litellm.ModelResponseStream)
|
||||
or isinstance(result, TextCompletionResponse)
|
||||
):
|
||||
complete_streaming_response: Optional[
|
||||
|
|
|
@ -2,7 +2,11 @@ from datetime import datetime
|
|||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.types.utils import ModelResponse, TextCompletionResponse
|
||||
from litellm.types.utils import (
|
||||
ModelResponse,
|
||||
ModelResponseStream,
|
||||
TextCompletionResponse,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm import ModelResponse as _ModelResponse
|
||||
|
@ -38,7 +42,7 @@ def convert_litellm_response_object_to_str(
|
|||
|
||||
|
||||
def _assemble_complete_response_from_streaming_chunks(
|
||||
result: Union[ModelResponse, TextCompletionResponse],
|
||||
result: Union[ModelResponse, TextCompletionResponse, ModelResponseStream],
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
request_kwargs: dict,
|
||||
|
|
|
@ -5,7 +5,7 @@ import time
|
|||
import traceback
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Callable, List, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional, cast
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
@ -611,44 +611,6 @@ class CustomStreamWrapper:
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def handle_watsonx_stream(self, chunk):
|
||||
try:
|
||||
if isinstance(chunk, dict):
|
||||
parsed_response = chunk
|
||||
elif isinstance(chunk, (str, bytes)):
|
||||
if isinstance(chunk, bytes):
|
||||
chunk = chunk.decode("utf-8")
|
||||
if "generated_text" in chunk:
|
||||
response = chunk.replace("data: ", "").strip()
|
||||
parsed_response = json.loads(response)
|
||||
else:
|
||||
return {
|
||||
"text": "",
|
||||
"is_finished": False,
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
}
|
||||
else:
|
||||
print_verbose(f"chunk: {chunk} (Type: {type(chunk)})")
|
||||
raise ValueError(
|
||||
f"Unable to parse response. Original response: {chunk}"
|
||||
)
|
||||
results = parsed_response.get("results", [])
|
||||
if len(results) > 0:
|
||||
text = results[0].get("generated_text", "")
|
||||
finish_reason = results[0].get("stop_reason")
|
||||
is_finished = finish_reason != "not_finished"
|
||||
return {
|
||||
"text": text,
|
||||
"is_finished": is_finished,
|
||||
"finish_reason": finish_reason,
|
||||
"prompt_tokens": results[0].get("input_token_count", 0),
|
||||
"completion_tokens": results[0].get("generated_token_count", 0),
|
||||
}
|
||||
return {"text": "", "is_finished": False}
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def handle_triton_stream(self, chunk):
|
||||
try:
|
||||
if isinstance(chunk, dict):
|
||||
|
@ -702,9 +664,18 @@ class CustomStreamWrapper:
|
|||
# pop model keyword
|
||||
chunk.pop("model", None)
|
||||
|
||||
model_response = ModelResponse(
|
||||
stream=True, model=_model, stream_options=self.stream_options, **chunk
|
||||
)
|
||||
chunk_dict = {}
|
||||
for key, value in chunk.items():
|
||||
if key != "stream":
|
||||
chunk_dict[key] = value
|
||||
|
||||
args = {
|
||||
"model": _model,
|
||||
"stream_options": self.stream_options,
|
||||
**chunk_dict,
|
||||
}
|
||||
|
||||
model_response = ModelResponseStream(**args)
|
||||
if self.response_id is not None:
|
||||
model_response.id = self.response_id
|
||||
else:
|
||||
|
@ -742,9 +713,9 @@ class CustomStreamWrapper:
|
|||
|
||||
def return_processed_chunk_logic( # noqa
|
||||
self,
|
||||
completion_obj: dict,
|
||||
completion_obj: Dict[str, Any],
|
||||
model_response: ModelResponseStream,
|
||||
response_obj: dict,
|
||||
response_obj: Dict[str, Any],
|
||||
):
|
||||
|
||||
print_verbose(
|
||||
|
@ -887,11 +858,11 @@ class CustomStreamWrapper:
|
|||
|
||||
def chunk_creator(self, chunk): # type: ignore # noqa: PLR0915
|
||||
model_response = self.model_response_creator()
|
||||
response_obj: dict = {}
|
||||
response_obj: Dict[str, Any] = {}
|
||||
|
||||
try:
|
||||
# return this for all models
|
||||
completion_obj = {"content": ""}
|
||||
completion_obj: Dict[str, Any] = {"content": ""}
|
||||
from litellm.types.utils import GenericStreamingChunk as GChunk
|
||||
|
||||
if (
|
||||
|
@ -1089,11 +1060,6 @@ class CustomStreamWrapper:
|
|||
print_verbose(f"completion obj content: {completion_obj['content']}")
|
||||
if response_obj["is_finished"]:
|
||||
self.received_finish_reason = response_obj["finish_reason"]
|
||||
elif self.custom_llm_provider == "watsonx":
|
||||
response_obj = self.handle_watsonx_stream(chunk)
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
if response_obj["is_finished"]:
|
||||
self.received_finish_reason = response_obj["finish_reason"]
|
||||
elif self.custom_llm_provider == "triton":
|
||||
response_obj = self.handle_triton_stream(chunk)
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
|
@ -1158,7 +1124,7 @@ class CustomStreamWrapper:
|
|||
self.received_finish_reason = response_obj["finish_reason"]
|
||||
else: # openai / azure chat model
|
||||
if self.custom_llm_provider == "azure":
|
||||
if hasattr(chunk, "model"):
|
||||
if isinstance(chunk, BaseModel) and hasattr(chunk, "model"):
|
||||
# for azure, we need to pass the model from the orignal chunk
|
||||
self.model = chunk.model
|
||||
response_obj = self.handle_openai_chat_completion_chunk(chunk)
|
||||
|
@ -1190,7 +1156,10 @@ class CustomStreamWrapper:
|
|||
|
||||
if response_obj["usage"] is not None:
|
||||
if isinstance(response_obj["usage"], dict):
|
||||
model_response.usage = litellm.Usage(
|
||||
setattr(
|
||||
model_response,
|
||||
"usage",
|
||||
litellm.Usage(
|
||||
prompt_tokens=response_obj["usage"].get(
|
||||
"prompt_tokens", None
|
||||
)
|
||||
|
@ -1199,12 +1168,17 @@ class CustomStreamWrapper:
|
|||
"completion_tokens", None
|
||||
)
|
||||
or None,
|
||||
total_tokens=response_obj["usage"].get("total_tokens", None)
|
||||
total_tokens=response_obj["usage"].get(
|
||||
"total_tokens", None
|
||||
)
|
||||
or None,
|
||||
),
|
||||
)
|
||||
elif isinstance(response_obj["usage"], BaseModel):
|
||||
model_response.usage = litellm.Usage(
|
||||
**response_obj["usage"].model_dump()
|
||||
setattr(
|
||||
model_response,
|
||||
"usage",
|
||||
litellm.Usage(**response_obj["usage"].model_dump()),
|
||||
)
|
||||
|
||||
model_response.model = self.model
|
||||
|
@ -1337,7 +1311,7 @@ class CustomStreamWrapper:
|
|||
raise StopIteration
|
||||
except Exception as e:
|
||||
traceback.format_exc()
|
||||
e.message = str(e)
|
||||
setattr(e, "message", str(e))
|
||||
raise exception_type(
|
||||
model=self.model,
|
||||
custom_llm_provider=self.custom_llm_provider,
|
||||
|
@ -1434,7 +1408,9 @@ class CustomStreamWrapper:
|
|||
print_verbose(
|
||||
f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}; custom_llm_provider: {self.custom_llm_provider}"
|
||||
)
|
||||
response: Optional[ModelResponse] = self.chunk_creator(chunk=chunk)
|
||||
response: Optional[ModelResponseStream] = self.chunk_creator(
|
||||
chunk=chunk
|
||||
)
|
||||
print_verbose(f"PROCESSED CHUNK POST CHUNK CREATOR: {response}")
|
||||
|
||||
if response is None:
|
||||
|
@ -1597,7 +1573,7 @@ class CustomStreamWrapper:
|
|||
# __anext__ also calls async_success_handler, which does logging
|
||||
print_verbose(f"PROCESSED ASYNC CHUNK PRE CHUNK CREATOR: {chunk}")
|
||||
|
||||
processed_chunk: Optional[ModelResponse] = self.chunk_creator(
|
||||
processed_chunk: Optional[ModelResponseStream] = self.chunk_creator(
|
||||
chunk=chunk
|
||||
)
|
||||
print_verbose(
|
||||
|
@ -1624,7 +1600,7 @@ class CustomStreamWrapper:
|
|||
if self.logging_obj._llm_caching_handler is not None:
|
||||
asyncio.create_task(
|
||||
self.logging_obj._llm_caching_handler._add_streaming_response_to_cache(
|
||||
processed_chunk=processed_chunk,
|
||||
processed_chunk=cast(ModelResponse, processed_chunk),
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -1663,8 +1639,8 @@ class CustomStreamWrapper:
|
|||
chunk = next(self.completion_stream)
|
||||
if chunk is not None and chunk != b"":
|
||||
print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}")
|
||||
processed_chunk: Optional[ModelResponse] = self.chunk_creator(
|
||||
chunk=chunk
|
||||
processed_chunk: Optional[ModelResponseStream] = (
|
||||
self.chunk_creator(chunk=chunk)
|
||||
)
|
||||
print_verbose(
|
||||
f"PROCESSED CHUNK POST CHUNK CREATOR: {processed_chunk}"
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
from typing import Any, Optional, Union
|
||||
|
||||
import httpx
|
||||
import requests
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
|
||||
|
@ -16,7 +15,7 @@ class BaseLLM:
|
|||
def process_response(
|
||||
self,
|
||||
model: str,
|
||||
response: Union[requests.Response, httpx.Response],
|
||||
response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
stream: bool,
|
||||
logging_obj: Any,
|
||||
|
@ -35,7 +34,7 @@ class BaseLLM:
|
|||
def process_text_completion_response(
|
||||
self,
|
||||
model: str,
|
||||
response: Union[requests.Response, httpx.Response],
|
||||
response: httpx.Response,
|
||||
model_response: TextCompletionResponse,
|
||||
stream: bool,
|
||||
logging_obj: Any,
|
||||
|
|
|
@ -107,7 +107,13 @@ class BaseConfig(ABC):
|
|||
) -> dict:
|
||||
pass
|
||||
|
||||
def get_complete_url(self, api_base: str, model: str) -> str:
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: str,
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
OPTIONAL
|
||||
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues
|
||||
from litellm.types.utils import EmbeddingResponse, ModelResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -16,12 +16,11 @@ else:
|
|||
|
||||
|
||||
class BaseEmbeddingConfig(BaseConfig, ABC):
|
||||
|
||||
@abstractmethod
|
||||
def transform_embedding_request(
|
||||
self,
|
||||
model: str,
|
||||
input: Union[str, List[str], List[float], List[List[float]]],
|
||||
input: AllEmbeddingInputValues,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
|
@ -34,14 +33,20 @@ class BaseEmbeddingConfig(BaseConfig, ABC):
|
|||
raw_response: httpx.Response,
|
||||
model_response: EmbeddingResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str] = None,
|
||||
request_data: dict = {},
|
||||
optional_params: dict = {},
|
||||
litellm_params: dict = {},
|
||||
api_key: Optional[str],
|
||||
request_data: dict,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> EmbeddingResponse:
|
||||
return model_response
|
||||
|
||||
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
OPTIONAL
|
||||
|
||||
|
|
|
@ -72,7 +72,13 @@ class CloudflareChatConfig(BaseConfig):
|
|||
}
|
||||
return headers
|
||||
|
||||
def get_complete_url(self, api_base: str, model: str) -> str:
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: str,
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
return api_base + model
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
|
|
|
@ -110,6 +110,8 @@ class BaseLLMHTTPHandler:
|
|||
api_base = provider_config.get_complete_url(
|
||||
api_base=api_base,
|
||||
model=model,
|
||||
optional_params=optional_params,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
data = provider_config.transform_request(
|
||||
|
@ -402,6 +404,7 @@ class BaseLLMHTTPHandler:
|
|||
logging_obj: LiteLLMLoggingObj,
|
||||
api_base: Optional[str],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
model_response: EmbeddingResponse,
|
||||
api_key: Optional[str] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
|
@ -424,6 +427,7 @@ class BaseLLMHTTPHandler:
|
|||
api_base = provider_config.get_complete_url(
|
||||
api_base=api_base,
|
||||
model=model,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
data = provider_config.transform_embedding_request(
|
||||
|
@ -457,6 +461,8 @@ class BaseLLMHTTPHandler:
|
|||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
|
@ -484,6 +490,8 @@ class BaseLLMHTTPHandler:
|
|||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
request_data=data,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
async def aembedding(
|
||||
|
@ -496,6 +504,8 @@ class BaseLLMHTTPHandler:
|
|||
provider_config: BaseEmbeddingConfig,
|
||||
model_response: EmbeddingResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
|
@ -524,6 +534,8 @@ class BaseLLMHTTPHandler:
|
|||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
request_data=request_data,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
def rerank(
|
||||
|
|
|
@ -350,7 +350,13 @@ class OllamaConfig(BaseConfig):
|
|||
) -> dict:
|
||||
return headers
|
||||
|
||||
def get_complete_url(self, api_base: str, model: str) -> str:
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: str,
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
OPTIONAL
|
||||
|
||||
|
|
|
@ -168,7 +168,9 @@ def completion(
|
|||
time.time()
|
||||
) # for pricing this must remain right before calling api
|
||||
|
||||
prediction_url = replicate_config.get_complete_url(api_base, model)
|
||||
prediction_url = replicate_config.get_complete_url(
|
||||
api_base=api_base, model=model, optional_params=optional_params
|
||||
)
|
||||
|
||||
## COMPLETION CALL
|
||||
httpx_client = _get_httpx_client(
|
||||
|
@ -235,7 +237,9 @@ async def async_completion(
|
|||
headers: dict,
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
|
||||
prediction_url = replicate_config.get_complete_url(api_base=api_base, model=model)
|
||||
prediction_url = replicate_config.get_complete_url(
|
||||
api_base=api_base, model=model, optional_params=optional_params
|
||||
)
|
||||
async_handler = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.REPLICATE,
|
||||
params={"timeout": 600.0},
|
||||
|
|
|
@ -136,7 +136,13 @@ class ReplicateConfig(BaseConfig):
|
|||
status_code=status_code, message=error_message, headers=headers
|
||||
)
|
||||
|
||||
def get_complete_url(self, api_base: str, model: str) -> str:
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: str,
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
version_id = self.model_to_version_id(model)
|
||||
base_url = api_base
|
||||
if "deployments" in version_id:
|
||||
|
|
|
@ -7,6 +7,7 @@ from litellm.llms.base_llm.embedding.transformation import (
|
|||
BaseEmbeddingConfig,
|
||||
LiteLLMLoggingObj,
|
||||
)
|
||||
from litellm.types.llms.openai import AllEmbeddingInputValues
|
||||
from litellm.types.utils import EmbeddingResponse
|
||||
|
||||
from ..common_utils import TritonError
|
||||
|
@ -48,7 +49,7 @@ class TritonEmbeddingConfig(BaseEmbeddingConfig):
|
|||
def transform_embedding_request(
|
||||
self,
|
||||
model: str,
|
||||
input: Union[str, List[str], List[float], List[List[float]]],
|
||||
input: AllEmbeddingInputValues,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
|
|
|
@ -6,7 +6,7 @@ from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
|||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues
|
||||
from litellm.types.utils import EmbeddingResponse, Usage
|
||||
|
||||
|
||||
|
@ -38,7 +38,13 @@ class VoyageEmbeddingConfig(BaseEmbeddingConfig):
|
|||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def get_complete_url(self, api_base: Optional[str], model: str) -> str:
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
if api_base:
|
||||
if not api_base.endswith("/embeddings"):
|
||||
api_base = f"{api_base}/embeddings"
|
||||
|
@ -90,7 +96,7 @@ class VoyageEmbeddingConfig(BaseEmbeddingConfig):
|
|||
def transform_embedding_request(
|
||||
self,
|
||||
model: str,
|
||||
input: Union[str, List[str], List[float], List[List[float]]],
|
||||
input: AllEmbeddingInputValues,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
|
|
|
@ -3,57 +3,19 @@ from typing import Callable, Optional, Union
|
|||
import httpx
|
||||
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.types.llms.watsonx import WatsonXAIEndpoint, WatsonXAPIParams
|
||||
from litellm.types.utils import CustomStreamingDecoder, ModelResponse
|
||||
|
||||
from ...openai_like.chat.handler import OpenAILikeChatHandler
|
||||
from ..common_utils import WatsonXAIError, _get_api_params
|
||||
from ..common_utils import _get_api_params
|
||||
from .transformation import IBMWatsonXChatConfig
|
||||
|
||||
watsonx_chat_transformation = IBMWatsonXChatConfig()
|
||||
|
||||
|
||||
class WatsonXChatHandler(OpenAILikeChatHandler):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _prepare_url(
|
||||
self, model: str, api_params: WatsonXAPIParams, stream: Optional[bool]
|
||||
) -> str:
|
||||
if model.startswith("deployment/"):
|
||||
if api_params.get("space_id") is None:
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
|
||||
)
|
||||
deployment_id = "/".join(model.split("/")[1:])
|
||||
endpoint = (
|
||||
WatsonXAIEndpoint.DEPLOYMENT_CHAT_STREAM.value
|
||||
if stream is True
|
||||
else WatsonXAIEndpoint.DEPLOYMENT_CHAT.value
|
||||
)
|
||||
endpoint = endpoint.format(deployment_id=deployment_id)
|
||||
else:
|
||||
endpoint = (
|
||||
WatsonXAIEndpoint.CHAT_STREAM.value
|
||||
if stream is True
|
||||
else WatsonXAIEndpoint.CHAT.value
|
||||
)
|
||||
base_url = httpx.URL(api_params["url"])
|
||||
base_url = base_url.join(endpoint)
|
||||
full_url = str(
|
||||
base_url.copy_add_param(key="version", value=api_params["api_version"])
|
||||
)
|
||||
|
||||
return full_url
|
||||
|
||||
def _prepare_payload(
|
||||
self, model: str, api_params: WatsonXAPIParams, stream: Optional[bool]
|
||||
) -> dict:
|
||||
payload: dict = {}
|
||||
if model.startswith("deployment/"):
|
||||
return payload
|
||||
payload["model_id"] = model
|
||||
payload["project_id"] = api_params["project_id"]
|
||||
return payload
|
||||
|
||||
def completion(
|
||||
self,
|
||||
*,
|
||||
|
@ -70,32 +32,37 @@ class WatsonXChatHandler(OpenAILikeChatHandler):
|
|||
optional_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers: Optional[dict] = None,
|
||||
logger_fn=None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
custom_endpoint: Optional[bool] = None,
|
||||
streaming_decoder: Optional[CustomStreamingDecoder] = None,
|
||||
fake_stream: bool = False,
|
||||
):
|
||||
api_params = _get_api_params(optional_params, print_verbose=print_verbose)
|
||||
api_params = _get_api_params(params=optional_params)
|
||||
|
||||
if headers is None:
|
||||
headers = {}
|
||||
headers.update(
|
||||
{
|
||||
"Authorization": f"Bearer {api_params['token']}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
## UPDATE HEADERS
|
||||
headers = watsonx_chat_transformation.validate_environment(
|
||||
headers=headers or {},
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
stream: Optional[bool] = optional_params.get("stream", False)
|
||||
## GET API URL
|
||||
api_base = watsonx_chat_transformation.get_complete_url(
|
||||
api_base=api_base,
|
||||
model=model,
|
||||
optional_params=optional_params,
|
||||
stream=optional_params.get("stream", False),
|
||||
)
|
||||
|
||||
## get api url and payload
|
||||
api_base = self._prepare_url(model=model, api_params=api_params, stream=stream)
|
||||
watsonx_auth_payload = self._prepare_payload(
|
||||
model=model, api_params=api_params, stream=stream
|
||||
## UPDATE PAYLOAD (optional params)
|
||||
watsonx_auth_payload = watsonx_chat_transformation._prepare_payload(
|
||||
model=model,
|
||||
api_params=api_params,
|
||||
)
|
||||
optional_params.update(watsonx_auth_payload)
|
||||
|
||||
|
|
|
@ -7,12 +7,14 @@ Docs: https://cloud.ibm.com/apidocs/watsonx-ai#text-chat
|
|||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.watsonx import WatsonXAIEndpoint, WatsonXAPIParams
|
||||
|
||||
from ....utils import _remove_additional_properties, _remove_strict_from_schema
|
||||
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||
from ..common_utils import IBMWatsonXMixin, WatsonXAIError
|
||||
|
||||
|
||||
class IBMWatsonXChatConfig(OpenAIGPTConfig):
|
||||
class IBMWatsonXChatConfig(IBMWatsonXMixin, OpenAIGPTConfig):
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List:
|
||||
return [
|
||||
|
@ -75,3 +77,47 @@ class IBMWatsonXChatConfig(OpenAIGPTConfig):
|
|||
api_key or get_secret_str("HOSTED_VLLM_API_KEY") or ""
|
||||
) # vllm does not require an api key
|
||||
return api_base, dynamic_api_key
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: str,
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
url = self._get_base_url(api_base=api_base)
|
||||
if model.startswith("deployment/"):
|
||||
# deployment models are passed in as 'deployment/<deployment_id>'
|
||||
if optional_params.get("space_id") is None:
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
|
||||
)
|
||||
deployment_id = "/".join(model.split("/")[1:])
|
||||
endpoint = (
|
||||
WatsonXAIEndpoint.DEPLOYMENT_CHAT_STREAM.value
|
||||
if stream
|
||||
else WatsonXAIEndpoint.DEPLOYMENT_CHAT.value
|
||||
)
|
||||
endpoint = endpoint.format(deployment_id=deployment_id)
|
||||
else:
|
||||
endpoint = (
|
||||
WatsonXAIEndpoint.DEPLOYMENT_CHAT_STREAM.value
|
||||
if stream
|
||||
else WatsonXAIEndpoint.DEPLOYMENT_CHAT.value
|
||||
)
|
||||
url = url.rstrip("/") + endpoint
|
||||
|
||||
## add api version
|
||||
url = self._add_api_version_to_url(
|
||||
url=url, api_version=optional_params.pop("api_version", None)
|
||||
)
|
||||
return url
|
||||
|
||||
def _prepare_payload(self, model: str, api_params: WatsonXAPIParams) -> dict:
|
||||
payload: dict = {}
|
||||
if model.startswith("deployment/"):
|
||||
return payload
|
||||
payload["model_id"] = model
|
||||
payload["project_id"] = api_params["project_id"]
|
||||
return payload
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
from typing import Callable, Dict, Optional, Union, cast
|
||||
from typing import Dict, List, Optional, Union, cast
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm import verbose_logger
|
||||
from litellm.caching import InMemoryCache
|
||||
from litellm.litellm_core_utils.prompt_templates import factory as ptf
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.watsonx import WatsonXAPIParams
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.llms.watsonx import WatsonXAPIParams, WatsonXCredentials
|
||||
|
||||
|
||||
class WatsonXAIError(BaseLLMException):
|
||||
|
@ -65,18 +67,20 @@ def generate_iam_token(api_key=None, **params) -> str:
|
|||
return cast(str, result)
|
||||
|
||||
|
||||
def _generate_watsonx_token(api_key: Optional[str], token: Optional[str]) -> str:
|
||||
if token is not None:
|
||||
return token
|
||||
token = generate_iam_token(api_key)
|
||||
return token
|
||||
|
||||
|
||||
def _get_api_params(
|
||||
params: dict,
|
||||
print_verbose: Optional[Callable] = None,
|
||||
generate_token: Optional[bool] = True,
|
||||
) -> WatsonXAPIParams:
|
||||
"""
|
||||
Find watsonx.ai credentials in the params or environment variables and return the headers for authentication.
|
||||
"""
|
||||
# Load auth variables from params
|
||||
url = params.pop("url", params.pop("api_base", params.pop("base_url", None)))
|
||||
api_key = params.pop("apikey", None)
|
||||
token = params.pop("token", None)
|
||||
project_id = params.pop(
|
||||
"project_id", params.pop("watsonx_project", None)
|
||||
) # watsonx.ai project_id - allow 'watsonx_project' to be consistent with how vertex project implementation works -> reduce provider-specific params
|
||||
|
@ -86,29 +90,8 @@ def _get_api_params(
|
|||
region_name = params.pop(
|
||||
"watsonx_region_name", params.pop("watsonx_region", None)
|
||||
) # consistent with how vertex ai + aws regions are accepted
|
||||
wx_credentials = params.pop(
|
||||
"wx_credentials",
|
||||
params.pop(
|
||||
"watsonx_credentials", None
|
||||
), # follow {provider}_credentials, same as vertex ai
|
||||
)
|
||||
api_version = params.pop("api_version", litellm.WATSONX_DEFAULT_API_VERSION)
|
||||
|
||||
# Load auth variables from environment variables
|
||||
if url is None:
|
||||
url = (
|
||||
get_secret_str("WATSONX_API_BASE") # consistent with 'AZURE_API_BASE'
|
||||
or get_secret_str("WATSONX_URL")
|
||||
or get_secret_str("WX_URL")
|
||||
or get_secret_str("WML_URL")
|
||||
)
|
||||
if api_key is None:
|
||||
api_key = (
|
||||
get_secret_str("WATSONX_APIKEY")
|
||||
or get_secret_str("WATSONX_API_KEY")
|
||||
or get_secret_str("WX_API_KEY")
|
||||
)
|
||||
if token is None:
|
||||
token = get_secret_str("WATSONX_TOKEN") or get_secret_str("WX_TOKEN")
|
||||
if project_id is None:
|
||||
project_id = (
|
||||
get_secret_str("WATSONX_PROJECT_ID")
|
||||
|
@ -129,34 +112,6 @@ def _get_api_params(
|
|||
or get_secret_str("SPACE_ID")
|
||||
)
|
||||
|
||||
# credentials parsing
|
||||
if wx_credentials is not None:
|
||||
url = wx_credentials.get("url", url)
|
||||
api_key = wx_credentials.get("apikey", wx_credentials.get("api_key", api_key))
|
||||
token = wx_credentials.get(
|
||||
"token",
|
||||
wx_credentials.get(
|
||||
"watsonx_token", token
|
||||
), # follow format of {provider}_token, same as azure - e.g. 'azure_ad_token=..'
|
||||
)
|
||||
|
||||
# verify that all required credentials are present
|
||||
if url is None:
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
message="Error: Watsonx URL not set. Set WX_URL in environment variables or pass in as a parameter.",
|
||||
)
|
||||
|
||||
if token is None and api_key is not None and generate_token:
|
||||
# generate the auth token
|
||||
if print_verbose is not None:
|
||||
print_verbose("Generating IAM token for Watsonx.ai")
|
||||
token = generate_iam_token(api_key)
|
||||
elif token is None and api_key is None:
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
message="Error: API key or token not found. Set WX_API_KEY or WX_TOKEN in environment variables or pass in as a parameter.",
|
||||
)
|
||||
if project_id is None:
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
|
@ -164,11 +119,147 @@ def _get_api_params(
|
|||
)
|
||||
|
||||
return WatsonXAPIParams(
|
||||
url=url,
|
||||
api_key=api_key,
|
||||
token=cast(str, token),
|
||||
project_id=project_id,
|
||||
space_id=space_id,
|
||||
region_name=region_name,
|
||||
api_version=api_version,
|
||||
)
|
||||
|
||||
|
||||
def convert_watsonx_messages_to_prompt(
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
provider: str,
|
||||
custom_prompt_dict: Dict,
|
||||
) -> str:
|
||||
# handle anthropic prompts and amazon titan prompts
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_dict = custom_prompt_dict[model]
|
||||
prompt = ptf.custom_prompt(
|
||||
messages=messages,
|
||||
role_dict=model_prompt_dict.get(
|
||||
"role_dict", model_prompt_dict.get("roles")
|
||||
),
|
||||
initial_prompt_value=model_prompt_dict.get("initial_prompt_value", ""),
|
||||
final_prompt_value=model_prompt_dict.get("final_prompt_value", ""),
|
||||
bos_token=model_prompt_dict.get("bos_token", ""),
|
||||
eos_token=model_prompt_dict.get("eos_token", ""),
|
||||
)
|
||||
return prompt
|
||||
elif provider == "ibm-mistralai":
|
||||
prompt = ptf.mistral_instruct_pt(messages=messages)
|
||||
else:
|
||||
prompt: str = ptf.prompt_factory( # type: ignore
|
||||
model=model, messages=messages, custom_llm_provider="watsonx"
|
||||
)
|
||||
return prompt
|
||||
|
||||
|
||||
# Mixin class for shared IBM Watson X functionality
|
||||
class IBMWatsonXMixin:
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: Dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: Dict,
|
||||
api_key: Optional[str] = None,
|
||||
) -> Dict:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
token = cast(Optional[str], optional_params.get("token"))
|
||||
if token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
else:
|
||||
token = _generate_watsonx_token(api_key=api_key, token=token)
|
||||
# build auth headers
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
return headers
|
||||
|
||||
def _get_base_url(self, api_base: Optional[str]) -> str:
|
||||
url = (
|
||||
api_base
|
||||
or get_secret_str("WATSONX_API_BASE") # consistent with 'AZURE_API_BASE'
|
||||
or get_secret_str("WATSONX_URL")
|
||||
or get_secret_str("WX_URL")
|
||||
or get_secret_str("WML_URL")
|
||||
)
|
||||
|
||||
if url is None:
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
message="Error: Watsonx URL not set. Set WATSONX_API_BASE in environment variables or pass in as parameter - 'api_base='.",
|
||||
)
|
||||
return url
|
||||
|
||||
def _add_api_version_to_url(self, url: str, api_version: Optional[str]) -> str:
|
||||
api_version = api_version or litellm.WATSONX_DEFAULT_API_VERSION
|
||||
url = url + f"?version={api_version}"
|
||||
|
||||
return url
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[Dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return WatsonXAIError(
|
||||
status_code=status_code, message=error_message, headers=headers
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_watsonx_credentials(
|
||||
optional_params: dict, api_key: Optional[str], api_base: Optional[str]
|
||||
) -> WatsonXCredentials:
|
||||
api_key = (
|
||||
api_key
|
||||
or optional_params.pop("apikey", None)
|
||||
or get_secret_str("WATSONX_APIKEY")
|
||||
or get_secret_str("WATSONX_API_KEY")
|
||||
or get_secret_str("WX_API_KEY")
|
||||
)
|
||||
|
||||
api_base = (
|
||||
api_base
|
||||
or optional_params.pop(
|
||||
"url",
|
||||
optional_params.pop("api_base", optional_params.pop("base_url", None)),
|
||||
)
|
||||
or get_secret_str("WATSONX_API_BASE")
|
||||
or get_secret_str("WATSONX_URL")
|
||||
or get_secret_str("WX_URL")
|
||||
or get_secret_str("WML_URL")
|
||||
)
|
||||
|
||||
wx_credentials = optional_params.pop(
|
||||
"wx_credentials",
|
||||
optional_params.pop(
|
||||
"watsonx_credentials", None
|
||||
), # follow {provider}_credentials, same as vertex ai
|
||||
)
|
||||
|
||||
token: Optional[str] = None
|
||||
if wx_credentials is not None:
|
||||
api_base = wx_credentials.get("url", api_base)
|
||||
api_key = wx_credentials.get(
|
||||
"apikey", wx_credentials.get("api_key", api_key)
|
||||
)
|
||||
token = wx_credentials.get(
|
||||
"token",
|
||||
wx_credentials.get(
|
||||
"watsonx_token", None
|
||||
), # follow format of {provider}_token, same as azure - e.g. 'azure_ad_token=..'
|
||||
)
|
||||
if api_key is None or not isinstance(api_key, str):
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
message="Error: Watsonx API key not set. Set WATSONX_API_KEY in environment variables or pass in as parameter - 'api_key='.",
|
||||
)
|
||||
if api_base is None or not isinstance(api_base, str):
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
message="Error: Watsonx API base not set. Set WATSONX_API_BASE in environment variables or pass in as parameter - 'api_base='.",
|
||||
)
|
||||
return WatsonXCredentials(
|
||||
api_key=api_key, api_base=api_base, token=cast(Optional[str], token)
|
||||
)
|
||||
|
|
|
@ -1,551 +1,3 @@
|
|||
import asyncio
|
||||
import json # noqa: E401
|
||||
import time
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Generator,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
import httpx # type: ignore
|
||||
import requests # type: ignore
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.prompt_templates import factory as ptf
|
||||
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
|
||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.llms.watsonx import WatsonXAIEndpoint
|
||||
from litellm.utils import EmbeddingResponse, ModelResponse, Usage, map_finish_reason
|
||||
|
||||
from ...base import BaseLLM
|
||||
from ..common_utils import WatsonXAIError, _get_api_params
|
||||
from .transformation import IBMWatsonXAIConfig
|
||||
|
||||
|
||||
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict) -> str:
|
||||
# handle anthropic prompts and amazon titan prompts
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_dict = custom_prompt_dict[model]
|
||||
prompt = ptf.custom_prompt(
|
||||
messages=messages,
|
||||
role_dict=model_prompt_dict.get(
|
||||
"role_dict", model_prompt_dict.get("roles")
|
||||
),
|
||||
initial_prompt_value=model_prompt_dict.get("initial_prompt_value", ""),
|
||||
final_prompt_value=model_prompt_dict.get("final_prompt_value", ""),
|
||||
bos_token=model_prompt_dict.get("bos_token", ""),
|
||||
eos_token=model_prompt_dict.get("eos_token", ""),
|
||||
)
|
||||
return prompt
|
||||
elif provider == "ibm-mistralai":
|
||||
prompt = ptf.mistral_instruct_pt(messages=messages)
|
||||
else:
|
||||
prompt: str = ptf.prompt_factory( # type: ignore
|
||||
model=model, messages=messages, custom_llm_provider="watsonx"
|
||||
)
|
||||
return prompt
|
||||
|
||||
|
||||
class IBMWatsonXAI(BaseLLM):
|
||||
"""
|
||||
Class to interface with IBM watsonx.ai API for text generation and embeddings.
|
||||
|
||||
Reference: https://cloud.ibm.com/apidocs/watsonx-ai
|
||||
Watsonx uses the llm_http_handler.py to handle the requests.
|
||||
"""
|
||||
|
||||
api_version = "2024-03-13"
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def _prepare_text_generation_req(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[AllMessageValues],
|
||||
prompt: str,
|
||||
stream: bool,
|
||||
optional_params: dict,
|
||||
print_verbose: Optional[Callable] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Get the request parameters for text generation.
|
||||
"""
|
||||
api_params = _get_api_params(optional_params, print_verbose=print_verbose)
|
||||
# build auth headers
|
||||
api_token = api_params.get("token")
|
||||
self.token = api_token
|
||||
headers = IBMWatsonXAIConfig().validate_environment(
|
||||
headers={},
|
||||
model=model_id,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
api_key=api_token,
|
||||
)
|
||||
extra_body_params = optional_params.pop("extra_body", {})
|
||||
optional_params.update(extra_body_params)
|
||||
# init the payload to the text generation call
|
||||
payload = {
|
||||
"input": prompt,
|
||||
"moderations": optional_params.pop("moderations", {}),
|
||||
"parameters": optional_params,
|
||||
}
|
||||
request_params = dict(version=api_params["api_version"])
|
||||
# text generation endpoint deployment or model / stream or not
|
||||
if model_id.startswith("deployment/"):
|
||||
# deployment models are passed in as 'deployment/<deployment_id>'
|
||||
if api_params.get("space_id") is None:
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
|
||||
)
|
||||
deployment_id = "/".join(model_id.split("/")[1:])
|
||||
endpoint = (
|
||||
WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION_STREAM.value
|
||||
if stream
|
||||
else WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION.value
|
||||
)
|
||||
endpoint = endpoint.format(deployment_id=deployment_id)
|
||||
else:
|
||||
payload["model_id"] = model_id
|
||||
payload["project_id"] = api_params["project_id"]
|
||||
endpoint = (
|
||||
WatsonXAIEndpoint.TEXT_GENERATION_STREAM
|
||||
if stream
|
||||
else WatsonXAIEndpoint.TEXT_GENERATION
|
||||
)
|
||||
url = api_params["url"].rstrip("/") + endpoint
|
||||
return dict(
|
||||
method="POST", url=url, headers=headers, json=payload, params=request_params
|
||||
)
|
||||
|
||||
def _process_text_gen_response(
|
||||
self, json_resp: dict, model_response: Union[ModelResponse, None] = None
|
||||
) -> ModelResponse:
|
||||
if "results" not in json_resp:
|
||||
raise WatsonXAIError(
|
||||
status_code=500,
|
||||
message=f"Error: Invalid response from Watsonx.ai API: {json_resp}",
|
||||
)
|
||||
if model_response is None:
|
||||
model_response = ModelResponse(model=json_resp.get("model_id", None))
|
||||
generated_text = json_resp["results"][0]["generated_text"]
|
||||
prompt_tokens = json_resp["results"][0]["input_token_count"]
|
||||
completion_tokens = json_resp["results"][0]["generated_token_count"]
|
||||
model_response.choices[0].message.content = generated_text # type: ignore
|
||||
model_response.choices[0].finish_reason = map_finish_reason(
|
||||
json_resp["results"][0]["stop_reason"]
|
||||
)
|
||||
if json_resp.get("created_at"):
|
||||
model_response.created = int(
|
||||
datetime.fromisoformat(json_resp["created_at"]).timestamp()
|
||||
)
|
||||
else:
|
||||
model_response.created = int(time.time())
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj: Any,
|
||||
optional_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
timeout=None,
|
||||
):
|
||||
"""
|
||||
Send a text generation request to the IBM Watsonx.ai API.
|
||||
Reference: https://cloud.ibm.com/apidocs/watsonx-ai#text-generation
|
||||
"""
|
||||
stream = optional_params.pop("stream", False)
|
||||
|
||||
# Load default configs
|
||||
config = IBMWatsonXAIConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in optional_params:
|
||||
optional_params[k] = v
|
||||
|
||||
# Make prompt to send to model
|
||||
provider = model.split("/")[0]
|
||||
# model_name = "/".join(model.split("/")[1:])
|
||||
prompt = convert_messages_to_prompt(
|
||||
model, messages, provider, custom_prompt_dict
|
||||
)
|
||||
model_response.model = model
|
||||
|
||||
def process_stream_response(
|
||||
stream_resp: Union[Iterator[str], AsyncIterator],
|
||||
) -> CustomStreamWrapper:
|
||||
streamwrapper = litellm.CustomStreamWrapper(
|
||||
stream_resp,
|
||||
model=model,
|
||||
custom_llm_provider="watsonx",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streamwrapper
|
||||
|
||||
# create the function to manage the request to watsonx.ai
|
||||
self.request_manager = RequestManager(logging_obj)
|
||||
|
||||
def handle_text_request(request_params: dict) -> ModelResponse:
|
||||
with self.request_manager.request(
|
||||
request_params,
|
||||
input=prompt,
|
||||
timeout=timeout,
|
||||
) as resp:
|
||||
json_resp = resp.json()
|
||||
|
||||
return self._process_text_gen_response(json_resp, model_response)
|
||||
|
||||
async def handle_text_request_async(request_params: dict) -> ModelResponse:
|
||||
async with self.request_manager.async_request(
|
||||
request_params,
|
||||
input=prompt,
|
||||
timeout=timeout,
|
||||
) as resp:
|
||||
json_resp = resp.json()
|
||||
return self._process_text_gen_response(json_resp, model_response)
|
||||
|
||||
def handle_stream_request(request_params: dict) -> CustomStreamWrapper:
|
||||
# stream the response - generated chunks will be handled
|
||||
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
|
||||
with self.request_manager.request(
|
||||
request_params,
|
||||
stream=True,
|
||||
input=prompt,
|
||||
timeout=timeout,
|
||||
) as resp:
|
||||
streamwrapper = process_stream_response(resp.iter_lines())
|
||||
return streamwrapper
|
||||
|
||||
async def handle_stream_request_async(
|
||||
request_params: dict,
|
||||
) -> CustomStreamWrapper:
|
||||
# stream the response - generated chunks will be handled
|
||||
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
|
||||
async with self.request_manager.async_request(
|
||||
request_params,
|
||||
stream=True,
|
||||
input=prompt,
|
||||
timeout=timeout,
|
||||
) as resp:
|
||||
streamwrapper = process_stream_response(resp.aiter_lines())
|
||||
return streamwrapper
|
||||
|
||||
try:
|
||||
## Get the response from the model
|
||||
req_params = self._prepare_text_generation_req(
|
||||
model_id=model,
|
||||
prompt=prompt,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
optional_params=optional_params,
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
if stream and (acompletion is True):
|
||||
# stream and async text generation
|
||||
return handle_stream_request_async(req_params)
|
||||
elif stream:
|
||||
# streaming text generation
|
||||
return handle_stream_request(req_params)
|
||||
elif acompletion is True:
|
||||
# async text generation
|
||||
return handle_text_request_async(req_params)
|
||||
else:
|
||||
# regular text generation
|
||||
return handle_text_request(req_params)
|
||||
except WatsonXAIError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise WatsonXAIError(status_code=500, message=str(e))
|
||||
|
||||
def _process_embedding_response(
|
||||
self, json_resp: dict, model_response: Optional[EmbeddingResponse] = None
|
||||
) -> EmbeddingResponse:
|
||||
if model_response is None:
|
||||
model_response = EmbeddingResponse(model=json_resp.get("model_id", None))
|
||||
results = json_resp.get("results", [])
|
||||
embedding_response = []
|
||||
for idx, result in enumerate(results):
|
||||
embedding_response.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": result["embedding"],
|
||||
}
|
||||
)
|
||||
model_response.object = "list"
|
||||
model_response.data = embedding_response
|
||||
input_tokens = json_resp.get("input_token_count", 0)
|
||||
setattr(
|
||||
model_response,
|
||||
"usage",
|
||||
Usage(
|
||||
prompt_tokens=input_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=input_tokens,
|
||||
),
|
||||
)
|
||||
return model_response
|
||||
|
||||
def embedding(
|
||||
self,
|
||||
model: str,
|
||||
input: Union[list, str],
|
||||
model_response: EmbeddingResponse,
|
||||
api_key: Optional[str],
|
||||
logging_obj: Any,
|
||||
optional_params: dict,
|
||||
encoding=None,
|
||||
print_verbose=None,
|
||||
aembedding=None,
|
||||
) -> EmbeddingResponse:
|
||||
"""
|
||||
Send a text embedding request to the IBM Watsonx.ai API.
|
||||
"""
|
||||
if optional_params is None:
|
||||
optional_params = {}
|
||||
# Load default configs
|
||||
config = IBMWatsonXAIConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in optional_params:
|
||||
optional_params[k] = v
|
||||
|
||||
model_response.model = model
|
||||
|
||||
# Load auth variables from environment variables
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
if api_key is not None:
|
||||
optional_params["api_key"] = api_key
|
||||
api_params = _get_api_params(optional_params)
|
||||
# build auth headers
|
||||
api_token = api_params.get("token")
|
||||
self.token = api_token
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_token}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
# init the payload to the text generation call
|
||||
payload = {
|
||||
"inputs": input,
|
||||
"model_id": model,
|
||||
"project_id": api_params["project_id"],
|
||||
"parameters": optional_params,
|
||||
}
|
||||
request_params = dict(version=api_params["api_version"])
|
||||
url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.EMBEDDINGS
|
||||
req_params = {
|
||||
"method": "POST",
|
||||
"url": url,
|
||||
"headers": headers,
|
||||
"json": payload,
|
||||
"params": request_params,
|
||||
}
|
||||
request_manager = RequestManager(logging_obj)
|
||||
|
||||
def handle_embedding(request_params: dict) -> EmbeddingResponse:
|
||||
with request_manager.request(request_params, input=input) as resp:
|
||||
json_resp = resp.json()
|
||||
return self._process_embedding_response(json_resp, model_response)
|
||||
|
||||
async def handle_aembedding(request_params: dict) -> EmbeddingResponse:
|
||||
async with request_manager.async_request(
|
||||
request_params, input=input
|
||||
) as resp:
|
||||
json_resp = resp.json()
|
||||
return self._process_embedding_response(json_resp, model_response)
|
||||
|
||||
try:
|
||||
if aembedding is True:
|
||||
return handle_aembedding(req_params) # type: ignore
|
||||
else:
|
||||
return handle_embedding(req_params)
|
||||
except WatsonXAIError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise WatsonXAIError(status_code=500, message=str(e))
|
||||
|
||||
def get_available_models(self, *, ids_only: bool = True, **params):
|
||||
api_params = _get_api_params(params)
|
||||
self.token = api_params["token"]
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_params['token']}",
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
request_params = dict(version=api_params["api_version"])
|
||||
url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.AVAILABLE_MODELS
|
||||
req_params = dict(method="GET", url=url, headers=headers, params=request_params)
|
||||
with RequestManager(logging_obj=None).request(req_params) as resp:
|
||||
json_resp = resp.json()
|
||||
if not ids_only:
|
||||
return json_resp
|
||||
return [res["model_id"] for res in json_resp["resources"]]
|
||||
|
||||
|
||||
class RequestManager:
|
||||
"""
|
||||
A class to handle sync/async HTTP requests to the IBM Watsonx.ai API.
|
||||
|
||||
Usage:
|
||||
```python
|
||||
request_params = dict(method="POST", url="https://api.example.com", headers={"Authorization" : "Bearer token"}, json={"key": "value"})
|
||||
request_manager = RequestManager(logging_obj=logging_obj)
|
||||
with request_manager.request(request_params) as resp:
|
||||
...
|
||||
# or
|
||||
async with request_manager.async_request(request_params) as resp:
|
||||
...
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, logging_obj=None):
|
||||
self.logging_obj = logging_obj
|
||||
|
||||
def pre_call(
|
||||
self,
|
||||
request_params: dict,
|
||||
input: Optional[Any] = None,
|
||||
is_async: Optional[bool] = False,
|
||||
):
|
||||
if self.logging_obj is None:
|
||||
return
|
||||
request_str = (
|
||||
f"response = {'await ' if is_async else ''}{request_params['method']}(\n"
|
||||
f"\turl={request_params['url']},\n"
|
||||
f"\tjson={request_params.get('json')},\n"
|
||||
f")"
|
||||
)
|
||||
self.logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key=request_params["headers"].get("Authorization"),
|
||||
additional_args={
|
||||
"complete_input_dict": request_params.get("json"),
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
|
||||
def post_call(self, resp, request_params):
|
||||
if self.logging_obj is None:
|
||||
return
|
||||
self.logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=request_params["headers"].get("Authorization"),
|
||||
original_response=json.dumps(resp.json()),
|
||||
additional_args={
|
||||
"status_code": resp.status_code,
|
||||
"complete_input_dict": request_params.get(
|
||||
"data", request_params.get("json")
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def request(
|
||||
self,
|
||||
request_params: dict,
|
||||
stream: bool = False,
|
||||
input: Optional[Any] = None,
|
||||
timeout=None,
|
||||
) -> Generator[requests.Response, None, None]:
|
||||
"""
|
||||
Returns a context manager that yields the response from the request.
|
||||
"""
|
||||
self.pre_call(request_params, input)
|
||||
if timeout:
|
||||
request_params["timeout"] = timeout
|
||||
if stream:
|
||||
request_params["stream"] = stream
|
||||
try:
|
||||
resp = requests.request(**request_params)
|
||||
if not resp.ok:
|
||||
raise WatsonXAIError(
|
||||
status_code=resp.status_code,
|
||||
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
|
||||
)
|
||||
yield resp
|
||||
except Exception as e:
|
||||
raise WatsonXAIError(status_code=500, message=str(e))
|
||||
if not stream:
|
||||
self.post_call(resp, request_params)
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_request(
|
||||
self,
|
||||
request_params: dict,
|
||||
stream: bool = False,
|
||||
input: Optional[Any] = None,
|
||||
timeout=None,
|
||||
) -> AsyncGenerator[httpx.Response, None]:
|
||||
self.pre_call(request_params, input, is_async=True)
|
||||
if timeout:
|
||||
request_params["timeout"] = timeout
|
||||
if stream:
|
||||
request_params["stream"] = stream
|
||||
try:
|
||||
self.async_handler = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.WATSONX,
|
||||
params={
|
||||
"timeout": httpx.Timeout(
|
||||
timeout=request_params.pop("timeout", 600.0), connect=5.0
|
||||
),
|
||||
},
|
||||
)
|
||||
if "json" in request_params:
|
||||
request_params["data"] = json.dumps(request_params.pop("json", {}))
|
||||
method = request_params.pop("method")
|
||||
retries = 0
|
||||
resp: Optional[httpx.Response] = None
|
||||
while retries < 3:
|
||||
if method.upper() == "POST":
|
||||
resp = await self.async_handler.post(**request_params)
|
||||
else:
|
||||
resp = await self.async_handler.get(**request_params)
|
||||
if resp is not None and resp.status_code in [429, 503, 504, 520]:
|
||||
# to handle rate limiting and service unavailable errors
|
||||
# see: ibm_watsonx_ai.foundation_models.inference.base_model_inference.BaseModelInference._send_inference_payload
|
||||
await asyncio.sleep(2**retries)
|
||||
retries += 1
|
||||
else:
|
||||
break
|
||||
if resp is None:
|
||||
raise WatsonXAIError(
|
||||
status_code=500,
|
||||
message="No response from the server",
|
||||
)
|
||||
if resp.is_error:
|
||||
error_reason = getattr(resp, "reason", "")
|
||||
raise WatsonXAIError(
|
||||
status_code=resp.status_code,
|
||||
message=f"Error {resp.status_code} ({error_reason}): {resp.text}",
|
||||
)
|
||||
yield resp
|
||||
# await async_handler.close()
|
||||
except Exception as e:
|
||||
raise e
|
||||
raise WatsonXAIError(status_code=500, message=str(e))
|
||||
if not stream:
|
||||
self.post_call(resp, request_params)
|
||||
|
|
|
@ -1,13 +1,31 @@
|
|||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.utils import ModelResponse
|
||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUsageBlock
|
||||
from litellm.types.llms.watsonx import WatsonXAIEndpoint
|
||||
from litellm.types.utils import GenericStreamingChunk, ModelResponse, Usage
|
||||
from litellm.utils import map_finish_reason
|
||||
|
||||
from ...base_llm.chat.transformation import BaseConfig
|
||||
from ..common_utils import WatsonXAIError
|
||||
from ..common_utils import (
|
||||
IBMWatsonXMixin,
|
||||
WatsonXAIError,
|
||||
_get_api_params,
|
||||
convert_watsonx_messages_to_prompt,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
@ -17,7 +35,7 @@ else:
|
|||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class IBMWatsonXAIConfig(BaseConfig):
|
||||
class IBMWatsonXAIConfig(IBMWatsonXMixin, BaseConfig):
|
||||
"""
|
||||
Reference: https://cloud.ibm.com/apidocs/watsonx-ai#text-generation
|
||||
(See ibm_watsonx_ai.metanames.GenTextParamsMetaNames for a list of all available params)
|
||||
|
@ -210,13 +228,6 @@ class IBMWatsonXAIConfig(BaseConfig):
|
|||
"us-south",
|
||||
]
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[Dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return WatsonXAIError(
|
||||
status_code=status_code, message=error_message, headers=headers
|
||||
)
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -225,9 +236,28 @@ class IBMWatsonXAIConfig(BaseConfig):
|
|||
litellm_params: Dict,
|
||||
headers: Dict,
|
||||
) -> Dict:
|
||||
raise NotImplementedError(
|
||||
"transform_request not implemented. Done in watsonx/completion handler.py"
|
||||
provider = model.split("/")[0]
|
||||
prompt = convert_watsonx_messages_to_prompt(
|
||||
model=model,
|
||||
messages=messages,
|
||||
provider=provider,
|
||||
custom_prompt_dict={},
|
||||
)
|
||||
extra_body_params = optional_params.pop("extra_body", {})
|
||||
optional_params.update(extra_body_params)
|
||||
watsonx_api_params = _get_api_params(params=optional_params)
|
||||
# init the payload to the text generation call
|
||||
payload = {
|
||||
"input": prompt,
|
||||
"moderations": optional_params.pop("moderations", {}),
|
||||
"parameters": optional_params,
|
||||
}
|
||||
|
||||
if not model.startswith("deployment/"):
|
||||
payload["model_id"] = model
|
||||
payload["project_id"] = watsonx_api_params["project_id"]
|
||||
|
||||
return payload
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
|
@ -243,22 +273,120 @@ class IBMWatsonXAIConfig(BaseConfig):
|
|||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
raise NotImplementedError(
|
||||
"transform_response not implemented. Done in watsonx/completion handler.py"
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response=raw_response.text,
|
||||
)
|
||||
|
||||
def validate_environment(
|
||||
json_resp = raw_response.json()
|
||||
|
||||
if "results" not in json_resp:
|
||||
raise WatsonXAIError(
|
||||
status_code=500,
|
||||
message=f"Error: Invalid response from Watsonx.ai API: {json_resp}",
|
||||
)
|
||||
if model_response is None:
|
||||
model_response = ModelResponse(model=json_resp.get("model_id", None))
|
||||
generated_text = json_resp["results"][0]["generated_text"]
|
||||
prompt_tokens = json_resp["results"][0]["input_token_count"]
|
||||
completion_tokens = json_resp["results"][0]["generated_token_count"]
|
||||
model_response.choices[0].message.content = generated_text # type: ignore
|
||||
model_response.choices[0].finish_reason = map_finish_reason(
|
||||
json_resp["results"][0]["stop_reason"]
|
||||
)
|
||||
if json_resp.get("created_at"):
|
||||
model_response.created = int(
|
||||
datetime.fromisoformat(json_resp["created_at"]).timestamp()
|
||||
)
|
||||
else:
|
||||
model_response.created = int(time.time())
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
headers: Dict,
|
||||
api_base: str,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: Dict,
|
||||
api_key: Optional[str] = None,
|
||||
) -> Dict:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
return headers
|
||||
optional_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
url = self._get_base_url(api_base=api_base)
|
||||
if model.startswith("deployment/"):
|
||||
# deployment models are passed in as 'deployment/<deployment_id>'
|
||||
if optional_params.get("space_id") is None:
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
|
||||
)
|
||||
deployment_id = "/".join(model.split("/")[1:])
|
||||
endpoint = (
|
||||
WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION_STREAM.value
|
||||
if stream
|
||||
else WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION.value
|
||||
)
|
||||
endpoint = endpoint.format(deployment_id=deployment_id)
|
||||
else:
|
||||
endpoint = (
|
||||
WatsonXAIEndpoint.TEXT_GENERATION_STREAM
|
||||
if stream
|
||||
else WatsonXAIEndpoint.TEXT_GENERATION
|
||||
)
|
||||
url = url.rstrip("/") + endpoint
|
||||
|
||||
## add api version
|
||||
url = self._add_api_version_to_url(
|
||||
url=url, api_version=optional_params.pop("api_version", None)
|
||||
)
|
||||
return url
|
||||
|
||||
def get_model_response_iterator(
|
||||
self,
|
||||
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
):
|
||||
return WatsonxTextCompletionResponseIterator(
|
||||
streaming_response=streaming_response,
|
||||
sync_stream=sync_stream,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
|
||||
class WatsonxTextCompletionResponseIterator(BaseModelResponseIterator):
|
||||
# def _handle_string_chunk(self, str_line: str) -> GenericStreamingChunk:
|
||||
# return self.chunk_parser(json.loads(str_line))
|
||||
|
||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||
try:
|
||||
results = chunk.get("results", [])
|
||||
if len(results) > 0:
|
||||
text = results[0].get("generated_text", "")
|
||||
finish_reason = results[0].get("stop_reason")
|
||||
is_finished = finish_reason != "not_finished"
|
||||
|
||||
return GenericStreamingChunk(
|
||||
text=text,
|
||||
is_finished=is_finished,
|
||||
finish_reason=finish_reason,
|
||||
usage=ChatCompletionUsageBlock(
|
||||
prompt_tokens=results[0].get("input_token_count", 0),
|
||||
completion_tokens=results[0].get("generated_token_count", 0),
|
||||
total_tokens=results[0].get("input_token_count", 0)
|
||||
+ results[0].get("generated_token_count", 0),
|
||||
),
|
||||
)
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
is_finished=False,
|
||||
finish_reason="stop",
|
||||
usage=None,
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
|
116
litellm/llms/watsonx/embed/transformation.py
Normal file
116
litellm/llms/watsonx/embed/transformation.py
Normal file
|
@ -0,0 +1,116 @@
|
|||
"""
|
||||
Translates from OpenAI's `/v1/embeddings` to IBM's `/text/embeddings` route.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.embedding.transformation import (
|
||||
BaseEmbeddingConfig,
|
||||
LiteLLMLoggingObj,
|
||||
)
|
||||
from litellm.types.llms.openai import AllEmbeddingInputValues
|
||||
from litellm.types.llms.watsonx import WatsonXAIEndpoint
|
||||
from litellm.types.utils import EmbeddingResponse, Usage
|
||||
|
||||
from ..common_utils import IBMWatsonXMixin, WatsonXAIError, _get_api_params
|
||||
|
||||
|
||||
class IBMWatsonXEmbeddingConfig(IBMWatsonXMixin, BaseEmbeddingConfig):
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
return []
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
return optional_params
|
||||
|
||||
def transform_embedding_request(
|
||||
self,
|
||||
model: str,
|
||||
input: AllEmbeddingInputValues,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
watsonx_api_params = _get_api_params(params=optional_params)
|
||||
project_id = watsonx_api_params["project_id"]
|
||||
if not project_id:
|
||||
raise ValueError("project_id is required")
|
||||
return {
|
||||
"inputs": input,
|
||||
"model_id": model,
|
||||
"project_id": project_id,
|
||||
"parameters": optional_params,
|
||||
}
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
url = self._get_base_url(api_base=api_base)
|
||||
endpoint = WatsonXAIEndpoint.EMBEDDINGS.value
|
||||
if model.startswith("deployment/"):
|
||||
# deployment models are passed in as 'deployment/<deployment_id>'
|
||||
if optional_params.get("space_id") is None:
|
||||
raise WatsonXAIError(
|
||||
status_code=401,
|
||||
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
|
||||
)
|
||||
deployment_id = "/".join(model.split("/")[1:])
|
||||
endpoint = endpoint.format(deployment_id=deployment_id)
|
||||
url = url.rstrip("/") + endpoint
|
||||
|
||||
## add api version
|
||||
url = self._add_api_version_to_url(
|
||||
url=url, api_version=optional_params.pop("api_version", None)
|
||||
)
|
||||
return url
|
||||
|
||||
def transform_embedding_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: EmbeddingResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str],
|
||||
request_data: dict,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> EmbeddingResponse:
|
||||
logging_obj.post_call(
|
||||
original_response=raw_response.text,
|
||||
)
|
||||
json_resp = raw_response.json()
|
||||
if model_response is None:
|
||||
model_response = EmbeddingResponse(model=json_resp.get("model_id", None))
|
||||
results = json_resp.get("results", [])
|
||||
embedding_response = []
|
||||
for idx, result in enumerate(results):
|
||||
embedding_response.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": result["embedding"],
|
||||
}
|
||||
)
|
||||
model_response.object = "list"
|
||||
model_response.data = embedding_response
|
||||
input_tokens = json_resp.get("input_token_count", 0)
|
||||
setattr(
|
||||
model_response,
|
||||
"usage",
|
||||
Usage(
|
||||
prompt_tokens=input_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=input_tokens,
|
||||
),
|
||||
)
|
||||
return model_response
|
134
litellm/main.py
134
litellm/main.py
|
@ -145,7 +145,7 @@ from .llms.vertex_ai.vertex_embeddings.embedding_handler import VertexEmbedding
|
|||
from .llms.vertex_ai.vertex_model_garden.main import VertexAIModelGardenModels
|
||||
from .llms.vllm.completion import handler as vllm_handler
|
||||
from .llms.watsonx.chat.handler import WatsonXChatHandler
|
||||
from .llms.watsonx.completion.handler import IBMWatsonXAI
|
||||
from .llms.watsonx.common_utils import IBMWatsonXMixin
|
||||
from .types.llms.openai import (
|
||||
ChatCompletionAssistantMessage,
|
||||
ChatCompletionAudioParam,
|
||||
|
@ -205,7 +205,6 @@ google_batch_embeddings = GoogleBatchEmbeddings()
|
|||
vertex_partner_models_chat_completion = VertexAIPartnerModels()
|
||||
vertex_model_garden_chat_completion = VertexAIModelGardenModels()
|
||||
vertex_text_to_speech = VertexTextToSpeechAPI()
|
||||
watsonxai = IBMWatsonXAI()
|
||||
sagemaker_llm = SagemakerLLM()
|
||||
watsonx_chat_completion = WatsonXChatHandler()
|
||||
openai_like_embedding = OpenAILikeEmbeddingHandler()
|
||||
|
@ -2585,43 +2584,68 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
custom_llm_provider="watsonx",
|
||||
)
|
||||
elif custom_llm_provider == "watsonx_text":
|
||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||
response = watsonxai.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params, # type: ignore
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
timeout=timeout, # type: ignore
|
||||
acompletion=acompletion,
|
||||
)
|
||||
if (
|
||||
"stream" in optional_params
|
||||
and optional_params["stream"] is True
|
||||
and not isinstance(response, CustomStreamWrapper)
|
||||
):
|
||||
# don't try to access stream object,
|
||||
response = CustomStreamWrapper(
|
||||
iter(response),
|
||||
model,
|
||||
custom_llm_provider="watsonx",
|
||||
logging_obj=logging,
|
||||
api_key = (
|
||||
api_key
|
||||
or optional_params.pop("apikey", None)
|
||||
or get_secret_str("WATSONX_APIKEY")
|
||||
or get_secret_str("WATSONX_API_KEY")
|
||||
or get_secret_str("WX_API_KEY")
|
||||
)
|
||||
|
||||
if optional_params.get("stream", False):
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=None,
|
||||
original_response=response,
|
||||
api_base = (
|
||||
api_base
|
||||
or optional_params.pop(
|
||||
"url",
|
||||
optional_params.pop(
|
||||
"api_base", optional_params.pop("base_url", None)
|
||||
),
|
||||
)
|
||||
or get_secret_str("WATSONX_API_BASE")
|
||||
or get_secret_str("WATSONX_URL")
|
||||
or get_secret_str("WX_URL")
|
||||
or get_secret_str("WML_URL")
|
||||
)
|
||||
|
||||
wx_credentials = optional_params.pop(
|
||||
"wx_credentials",
|
||||
optional_params.pop(
|
||||
"watsonx_credentials", None
|
||||
), # follow {provider}_credentials, same as vertex ai
|
||||
)
|
||||
|
||||
token: Optional[str] = None
|
||||
if wx_credentials is not None:
|
||||
api_base = wx_credentials.get("url", api_base)
|
||||
api_key = wx_credentials.get(
|
||||
"apikey", wx_credentials.get("api_key", api_key)
|
||||
)
|
||||
token = wx_credentials.get(
|
||||
"token",
|
||||
wx_credentials.get(
|
||||
"watsonx_token", None
|
||||
), # follow format of {provider}_token, same as azure - e.g. 'azure_ad_token=..'
|
||||
)
|
||||
|
||||
if token is not None:
|
||||
optional_params["token"] = token
|
||||
|
||||
response = base_llm_http_handler.completion(
|
||||
model=model,
|
||||
stream=stream,
|
||||
messages=messages,
|
||||
acompletion=acompletion,
|
||||
api_base=api_base,
|
||||
model_response=model_response,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
custom_llm_provider="watsonx_text",
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
||||
client=client,
|
||||
)
|
||||
## RESPONSE OBJECT
|
||||
response = response
|
||||
elif custom_llm_provider == "vllm":
|
||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||
model_response = vllm_handler.completion(
|
||||
|
@ -3485,6 +3509,7 @@ def embedding( # noqa: PLR0915
|
|||
optional_params=optional_params,
|
||||
client=client,
|
||||
aembedding=aembedding,
|
||||
litellm_params={},
|
||||
)
|
||||
elif custom_llm_provider == "gemini":
|
||||
gemini_api_key = (
|
||||
|
@ -3661,6 +3686,32 @@ def embedding( # noqa: PLR0915
|
|||
optional_params=optional_params,
|
||||
client=client,
|
||||
aembedding=aembedding,
|
||||
litellm_params={},
|
||||
)
|
||||
elif custom_llm_provider == "watsonx":
|
||||
credentials = IBMWatsonXMixin.get_watsonx_credentials(
|
||||
optional_params=optional_params, api_key=api_key, api_base=api_base
|
||||
)
|
||||
|
||||
api_key = credentials["api_key"]
|
||||
api_base = credentials["api_base"]
|
||||
|
||||
if "token" in credentials:
|
||||
optional_params["token"] = credentials["token"]
|
||||
|
||||
response = base_llm_http_handler.embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
logging_obj=logging,
|
||||
timeout=timeout,
|
||||
model_response=EmbeddingResponse(),
|
||||
optional_params=optional_params,
|
||||
litellm_params={},
|
||||
client=client,
|
||||
aembedding=aembedding,
|
||||
)
|
||||
elif custom_llm_provider == "xinference":
|
||||
api_key = (
|
||||
|
@ -3687,17 +3738,6 @@ def embedding( # noqa: PLR0915
|
|||
client=client,
|
||||
aembedding=aembedding,
|
||||
)
|
||||
elif custom_llm_provider == "watsonx":
|
||||
response = watsonxai.embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
optional_params=optional_params,
|
||||
model_response=EmbeddingResponse(),
|
||||
aembedding=aembedding,
|
||||
api_key=api_key,
|
||||
)
|
||||
elif custom_llm_provider == "azure_ai":
|
||||
api_base = (
|
||||
api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there
|
||||
|
|
|
@ -611,3 +611,6 @@ class FineTuningJobCreate(BaseModel):
|
|||
|
||||
class LiteLLMFineTuningJobCreate(FineTuningJobCreate):
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"]
|
||||
|
||||
|
||||
AllEmbeddingInputValues = Union[str, List[str], List[int], List[List[int]]]
|
||||
|
|
|
@ -6,13 +6,15 @@ from pydantic import BaseModel
|
|||
|
||||
|
||||
class WatsonXAPIParams(TypedDict):
|
||||
url: str
|
||||
api_key: Optional[str]
|
||||
token: str
|
||||
project_id: str
|
||||
space_id: Optional[str]
|
||||
region_name: Optional[str]
|
||||
api_version: str
|
||||
|
||||
|
||||
class WatsonXCredentials(TypedDict):
|
||||
api_key: str
|
||||
api_base: str
|
||||
token: Optional[str]
|
||||
|
||||
|
||||
class WatsonXAIEndpoint(str, Enum):
|
||||
|
|
|
@ -808,6 +808,8 @@ class ModelResponseStream(ModelResponseBase):
|
|||
def __init__(
|
||||
self,
|
||||
choices: Optional[List[Union[StreamingChoices, dict, BaseModel]]] = None,
|
||||
id: Optional[str] = None,
|
||||
created: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if choices is not None and isinstance(choices, list):
|
||||
|
@ -824,6 +826,20 @@ class ModelResponseStream(ModelResponseBase):
|
|||
kwargs["choices"] = new_choices
|
||||
else:
|
||||
kwargs["choices"] = [StreamingChoices()]
|
||||
|
||||
if id is None:
|
||||
id = _generate_id()
|
||||
else:
|
||||
id = id
|
||||
if created is None:
|
||||
created = int(time.time())
|
||||
else:
|
||||
created = created
|
||||
|
||||
kwargs["id"] = id
|
||||
kwargs["created"] = created
|
||||
kwargs["object"] = "chat.completion.chunk"
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __contains__(self, key):
|
||||
|
|
|
@ -6244,7 +6244,9 @@ class ProviderConfigManager:
|
|||
return litellm.VoyageEmbeddingConfig()
|
||||
elif litellm.LlmProviders.TRITON == provider:
|
||||
return litellm.TritonEmbeddingConfig()
|
||||
raise ValueError(f"Provider {provider} does not support embedding config")
|
||||
elif litellm.LlmProviders.WATSONX == provider:
|
||||
return litellm.IBMWatsonXEmbeddingConfig()
|
||||
raise ValueError(f"Provider {provider.value} does not support embedding config")
|
||||
|
||||
@staticmethod
|
||||
def get_provider_rerank_config(
|
||||
|
|
2542
poetry.lock
generated
2542
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -26,7 +26,6 @@ tokenizers = "*"
|
|||
click = "*"
|
||||
jinja2 = "^3.1.2"
|
||||
aiohttp = "*"
|
||||
requests = "^2.31.0"
|
||||
pydantic = "^2.0.0"
|
||||
jsonschema = "^4.22.0"
|
||||
|
||||
|
|
|
@ -133,7 +133,7 @@ def test_completion_xai(stream):
|
|||
for chunk in response:
|
||||
print(chunk)
|
||||
assert chunk is not None
|
||||
assert isinstance(chunk, litellm.ModelResponse)
|
||||
assert isinstance(chunk, litellm.ModelResponseStream)
|
||||
assert isinstance(chunk.choices[0], litellm.utils.StreamingChoices)
|
||||
|
||||
else:
|
||||
|
|
|
@ -1,104 +0,0 @@
|
|||
============================= test session starts ==============================
|
||||
platform darwin -- Python 3.11.4, pytest-8.3.2, pluggy-1.5.0 -- /Users/krrishdholakia/Documents/litellm/myenv/bin/python3.11
|
||||
cachedir: .pytest_cache
|
||||
rootdir: /Users/krrishdholakia/Documents/litellm
|
||||
configfile: pyproject.toml
|
||||
plugins: asyncio-0.23.8, respx-0.21.1, anyio-4.6.0
|
||||
asyncio: mode=Mode.STRICT
|
||||
collecting ... collected 1 item
|
||||
|
||||
test_function_calling.py::test_aaparallel_function_call[claude-3-haiku-20240307] <module 'litellm' from '/Users/krrishdholakia/Documents/litellm/litellm/__init__.py'>
|
||||
|
||||
|
||||
[92mRequest to litellm:[0m
|
||||
[92mlitellm.completion(model='claude-3-haiku-20240307', messages=[{'role': 'user', 'content': "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses"}], tools=[{'type': 'function', 'function': {'name': 'get_current_weather', 'description': 'Get the current weather in a given location', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}}], tool_choice='auto')[0m
|
||||
|
||||
|
||||
SYNC kwargs[caching]: False; litellm.cache: None; kwargs.get('cache')['no-cache']: False
|
||||
Final returned optional params: {'tools': [{'type': 'function', 'function': {'name': 'get_current_weather', 'description': 'Get the current weather in a given location', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}}], 'tool_choice': {'type': 'auto'}}
|
||||
optional_params: {'tools': [{'type': 'function', 'function': {'name': 'get_current_weather', 'description': 'Get the current weather in a given location', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}}], 'tool_choice': {'type': 'auto'}}
|
||||
SENT optional_params: {'tools': [{'type': 'function', 'function': {'name': 'get_current_weather', 'description': 'Get the current weather in a given location', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}}], 'tool_choice': {'type': 'auto'}, 'max_tokens': 4096}
|
||||
tool: {'type': 'function', 'function': {'name': 'get_current_weather', 'description': 'Get the current weather in a given location', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}}
|
||||
[92m
|
||||
|
||||
POST Request Sent from LiteLLM:
|
||||
curl -X POST \
|
||||
https://api.anthropic.com/v1/messages \
|
||||
-H 'accept: *****' -H 'anthropic-version: *****' -H 'content-type: *****' -H 'x-api-key: sk-ant-api03-bJf1M8qp-JDptRcZRE5ve5efAfSIaL5u-SZ9vItIkvuFcV5cUsd********************************************' -H 'anthropic-beta: *****' \
|
||||
-d '{'messages': [{'role': 'user', 'content': [{'type': 'text', 'text': "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses"}]}], 'tools': [{'name': 'get_current_weather', 'description': 'Get the current weather in a given location', 'input_schema': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}], 'tool_choice': {'type': 'auto'}, 'max_tokens': 4096, 'model': 'claude-3-haiku-20240307'}'
|
||||
[0m
|
||||
|
||||
_is_function_call: False
|
||||
RAW RESPONSE:
|
||||
{"id":"msg_01HRugqzL4WmcxMmbvDheTph","type":"message","role":"assistant","model":"claude-3-haiku-20240307","content":[{"type":"text","text":"Okay, let's check the current weather in those three cities:"},{"type":"tool_use","id":"toolu_016U6G3kpxjHSiJLwVCrrScz","name":"get_current_weather","input":{"location":"San Francisco","unit":"celsius"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":379,"output_tokens":87}}
|
||||
|
||||
|
||||
raw model_response: {"id":"msg_01HRugqzL4WmcxMmbvDheTph","type":"message","role":"assistant","model":"claude-3-haiku-20240307","content":[{"type":"text","text":"Okay, let's check the current weather in those three cities:"},{"type":"tool_use","id":"toolu_016U6G3kpxjHSiJLwVCrrScz","name":"get_current_weather","input":{"location":"San Francisco","unit":"celsius"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":379,"output_tokens":87}}
|
||||
Logging Details LiteLLM-Success Call: Cache_hit=None
|
||||
Looking up model=claude-3-haiku-20240307 in model_cost_map
|
||||
Looking up model=claude-3-haiku-20240307 in model_cost_map
|
||||
Response
|
||||
ModelResponse(id='chatcmpl-7222f6c2-962a-4776-8639-576723466cb7', choices=[Choices(finish_reason='tool_calls', index=0, message=Message(content="Okay, let's check the current weather in those three cities:", role='assistant', tool_calls=[ChatCompletionMessageToolCall(index=1, function=Function(arguments='{"location": "San Francisco", "unit": "celsius"}', name='get_current_weather'), id='toolu_016U6G3kpxjHSiJLwVCrrScz', type='function')], function_call=None))], created=1727897483, model='claude-3-haiku-20240307', object='chat.completion', system_fingerprint=None, usage=Usage(completion_tokens=87, prompt_tokens=379, total_tokens=466, completion_tokens_details=None))
|
||||
length of tool calls 1
|
||||
Expecting there to be 3 tool calls
|
||||
tool_calls: [ChatCompletionMessageToolCall(index=1, function=Function(arguments='{"location": "San Francisco", "unit": "celsius"}', name='get_current_weather'), id='toolu_016U6G3kpxjHSiJLwVCrrScz', type='function')]
|
||||
Response message
|
||||
Message(content="Okay, let's check the current weather in those three cities:", role='assistant', tool_calls=[ChatCompletionMessageToolCall(index=1, function=Function(arguments='{"location": "San Francisco", "unit": "celsius"}', name='get_current_weather'), id='toolu_016U6G3kpxjHSiJLwVCrrScz', type='function')], function_call=None)
|
||||
messages: [{'role': 'user', 'content': "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses"}, Message(content="Okay, let's check the current weather in those three cities:", role='assistant', tool_calls=[ChatCompletionMessageToolCall(index=1, function=Function(arguments='{"location": "San Francisco", "unit": "celsius"}', name='get_current_weather'), id='toolu_016U6G3kpxjHSiJLwVCrrScz', type='function')], function_call=None), {'tool_call_id': 'toolu_016U6G3kpxjHSiJLwVCrrScz', 'role': 'tool', 'name': 'get_current_weather', 'content': '{"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"}'}]
|
||||
|
||||
|
||||
[92mRequest to litellm:[0m
|
||||
[92mlitellm.completion(model='claude-3-haiku-20240307', messages=[{'role': 'user', 'content': "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses"}, Message(content="Okay, let's check the current weather in those three cities:", role='assistant', tool_calls=[ChatCompletionMessageToolCall(index=1, function=Function(arguments='{"location": "San Francisco", "unit": "celsius"}', name='get_current_weather'), id='toolu_016U6G3kpxjHSiJLwVCrrScz', type='function')], function_call=None), {'tool_call_id': 'toolu_016U6G3kpxjHSiJLwVCrrScz', 'role': 'tool', 'name': 'get_current_weather', 'content': '{"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"}'}], temperature=0.2, seed=22, drop_params=True)[0m
|
||||
|
||||
|
||||
SYNC kwargs[caching]: False; litellm.cache: None; kwargs.get('cache')['no-cache']: False
|
||||
Final returned optional params: {'temperature': 0.2, 'tools': [{'type': 'function', 'function': {'name': 'dummy-tool', 'description': '', 'parameters': {'type': 'object', 'properties': {}}}}]}
|
||||
optional_params: {'temperature': 0.2, 'tools': [{'type': 'function', 'function': {'name': 'dummy-tool', 'description': '', 'parameters': {'type': 'object', 'properties': {}}}}]}
|
||||
SENT optional_params: {'temperature': 0.2, 'tools': [{'type': 'function', 'function': {'name': 'dummy-tool', 'description': '', 'parameters': {'type': 'object', 'properties': {}}}}], 'max_tokens': 4096}
|
||||
tool: {'type': 'function', 'function': {'name': 'dummy-tool', 'description': '', 'parameters': {'type': 'object', 'properties': {}}}}
|
||||
[92m
|
||||
|
||||
POST Request Sent from LiteLLM:
|
||||
curl -X POST \
|
||||
https://api.anthropic.com/v1/messages \
|
||||
-H 'accept: *****' -H 'anthropic-version: *****' -H 'content-type: *****' -H 'x-api-key: sk-ant-api03-bJf1M8qp-JDptRcZRE5ve5efAfSIaL5u-SZ9vItIkvuFcV5cUsd********************************************' -H 'anthropic-beta: *****' \
|
||||
-d '{'messages': [{'role': 'user', 'content': [{'type': 'text', 'text': "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses"}]}, {'role': 'assistant', 'content': [{'type': 'tool_use', 'id': 'toolu_016U6G3kpxjHSiJLwVCrrScz', 'name': 'get_current_weather', 'input': {'location': 'San Francisco', 'unit': 'celsius'}}]}, {'role': 'user', 'content': [{'type': 'tool_result', 'tool_use_id': 'toolu_016U6G3kpxjHSiJLwVCrrScz', 'content': '{"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"}'}]}], 'temperature': 0.2, 'tools': [{'name': 'dummy-tool', 'description': '', 'input_schema': {'type': 'object', 'properties': {}}}], 'max_tokens': 4096, 'model': 'claude-3-haiku-20240307'}'
|
||||
[0m
|
||||
|
||||
_is_function_call: False
|
||||
RAW RESPONSE:
|
||||
{"id":"msg_01Wp8NVScugz6yAGsmB5trpZ","type":"message","role":"assistant","model":"claude-3-haiku-20240307","content":[{"type":"text","text":"The current weather in San Francisco is 72°F (22°C)."},{"type":"tool_use","id":"toolu_01HTXEYDX4MspM76STtJqs1n","name":"get_current_weather","input":{"location":"Tokyo","unit":"celsius"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":426,"output_tokens":90}}
|
||||
|
||||
|
||||
raw model_response: {"id":"msg_01Wp8NVScugz6yAGsmB5trpZ","type":"message","role":"assistant","model":"claude-3-haiku-20240307","content":[{"type":"text","text":"The current weather in San Francisco is 72°F (22°C)."},{"type":"tool_use","id":"toolu_01HTXEYDX4MspM76STtJqs1n","name":"get_current_weather","input":{"location":"Tokyo","unit":"celsius"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":426,"output_tokens":90}}
|
||||
Logging Details LiteLLM-Success Call: Cache_hit=None
|
||||
Looking up model=claude-3-haiku-20240307 in model_cost_map
|
||||
Looking up model=claude-3-haiku-20240307 in model_cost_map
|
||||
second response
|
||||
ModelResponse(id='chatcmpl-c4ed5c25-ba7c-49e5-a6be-5720ab25fff0', choices=[Choices(finish_reason='tool_calls', index=0, message=Message(content='The current weather in San Francisco is 72°F (22°C).', role='assistant', tool_calls=[ChatCompletionMessageToolCall(index=1, function=Function(arguments='{"location": "Tokyo", "unit": "celsius"}', name='get_current_weather'), id='toolu_01HTXEYDX4MspM76STtJqs1n', type='function')], function_call=None))], created=1727897484, model='claude-3-haiku-20240307', object='chat.completion', system_fingerprint=None, usage=Usage(completion_tokens=90, prompt_tokens=426, total_tokens=516, completion_tokens_details=None))
|
||||
PASSED
|
||||
|
||||
=============================== warnings summary ===============================
|
||||
../../myenv/lib/python3.11/site-packages/pydantic/_internal/_config.py:284
|
||||
/Users/krrishdholakia/Documents/litellm/myenv/lib/python3.11/site-packages/pydantic/_internal/_config.py:284: PydanticDeprecatedSince20: Support for class-based `config` is deprecated, use ConfigDict instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/
|
||||
warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning)
|
||||
|
||||
../../litellm/utils.py:17
|
||||
/Users/krrishdholakia/Documents/litellm/litellm/utils.py:17: DeprecationWarning: 'imghdr' is deprecated and slated for removal in Python 3.13
|
||||
import imghdr
|
||||
|
||||
../../litellm/utils.py:124
|
||||
/Users/krrishdholakia/Documents/litellm/litellm/utils.py:124: DeprecationWarning: open_text is deprecated. Use files() instead. Refer to https://importlib-resources.readthedocs.io/en/latest/using.html#migrating-from-legacy for migration advice.
|
||||
with resources.open_text("litellm.llms.tokenizers", "anthropic_tokenizer.json") as f:
|
||||
|
||||
test_function_calling.py:56
|
||||
/Users/krrishdholakia/Documents/litellm/tests/local_testing/test_function_calling.py:56: PytestUnknownMarkWarning: Unknown pytest.mark.flaky - is this a typo? You can register custom marks to avoid this warning - for details, see https://docs.pytest.org/en/stable/how-to/mark.html
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
|
||||
tests/local_testing/test_function_calling.py::test_aaparallel_function_call[claude-3-haiku-20240307]
|
||||
tests/local_testing/test_function_calling.py::test_aaparallel_function_call[claude-3-haiku-20240307]
|
||||
/Users/krrishdholakia/Documents/litellm/myenv/lib/python3.11/site-packages/httpx/_content.py:202: DeprecationWarning: Use 'content=<...>' to upload raw bytes/text content.
|
||||
warnings.warn(message, DeprecationWarning)
|
||||
|
||||
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
|
||||
======================== 1 passed, 6 warnings in 1.89s =========================
|
|
@ -1805,7 +1805,7 @@ async def test_gemini_pro_function_calling_streaming(sync_mode):
|
|||
|
||||
for chunk in response:
|
||||
chunks.append(chunk)
|
||||
assert isinstance(chunk, litellm.ModelResponse)
|
||||
assert isinstance(chunk, litellm.ModelResponseStream)
|
||||
else:
|
||||
response = await litellm.acompletion(**data)
|
||||
print(f"completion: {response}")
|
||||
|
@ -1815,7 +1815,7 @@ async def test_gemini_pro_function_calling_streaming(sync_mode):
|
|||
async for chunk in response:
|
||||
print(f"chunk: {chunk}")
|
||||
chunks.append(chunk)
|
||||
assert isinstance(chunk, litellm.ModelResponse)
|
||||
assert isinstance(chunk, litellm.ModelResponseStream)
|
||||
|
||||
complete_response = litellm.stream_chunk_builder(chunks=chunks)
|
||||
assert (
|
||||
|
|
|
@ -4019,20 +4019,21 @@ def test_completion_deepseek():
|
|||
@pytest.mark.skip(reason="Account deleted by IBM.")
|
||||
def test_completion_watsonx_error():
|
||||
litellm.set_verbose = True
|
||||
model_name = "watsonx/ibm/granite-13b-chat-v2"
|
||||
model_name = "watsonx_text/ibm/granite-13b-chat-v2"
|
||||
|
||||
with pytest.raises(litellm.BadRequestError) as e:
|
||||
response = completion(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
stop=["stop"],
|
||||
max_tokens=20,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
for chunk in response:
|
||||
print(chunk)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
|
||||
assert "use 'watsonx_text' route instead" in str(e).lower()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Skip test. account deleted.")
|
||||
def test_completion_stream_watsonx():
|
||||
|
|
|
@ -135,7 +135,7 @@ class CompletionCustomHandler(
|
|||
## END TIME
|
||||
assert isinstance(end_time, datetime)
|
||||
## RESPONSE OBJECT
|
||||
assert isinstance(response_obj, litellm.ModelResponse)
|
||||
assert isinstance(response_obj, litellm.ModelResponseStream)
|
||||
## KWARGS
|
||||
assert isinstance(kwargs["model"], str)
|
||||
assert isinstance(kwargs["messages"], list) and isinstance(
|
||||
|
|
|
@ -153,7 +153,7 @@ class CompletionCustomHandler(
|
|||
## END TIME
|
||||
assert isinstance(end_time, datetime)
|
||||
## RESPONSE OBJECT
|
||||
assert isinstance(response_obj, litellm.ModelResponse)
|
||||
assert isinstance(response_obj, litellm.ModelResponseStream)
|
||||
## KWARGS
|
||||
assert isinstance(kwargs["model"], str)
|
||||
assert isinstance(kwargs["messages"], list) and isinstance(
|
||||
|
|
|
@ -122,8 +122,13 @@ async def test_async_chat_openai_stream():
|
|||
|
||||
complete_streaming_response = complete_streaming_response.strip("'")
|
||||
|
||||
print(f"complete_streaming_response: {complete_streaming_response}")
|
||||
|
||||
await asyncio.sleep(3)
|
||||
|
||||
print(
|
||||
f"tmp_function.complete_streaming_response_in_callback: {tmp_function.complete_streaming_response_in_callback}"
|
||||
)
|
||||
# problematic line
|
||||
response1 = tmp_function.complete_streaming_response_in_callback["choices"][0][
|
||||
"message"
|
||||
|
|
|
@ -801,8 +801,11 @@ def test_fireworks_embeddings():
|
|||
|
||||
|
||||
def test_watsonx_embeddings():
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
||||
def mock_wx_embed_request(method: str, url: str, **kwargs):
|
||||
client = HTTPHandler()
|
||||
|
||||
def mock_wx_embed_request(url: str, **kwargs):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"Content-Type": "application/json"}
|
||||
|
@ -816,12 +819,14 @@ def test_watsonx_embeddings():
|
|||
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
with patch("requests.request", side_effect=mock_wx_embed_request):
|
||||
with patch.object(client, "post", side_effect=mock_wx_embed_request):
|
||||
response = litellm.embedding(
|
||||
model="watsonx/ibm/slate-30m-english-rtrvr",
|
||||
input=["good morning from litellm"],
|
||||
token="secret-token",
|
||||
client=client,
|
||||
)
|
||||
|
||||
print(f"response: {response}")
|
||||
assert isinstance(response.usage, litellm.Usage)
|
||||
except litellm.RateLimitError as e:
|
||||
|
@ -832,6 +837,9 @@ def test_watsonx_embeddings():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_watsonx_aembeddings():
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
|
||||
client = AsyncHTTPHandler()
|
||||
|
||||
def mock_async_client(*args, **kwargs):
|
||||
|
||||
|
@ -856,12 +864,14 @@ async def test_watsonx_aembeddings():
|
|||
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
with patch("httpx.AsyncClient", side_effect=mock_async_client):
|
||||
with patch.object(client, "post", side_effect=mock_async_client) as mock_client:
|
||||
response = await litellm.aembedding(
|
||||
model="watsonx/ibm/slate-30m-english-rtrvr",
|
||||
input=["good morning from litellm"],
|
||||
token="secret-token",
|
||||
client=client,
|
||||
)
|
||||
mock_client.assert_called_once()
|
||||
print(f"response: {response}")
|
||||
assert isinstance(response.usage, litellm.Usage)
|
||||
except litellm.RateLimitError as e:
|
||||
|
|
|
@ -17,6 +17,7 @@ from pydantic import BaseModel
|
|||
import litellm.litellm_core_utils
|
||||
import litellm.litellm_core_utils.litellm_logging
|
||||
from litellm.utils import ModelResponseListIterator
|
||||
from litellm.types.utils import ModelResponseStream
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
|
@ -69,7 +70,7 @@ first_openai_chunk_example = {
|
|||
|
||||
def validate_first_format(chunk):
|
||||
# write a test to make sure chunk follows the same format as first_openai_chunk_example
|
||||
assert isinstance(chunk, ModelResponse), "Chunk should be a dictionary."
|
||||
assert isinstance(chunk, ModelResponseStream), "Chunk should be a dictionary."
|
||||
assert isinstance(chunk["id"], str), "'id' should be a string."
|
||||
assert isinstance(chunk["object"], str), "'object' should be a string."
|
||||
assert isinstance(chunk["created"], int), "'created' should be an integer."
|
||||
|
@ -99,7 +100,7 @@ second_openai_chunk_example = {
|
|||
|
||||
|
||||
def validate_second_format(chunk):
|
||||
assert isinstance(chunk, ModelResponse), "Chunk should be a dictionary."
|
||||
assert isinstance(chunk, ModelResponseStream), "Chunk should be a dictionary."
|
||||
assert isinstance(chunk["id"], str), "'id' should be a string."
|
||||
assert isinstance(chunk["object"], str), "'object' should be a string."
|
||||
assert isinstance(chunk["created"], int), "'created' should be an integer."
|
||||
|
@ -137,7 +138,7 @@ def validate_last_format(chunk):
|
|||
"""
|
||||
Ensure last chunk has no remaining content or tools
|
||||
"""
|
||||
assert isinstance(chunk, ModelResponse), "Chunk should be a dictionary."
|
||||
assert isinstance(chunk, ModelResponseStream), "Chunk should be a dictionary."
|
||||
assert isinstance(chunk["id"], str), "'id' should be a string."
|
||||
assert isinstance(chunk["object"], str), "'object' should be a string."
|
||||
assert isinstance(chunk["created"], int), "'created' should be an integer."
|
||||
|
@ -1523,7 +1524,7 @@ async def test_parallel_streaming_requests(sync_mode, model):
|
|||
num_finish_reason = 0
|
||||
for chunk in response:
|
||||
print(f"chunk: {chunk}")
|
||||
if isinstance(chunk, ModelResponse):
|
||||
if isinstance(chunk, ModelResponseStream):
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
num_finish_reason += 1
|
||||
assert num_finish_reason == 1
|
||||
|
@ -1541,7 +1542,7 @@ async def test_parallel_streaming_requests(sync_mode, model):
|
|||
num_finish_reason = 0
|
||||
async for chunk in response:
|
||||
print(f"type of chunk: {type(chunk)}")
|
||||
if isinstance(chunk, ModelResponse):
|
||||
if isinstance(chunk, ModelResponseStream):
|
||||
print(f"OUTSIDE CHUNK: {chunk.choices[0]}")
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
num_finish_reason += 1
|
||||
|
|
|
@ -1,104 +0,0 @@
|
|||
============================= test session starts ==============================
|
||||
platform darwin -- Python 3.11.4, pytest-8.3.2, pluggy-1.5.0 -- /Users/krrishdholakia/Documents/litellm/myenv/bin/python3.11
|
||||
cachedir: .pytest_cache
|
||||
rootdir: /Users/krrishdholakia/Documents/litellm
|
||||
configfile: pyproject.toml
|
||||
plugins: asyncio-0.23.8, respx-0.21.1, anyio-4.6.0
|
||||
asyncio: mode=Mode.STRICT
|
||||
collecting ... collected 1 item
|
||||
|
||||
test_function_calling.py::test_aaparallel_function_call[claude-3-haiku-20240307] <module 'litellm' from '/Users/krrishdholakia/Documents/litellm/litellm/__init__.py'>
|
||||
|
||||
|
||||
[92mRequest to litellm:[0m
|
||||
[92mlitellm.completion(model='claude-3-haiku-20240307', messages=[{'role': 'user', 'content': "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses"}], tools=[{'type': 'function', 'function': {'name': 'get_current_weather', 'description': 'Get the current weather in a given location', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}}], tool_choice='auto')[0m
|
||||
|
||||
|
||||
SYNC kwargs[caching]: False; litellm.cache: None; kwargs.get('cache')['no-cache']: False
|
||||
Final returned optional params: {'tools': [{'type': 'function', 'function': {'name': 'get_current_weather', 'description': 'Get the current weather in a given location', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}}], 'tool_choice': {'type': 'auto'}}
|
||||
optional_params: {'tools': [{'type': 'function', 'function': {'name': 'get_current_weather', 'description': 'Get the current weather in a given location', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}}], 'tool_choice': {'type': 'auto'}}
|
||||
SENT optional_params: {'tools': [{'type': 'function', 'function': {'name': 'get_current_weather', 'description': 'Get the current weather in a given location', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}}], 'tool_choice': {'type': 'auto'}, 'max_tokens': 4096}
|
||||
tool: {'type': 'function', 'function': {'name': 'get_current_weather', 'description': 'Get the current weather in a given location', 'parameters': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}}
|
||||
[92m
|
||||
|
||||
POST Request Sent from LiteLLM:
|
||||
curl -X POST \
|
||||
https://api.anthropic.com/v1/messages \
|
||||
-H 'accept: *****' -H 'anthropic-version: *****' -H 'content-type: *****' -H 'x-api-key: sk-ant-api03-bJf1M8qp-JDptRcZRE5ve5efAfSIaL5u-SZ9vItIkvuFcV5cUsd********************************************' -H 'anthropic-beta: *****' \
|
||||
-d '{'messages': [{'role': 'user', 'content': [{'type': 'text', 'text': "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses"}]}], 'tools': [{'name': 'get_current_weather', 'description': 'Get the current weather in a given location', 'input_schema': {'type': 'object', 'properties': {'location': {'type': 'string', 'description': 'The city and state'}, 'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']}}, 'required': ['location']}}], 'tool_choice': {'type': 'auto'}, 'max_tokens': 4096, 'model': 'claude-3-haiku-20240307'}'
|
||||
[0m
|
||||
|
||||
_is_function_call: False
|
||||
RAW RESPONSE:
|
||||
{"id":"msg_01HRugqzL4WmcxMmbvDheTph","type":"message","role":"assistant","model":"claude-3-haiku-20240307","content":[{"type":"text","text":"Okay, let's check the current weather in those three cities:"},{"type":"tool_use","id":"toolu_016U6G3kpxjHSiJLwVCrrScz","name":"get_current_weather","input":{"location":"San Francisco","unit":"celsius"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":379,"output_tokens":87}}
|
||||
|
||||
|
||||
raw model_response: {"id":"msg_01HRugqzL4WmcxMmbvDheTph","type":"message","role":"assistant","model":"claude-3-haiku-20240307","content":[{"type":"text","text":"Okay, let's check the current weather in those three cities:"},{"type":"tool_use","id":"toolu_016U6G3kpxjHSiJLwVCrrScz","name":"get_current_weather","input":{"location":"San Francisco","unit":"celsius"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":379,"output_tokens":87}}
|
||||
Logging Details LiteLLM-Success Call: Cache_hit=None
|
||||
Looking up model=claude-3-haiku-20240307 in model_cost_map
|
||||
Looking up model=claude-3-haiku-20240307 in model_cost_map
|
||||
Response
|
||||
ModelResponse(id='chatcmpl-7222f6c2-962a-4776-8639-576723466cb7', choices=[Choices(finish_reason='tool_calls', index=0, message=Message(content="Okay, let's check the current weather in those three cities:", role='assistant', tool_calls=[ChatCompletionMessageToolCall(index=1, function=Function(arguments='{"location": "San Francisco", "unit": "celsius"}', name='get_current_weather'), id='toolu_016U6G3kpxjHSiJLwVCrrScz', type='function')], function_call=None))], created=1727897483, model='claude-3-haiku-20240307', object='chat.completion', system_fingerprint=None, usage=Usage(completion_tokens=87, prompt_tokens=379, total_tokens=466, completion_tokens_details=None))
|
||||
length of tool calls 1
|
||||
Expecting there to be 3 tool calls
|
||||
tool_calls: [ChatCompletionMessageToolCall(index=1, function=Function(arguments='{"location": "San Francisco", "unit": "celsius"}', name='get_current_weather'), id='toolu_016U6G3kpxjHSiJLwVCrrScz', type='function')]
|
||||
Response message
|
||||
Message(content="Okay, let's check the current weather in those three cities:", role='assistant', tool_calls=[ChatCompletionMessageToolCall(index=1, function=Function(arguments='{"location": "San Francisco", "unit": "celsius"}', name='get_current_weather'), id='toolu_016U6G3kpxjHSiJLwVCrrScz', type='function')], function_call=None)
|
||||
messages: [{'role': 'user', 'content': "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses"}, Message(content="Okay, let's check the current weather in those three cities:", role='assistant', tool_calls=[ChatCompletionMessageToolCall(index=1, function=Function(arguments='{"location": "San Francisco", "unit": "celsius"}', name='get_current_weather'), id='toolu_016U6G3kpxjHSiJLwVCrrScz', type='function')], function_call=None), {'tool_call_id': 'toolu_016U6G3kpxjHSiJLwVCrrScz', 'role': 'tool', 'name': 'get_current_weather', 'content': '{"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"}'}]
|
||||
|
||||
|
||||
[92mRequest to litellm:[0m
|
||||
[92mlitellm.completion(model='claude-3-haiku-20240307', messages=[{'role': 'user', 'content': "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses"}, Message(content="Okay, let's check the current weather in those three cities:", role='assistant', tool_calls=[ChatCompletionMessageToolCall(index=1, function=Function(arguments='{"location": "San Francisco", "unit": "celsius"}', name='get_current_weather'), id='toolu_016U6G3kpxjHSiJLwVCrrScz', type='function')], function_call=None), {'tool_call_id': 'toolu_016U6G3kpxjHSiJLwVCrrScz', 'role': 'tool', 'name': 'get_current_weather', 'content': '{"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"}'}], temperature=0.2, seed=22, drop_params=True)[0m
|
||||
|
||||
|
||||
SYNC kwargs[caching]: False; litellm.cache: None; kwargs.get('cache')['no-cache']: False
|
||||
Final returned optional params: {'temperature': 0.2, 'tools': [{'type': 'function', 'function': {'name': 'dummy-tool', 'description': '', 'parameters': {'type': 'object', 'properties': {}}}}]}
|
||||
optional_params: {'temperature': 0.2, 'tools': [{'type': 'function', 'function': {'name': 'dummy-tool', 'description': '', 'parameters': {'type': 'object', 'properties': {}}}}]}
|
||||
SENT optional_params: {'temperature': 0.2, 'tools': [{'type': 'function', 'function': {'name': 'dummy-tool', 'description': '', 'parameters': {'type': 'object', 'properties': {}}}}], 'max_tokens': 4096}
|
||||
tool: {'type': 'function', 'function': {'name': 'dummy-tool', 'description': '', 'parameters': {'type': 'object', 'properties': {}}}}
|
||||
[92m
|
||||
|
||||
POST Request Sent from LiteLLM:
|
||||
curl -X POST \
|
||||
https://api.anthropic.com/v1/messages \
|
||||
-H 'accept: *****' -H 'anthropic-version: *****' -H 'content-type: *****' -H 'x-api-key: sk-ant-api03-bJf1M8qp-JDptRcZRE5ve5efAfSIaL5u-SZ9vItIkvuFcV5cUsd********************************************' -H 'anthropic-beta: *****' \
|
||||
-d '{'messages': [{'role': 'user', 'content': [{'type': 'text', 'text': "What's the weather like in San Francisco, Tokyo, and Paris? - give me 3 responses"}]}, {'role': 'assistant', 'content': [{'type': 'tool_use', 'id': 'toolu_016U6G3kpxjHSiJLwVCrrScz', 'name': 'get_current_weather', 'input': {'location': 'San Francisco', 'unit': 'celsius'}}]}, {'role': 'user', 'content': [{'type': 'tool_result', 'tool_use_id': 'toolu_016U6G3kpxjHSiJLwVCrrScz', 'content': '{"location": "San Francisco", "temperature": "72", "unit": "fahrenheit"}'}]}], 'temperature': 0.2, 'tools': [{'name': 'dummy-tool', 'description': '', 'input_schema': {'type': 'object', 'properties': {}}}], 'max_tokens': 4096, 'model': 'claude-3-haiku-20240307'}'
|
||||
[0m
|
||||
|
||||
_is_function_call: False
|
||||
RAW RESPONSE:
|
||||
{"id":"msg_01Wp8NVScugz6yAGsmB5trpZ","type":"message","role":"assistant","model":"claude-3-haiku-20240307","content":[{"type":"text","text":"The current weather in San Francisco is 72°F (22°C)."},{"type":"tool_use","id":"toolu_01HTXEYDX4MspM76STtJqs1n","name":"get_current_weather","input":{"location":"Tokyo","unit":"celsius"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":426,"output_tokens":90}}
|
||||
|
||||
|
||||
raw model_response: {"id":"msg_01Wp8NVScugz6yAGsmB5trpZ","type":"message","role":"assistant","model":"claude-3-haiku-20240307","content":[{"type":"text","text":"The current weather in San Francisco is 72°F (22°C)."},{"type":"tool_use","id":"toolu_01HTXEYDX4MspM76STtJqs1n","name":"get_current_weather","input":{"location":"Tokyo","unit":"celsius"}}],"stop_reason":"tool_use","stop_sequence":null,"usage":{"input_tokens":426,"output_tokens":90}}
|
||||
Logging Details LiteLLM-Success Call: Cache_hit=None
|
||||
Looking up model=claude-3-haiku-20240307 in model_cost_map
|
||||
Looking up model=claude-3-haiku-20240307 in model_cost_map
|
||||
second response
|
||||
ModelResponse(id='chatcmpl-c4ed5c25-ba7c-49e5-a6be-5720ab25fff0', choices=[Choices(finish_reason='tool_calls', index=0, message=Message(content='The current weather in San Francisco is 72°F (22°C).', role='assistant', tool_calls=[ChatCompletionMessageToolCall(index=1, function=Function(arguments='{"location": "Tokyo", "unit": "celsius"}', name='get_current_weather'), id='toolu_01HTXEYDX4MspM76STtJqs1n', type='function')], function_call=None))], created=1727897484, model='claude-3-haiku-20240307', object='chat.completion', system_fingerprint=None, usage=Usage(completion_tokens=90, prompt_tokens=426, total_tokens=516, completion_tokens_details=None))
|
||||
PASSED
|
||||
|
||||
=============================== warnings summary ===============================
|
||||
../../myenv/lib/python3.11/site-packages/pydantic/_internal/_config.py:284
|
||||
/Users/krrishdholakia/Documents/litellm/myenv/lib/python3.11/site-packages/pydantic/_internal/_config.py:284: PydanticDeprecatedSince20: Support for class-based `config` is deprecated, use ConfigDict instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/
|
||||
warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning)
|
||||
|
||||
../../litellm/utils.py:17
|
||||
/Users/krrishdholakia/Documents/litellm/litellm/utils.py:17: DeprecationWarning: 'imghdr' is deprecated and slated for removal in Python 3.13
|
||||
import imghdr
|
||||
|
||||
../../litellm/utils.py:124
|
||||
/Users/krrishdholakia/Documents/litellm/litellm/utils.py:124: DeprecationWarning: open_text is deprecated. Use files() instead. Refer to https://importlib-resources.readthedocs.io/en/latest/using.html#migrating-from-legacy for migration advice.
|
||||
with resources.open_text("litellm.llms.tokenizers", "anthropic_tokenizer.json") as f:
|
||||
|
||||
test_function_calling.py:56
|
||||
/Users/krrishdholakia/Documents/litellm/tests/local_testing/test_function_calling.py:56: PytestUnknownMarkWarning: Unknown pytest.mark.flaky - is this a typo? You can register custom marks to avoid this warning - for details, see https://docs.pytest.org/en/stable/how-to/mark.html
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
|
||||
tests/local_testing/test_function_calling.py::test_aaparallel_function_call[claude-3-haiku-20240307]
|
||||
tests/local_testing/test_function_calling.py::test_aaparallel_function_call[claude-3-haiku-20240307]
|
||||
/Users/krrishdholakia/Documents/litellm/myenv/lib/python3.11/site-packages/httpx/_content.py:202: DeprecationWarning: Use 'content=<...>' to upload raw bytes/text content.
|
||||
warnings.warn(message, DeprecationWarning)
|
||||
|
||||
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
|
||||
======================== 1 passed, 6 warnings in 1.89s =========================
|
Loading…
Add table
Add a link
Reference in a new issue