diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index 3ae2db701..24a076dd0 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -487,7 +487,7 @@ def format_prompt_togetherai(messages, prompt_format, chat_template): def ibm_granite_pt(messages: list): """ - IBM's Granite chat models uses the template: + IBM's Granite models uses the template: <|system|> {system_message} <|user|> {user_message} <|assistant|> {assistant_message} See: https://www.ibm.com/docs/en/watsonx-as-a-service?topic=solutions-supported-foundation-models @@ -503,13 +503,12 @@ def ibm_granite_pt(messages: list): "pre_message": "<|user|>\n", "post_message": "\n", }, - 'assistant': { - 'pre_message': '<|assistant|>\n', - 'post_message': '\n', + "assistant": { + "pre_message": "<|assistant|>\n", + "post_message": "\n", }, }, - final_prompt_value='<|assistant|>\n', - ) + ).strip() ### ANTHROPIC ### @@ -1525,9 +1524,24 @@ def prompt_factory( return mistral_instruct_pt(messages=messages) elif "meta-llama/llama-3" in model and "instruct" in model: # https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/ - return hf_chat_template( - model="meta-llama/Meta-Llama-3-8B-Instruct", + return custom_prompt( + role_dict={ + "system": { + "pre_message": "<|start_header_id|>system<|end_header_id|>\n", + "post_message": "<|eot_id|>", + }, + "user": { + "pre_message": "<|start_header_id|>user<|end_header_id|>\n", + "post_message": "<|eot_id|>", + }, + "assistant": { + "pre_message": "<|start_header_id|>assistant<|end_header_id|>\n", + "post_message": "<|eot_id|>", + }, + }, messages=messages, + initial_prompt_value="<|begin_of_text|>", + final_prompt_value="<|start_header_id|>assistant<|end_header_id|>\n", ) try: if "meta-llama/llama-2" in model and "chat" in model: diff --git a/litellm/llms/watsonx.py b/litellm/llms/watsonx.py index 082cdb325..99f2d18ba 100644 --- a/litellm/llms/watsonx.py +++ b/litellm/llms/watsonx.py @@ -1,13 +1,12 @@ from enum import Enum import json, types, time # noqa: E401 -from contextlib import asynccontextmanager, contextmanager -from typing import AsyncGenerator, Callable, Dict, Generator, Optional, Any, Union, List +from contextlib import contextmanager +from typing import Callable, Dict, Optional, Any, Union, List import httpx # type: ignore import requests # type: ignore import litellm -from litellm.utils import Logging, ModelResponse, Usage, get_secret -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +from litellm.utils import ModelResponse, get_secret, Usage from .base import BaseLLM from .prompt_templates import factory as ptf @@ -193,7 +192,7 @@ class WatsonXAIEndpoint(str, Enum): class IBMWatsonXAI(BaseLLM): """ - Class to interface with IBM watsonx.ai API for text generation and embeddings. + Class to interface with IBM Watsonx.ai API for text generation and embeddings. Reference: https://cloud.ibm.com/apidocs/watsonx-ai """ @@ -344,7 +343,7 @@ class IBMWatsonXAI(BaseLLM): ) if token is None and api_key is not None: # generate the auth token - if print_verbose is not None: + if print_verbose: print_verbose("Generating IAM token for Watsonx.ai") token = self.generate_iam_token(api_key) elif token is None and api_key is None: @@ -378,9 +377,8 @@ class IBMWatsonXAI(BaseLLM): model_response: ModelResponse, print_verbose: Callable, encoding, - logging_obj: Logging, - optional_params: Optional[dict] = None, - acompletion: bool = None, + logging_obj, + optional_params: dict, litellm_params: Optional[dict] = None, logger_fn=None, timeout: Optional[float] = None, @@ -403,15 +401,13 @@ class IBMWatsonXAI(BaseLLM): prompt = convert_messages_to_prompt( model, messages, provider, custom_prompt_dict ) - - manage_response = self._make_response_manager(async_=(acompletion is True), logging_obj=logging_obj) - def process_text_gen_response(json_resp: dict) -> ModelResponse: - if "results" not in json_resp: - raise WatsonXAIError( - status_code=500, - message=f"Error: Invalid response from Watsonx.ai API: {json_resp}", - ) + def process_text_request(request_params: dict) -> ModelResponse: + with self._manage_response( + request_params, logging_obj=logging_obj, input=prompt, timeout=timeout + ) as resp: + json_resp = resp.json() + generated_text = json_resp["results"][0]["generated_text"] prompt_tokens = json_resp["results"][0]["input_token_count"] completion_tokens = json_resp["results"][0]["generated_token_count"] @@ -430,52 +426,25 @@ class IBMWatsonXAI(BaseLLM): ) return model_response - def handle_text_request(request_params: dict) -> ModelResponse: - with manage_response( - request_params, input=prompt, timeout=timeout, - ) as resp: - json_resp = resp.json() - - return process_text_gen_response(json_resp) - - async def handle_text_request_async(request_params: dict) -> ModelResponse: - async with manage_response( - request_params, input=prompt, timeout=timeout, - ) as resp: - json_resp = resp.json() - return process_text_gen_response(json_resp) - - def handle_stream_request( + def process_stream_request( request_params: dict, ) -> litellm.CustomStreamWrapper: # stream the response - generated chunks will be handled # by litellm.utils.CustomStreamWrapper.handle_watsonx_stream - with manage_response( - request_params, stream=True, input=prompt, timeout=timeout, + with self._manage_response( + request_params, + logging_obj=logging_obj, + stream=True, + input=prompt, + timeout=timeout, ) as resp: - streamwrapper = litellm.CustomStreamWrapper( + response = litellm.CustomStreamWrapper( resp.iter_lines(), model=model, custom_llm_provider="watsonx", logging_obj=logging_obj, ) - return streamwrapper - - async def handle_stream_request_async( - request_params: dict, - ) -> litellm.CustomStreamWrapper: - # stream the response - generated chunks will be handled - # by litellm.utils.CustomStreamWrapper.handle_watsonx_stream - async with manage_response( - request_params, stream=True, input=prompt, timeout=timeout, - ) as resp: - streamwrapper = litellm.CustomStreamWrapper( - resp.aiter_lines(), - model=model, - custom_llm_provider="watsonx", - logging_obj=logging_obj, - ) - return streamwrapper + return response try: ## Get the response from the model @@ -486,18 +455,10 @@ class IBMWatsonXAI(BaseLLM): optional_params=optional_params, print_verbose=print_verbose, ) - if stream and acompletion: - # stream and async text generation - return handle_stream_request_async(req_params) - elif stream: - # streaming text generation - return handle_stream_request(req_params) - elif acompletion: - # async text generation - return handle_text_request_async(req_params) + if stream: + return process_stream_request(req_params) else: - # regular text generation - return handle_text_request(req_params) + return process_text_request(req_params) except WatsonXAIError as e: raise e except Exception as e: @@ -512,7 +473,6 @@ class IBMWatsonXAI(BaseLLM): model_response=None, optional_params=None, encoding=None, - aembedding=None, ): """ Send a text embedding request to the IBM Watsonx.ai API. @@ -547,6 +507,9 @@ class IBMWatsonXAI(BaseLLM): } request_params = dict(version=api_params["api_version"]) url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.EMBEDDINGS + # request = httpx.Request( + # "POST", url, headers=headers, json=payload, params=request_params + # ) req_params = { "method": "POST", "url": url, @@ -554,47 +517,25 @@ class IBMWatsonXAI(BaseLLM): "json": payload, "params": request_params, } - manage_response = self._make_response_manager(async_=(aembedding is True), logging_obj=logging_obj) - - def process_embedding_response(json_resp: dict) -> ModelResponse: - results = json_resp.get("results", []) - embedding_response = [] - for idx, result in enumerate(results): - embedding_response.append( - {"object": "embedding", "index": idx, "embedding": result["embedding"]} - ) - model_response["object"] = "list" - model_response["data"] = embedding_response - model_response["model"] = model - input_tokens = json_resp.get("input_token_count", 0) - model_response.usage = Usage( - prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens + with self._manage_response( + req_params, logging_obj=logging_obj, input=input + ) as resp: + json_resp = resp.json() + + results = json_resp.get("results", []) + embedding_response = [] + for idx, result in enumerate(results): + embedding_response.append( + {"object": "embedding", "index": idx, "embedding": result["embedding"]} ) - return model_response - - def handle_embedding_request(request_params: dict) -> ModelResponse: - with manage_response( - request_params, input=input - ) as resp: - json_resp = resp.json() - return process_embedding_response(json_resp) - - async def handle_embedding_request_async(request_params: dict) -> ModelResponse: - async with manage_response( - request_params, input=input - ) as resp: - json_resp = resp.json() - return process_embedding_response(json_resp) - - try: - if aembedding: - return handle_embedding_request_async(req_params) - else: - return handle_embedding_request(req_params) - except WatsonXAIError as e: - raise e - except Exception as e: - raise WatsonXAIError(status_code=500, message=str(e)) + model_response["object"] = "list" + model_response["data"] = embedding_response + model_response["model"] = model + input_tokens = json_resp.get("input_token_count", 0) + model_response.usage = Usage( + prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens + ) + return model_response def generate_iam_token(self, api_key=None, **params): headers = {} @@ -616,116 +557,53 @@ class IBMWatsonXAI(BaseLLM): iam_access_token = json_data["access_token"] self.token = iam_access_token return iam_access_token - - def _make_response_manager( - self, - async_: bool, - logging_obj: Logging - ) -> Callable[..., Generator[Union[requests.Response, httpx.Response], None, None]]: - """ - Returns a context manager that manages the response from the request. - if async_ is True, returns an async context manager, otherwise returns a regular context manager. - Usage: - ```python - manage_response = self._make_response_manager(async_=True, logging_obj=logging_obj) - async with manage_response(request_params) as resp: - ... - # or - manage_response = self._make_response_manager(async_=False, logging_obj=logging_obj) - with manage_response(request_params) as resp: - ... - ``` - """ - - def pre_call( - request_params: dict, - input:Optional[Any]=None, - ): - request_str = ( - f"response = {'await ' if async_ else ''}{request_params['method']}(\n" - f"\turl={request_params['url']},\n" - f"\tjson={request_params['json']},\n" - f")" - ) - logging_obj.pre_call( - input=input, - api_key=request_params["headers"].get("Authorization"), - additional_args={ - "complete_input_dict": request_params["json"], - "request_str": request_str, - }, - ) - - def post_call(resp, request_params): + @contextmanager + def _manage_response( + self, + request_params: dict, + logging_obj: Any, + stream: bool = False, + input: Optional[Any] = None, + timeout: Optional[float] = None, + ): + request_str = ( + f"response = {request_params['method']}(\n" + f"\turl={request_params['url']},\n" + f"\tjson={request_params['json']},\n" + f")" + ) + logging_obj.pre_call( + input=input, + api_key=request_params["headers"].get("Authorization"), + additional_args={ + "complete_input_dict": request_params["json"], + "request_str": request_str, + }, + ) + if timeout: + request_params["timeout"] = timeout + try: + if stream: + resp = requests.request( + **request_params, + stream=True, + ) + resp.raise_for_status() + yield resp + else: + resp = requests.request(**request_params) + resp.raise_for_status() + yield resp + except Exception as e: + raise WatsonXAIError(status_code=500, message=str(e)) + if not stream: logging_obj.post_call( input=input, api_key=request_params["headers"].get("Authorization"), original_response=json.dumps(resp.json()), additional_args={ "status_code": resp.status_code, - "complete_input_dict": request_params.get("data", request_params.get("json")), + "complete_input_dict": request_params["json"], }, ) - - @contextmanager - def _manage_response( - request_params: dict, - stream: bool = False, - input: Optional[Any] = None, - timeout: float = None, - ) -> Generator[requests.Response, None, None]: - """ - Returns a context manager that yields the response from the request. - """ - pre_call(request_params, input) - if timeout: - request_params["timeout"] = timeout - if stream: - request_params["stream"] = stream - try: - resp = requests.request(**request_params) - resp.raise_for_status() - yield resp - except Exception as e: - raise WatsonXAIError(status_code=500, message=str(e)) - if not stream: - post_call(resp, request_params) - - - @asynccontextmanager - async def _manage_response_async( - request_params: dict, - stream: bool = False, - input: Optional[Any] = None, - timeout: float = None, - ) -> AsyncGenerator[httpx.Response, None]: - pre_call(request_params, input) - if timeout: - request_params["timeout"] = timeout - if stream: - request_params["stream"] = stream - try: - # async with AsyncHTTPHandler(timeout=timeout) as client: - self.async_handler = AsyncHTTPHandler( - timeout=httpx.Timeout(timeout=request_params.pop("timeout", 600.0), connect=5.0), - ) - # async_handler.client.verify = False - if "json" in request_params: - request_params['data'] = json.dumps(request_params.pop("json", {})) - method = request_params.pop("method") - if method.upper() == "POST": - resp = await self.async_handler.post(**request_params) - else: - resp = await self.async_handler.get(**request_params) - yield resp - # await async_handler.close() - except Exception as e: - raise WatsonXAIError(status_code=500, message=str(e)) - if not stream: - post_call(resp, request_params) - - if async_: - return _manage_response_async - else: - return _manage_response diff --git a/litellm/main.py b/litellm/main.py index 1298dec4c..99e5ec224 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -73,7 +73,6 @@ from .llms.azure_text import AzureTextCompletion from .llms.anthropic import AnthropicChatCompletion from .llms.anthropic_text import AnthropicTextCompletion from .llms.huggingface_restapi import Huggingface -from .llms.watsonx import IBMWatsonXAI from .llms.prompt_templates.factory import ( prompt_factory, custom_prompt, @@ -110,7 +109,6 @@ anthropic_text_completions = AnthropicTextCompletion() azure_chat_completions = AzureChatCompletion() azure_text_completions = AzureTextCompletion() huggingface = Huggingface() -watsonxai = IBMWatsonXAI() ####### COMPLETION ENDPOINTS ################ @@ -315,7 +313,6 @@ async def acompletion( or custom_llm_provider == "gemini" or custom_llm_provider == "sagemaker" or custom_llm_provider == "anthropic" - or custom_llm_provider == "watsonx" or custom_llm_provider in litellm.openai_compatible_providers ): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all. init_response = await loop.run_in_executor(None, func_with_context) @@ -1911,7 +1908,7 @@ def completion( response = response elif custom_llm_provider == "watsonx": custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - response = watsonxai.completion( + response = watsonx.IBMWatsonXAI().completion( model=model, messages=messages, custom_prompt_dict=custom_prompt_dict, @@ -1922,8 +1919,7 @@ def completion( logger_fn=logger_fn, encoding=encoding, logging_obj=logging, - acompletion=acompletion, - timeout=timeout, + timeout=timeout, # type: ignore ) if ( "stream" in optional_params @@ -2576,7 +2572,6 @@ async def aembedding(*args, **kwargs): or custom_llm_provider == "fireworks_ai" or custom_llm_provider == "ollama" or custom_llm_provider == "vertex_ai" - or custom_llm_provider == "watsonx" ): # currently implemented aiohttp calls for just azure and openai, soon all. # Await normally init_response = await loop.run_in_executor(None, func_with_context) @@ -3034,14 +3029,13 @@ def embedding( aembedding=aembedding, ) elif custom_llm_provider == "watsonx": - response = watsonxai.embedding( + response = watsonx.IBMWatsonXAI().embedding( model=model, input=input, encoding=encoding, logging_obj=logging, optional_params=optional_params, model_response=EmbeddingResponse(), - aembedding=aembedding, ) else: args = locals() diff --git a/litellm/utils.py b/litellm/utils.py index d1af1b44a..c03d4e2bc 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -10567,18 +10567,6 @@ class CustomStreamWrapper: elif self.custom_llm_provider == "watsonx": response_obj = self.handle_watsonx_stream(chunk) completion_obj["content"] = response_obj["text"] - print_verbose(f"completion obj content: {completion_obj['content']}") - if getattr(model_response, "usage", None) is None: - model_response.usage = Usage() - if response_obj.get("prompt_tokens") is not None: - prompt_token_count = getattr(model_response.usage, "prompt_tokens", 0) - model_response.usage.prompt_tokens = (prompt_token_count+response_obj["prompt_tokens"]) - if response_obj.get("completion_tokens") is not None: - model_response.usage.completion_tokens = response_obj["completion_tokens"] - model_response.usage.total_tokens = ( - getattr(model_response.usage, "prompt_tokens", 0) - + getattr(model_response.usage, "completion_tokens", 0) - ) if response_obj["is_finished"]: self.received_finish_reason = response_obj["finish_reason"] elif self.custom_llm_provider == "text-completion-openai": @@ -10983,7 +10971,6 @@ class CustomStreamWrapper: or self.custom_llm_provider == "sagemaker" or self.custom_llm_provider == "gemini" or self.custom_llm_provider == "cached_response" - or self.custom_llm_provider == "watsonx" or self.custom_llm_provider in litellm.openai_compatible_endpoints ): async for chunk in self.completion_stream: