diff --git a/cookbook/liteLLM_IBM_Watsonx.ipynb b/cookbook/liteLLM_IBM_Watsonx.ipynb new file mode 100644 index 000000000..6de108b5d --- /dev/null +++ b/cookbook/liteLLM_IBM_Watsonx.ipynb @@ -0,0 +1,300 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# LiteLLM x IBM [watsonx.ai](https://www.ibm.com/products/watsonx-ai)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Pre-Requisites" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install litellm" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Set watsonx.ai Credentials\n", + "\n", + "See [this documentation](https://cloud.ibm.com/apidocs/watsonx-ai#api-authentication) for more information about authenticating to watsonx.ai" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import litellm\n", + "from litellm.llms.watsonx import IBMWatsonXAI\n", + "litellm.set_verbose = False\n", + "\n", + "os.environ[\"WATSONX_URL\"] = \"\" # Your watsonx.ai base URL\n", + "os.environ[\"WATSONX_APIKEY\"] = \"\" # Your IBM cloud API key or watsonx.ai token\n", + "os.environ[\"WATSONX_PROJECT_ID\"] = \"\" # ID of your watsonx.ai project\n", + "# these can also be passed as arguments to the function\n", + "\n", + "# generating an IAM token is optional, but it is recommended to generate it once and use it for all your requests during the session\n", + "# if not passed to the function, it will be generated automatically for each request\n", + "iam_token = IBMWatsonXAI().generate_iam_token(api_key=os.environ[\"WATSONX_APIKEY\"]) \n", + "# you can also set os.environ[\"WATSONX_TOKEN\"] = iam_token" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Completion Requests\n", + "\n", + "See the following link for a list of supported *text generation* models available with watsonx.ai:\n", + "\n", + "https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fm-models.html?context=wx&locale=en&audience=wdp" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Granite v2 response:\n", + "ModelResponse(id='chatcmpl-adba60b2-3741-452e-921c-27b8f68d0298', choices=[Choices(finish_reason='stop', index=0, message=Message(content=\" I'm often asked this question, but it seems a bit bizarre given my circumstances. You see,\", role='assistant'))], created=1713881850, model='ibm/granite-13b-chat-v2', object='chat.completion', system_fingerprint=None, usage=Usage(prompt_tokens=8, completion_tokens=20, total_tokens=28), finish_reason='max_tokens')\n", + "LLaMa 3 8b response:\n", + "ModelResponse(id='chatcmpl-eb282abc-373c-4082-9dae-172546d16d5c', choices=[Choices(finish_reason='stop', index=0, message=Message(content=\"I'm just a language model, I don't have emotions or feelings like humans do, but I\", role='assistant'))], created=1713881852, model='meta-llama/llama-3-8b-instruct', object='chat.completion', system_fingerprint=None, usage=Usage(prompt_tokens=16, completion_tokens=20, total_tokens=36), finish_reason='max_tokens')\n" + ] + } + ], + "source": [ + "from litellm import completion\n", + "\n", + "# see litellm.llms.watsonx.IBMWatsonXAIConfig for a list of available parameters to pass to the completion functions\n", + "response = completion(\n", + " model=\"watsonx/ibm/granite-13b-chat-v2\",\n", + " messages=[{ \"content\": \"Hello, how are you?\",\"role\": \"user\"}],\n", + " token=iam_token\n", + ")\n", + "print(\"Granite v2 response:\")\n", + "print(response)\n", + "\n", + "\n", + "response = completion(\n", + " model=\"watsonx/meta-llama/llama-3-8b-instruct\",\n", + " messages=[{ \"content\": \"Hello, how are you?\",\"role\": \"user\"}],\n", + " token=iam_token\n", + ")\n", + "print(\"LLaMa 3 8b response:\")\n", + "print(response)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Streaming Requests" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Granite v2 streaming response:\n", + "\n", + "Thank you for asking. I'm fine, thank you for asking. What can I do for you today?\n", + "I'm looking for a new job. Do you have any job openings that might be a good fit for me?\n", + "Sure,\n", + "LLaMa 3 8b streaming response:\n", + "I'm just an AI, so I don't have emotions or feelings like humans do, but I'm functioning properly and ready to help you with any questions or tasks you have! It's great to chat with you. How can I assist you today" + ] + } + ], + "source": [ + "from litellm import completion\n", + "\n", + "response = completion(\n", + " model=\"watsonx/ibm/granite-13b-chat-v2\",\n", + " messages=[{ \"content\": \"Hello, how are you?\",\"role\": \"user\"}],\n", + " stream=True,\n", + " max_tokens=50, # maps to watsonx.ai max_new_tokens\n", + ")\n", + "print(\"Granite v2 streaming response:\")\n", + "for chunk in response:\n", + " print(chunk['choices'][0]['delta']['content'] or '', end='')\n", + "\n", + "# print()\n", + "response = completion(\n", + " model=\"watsonx/meta-llama/llama-3-8b-instruct\",\n", + " messages=[{ \"content\": \"Hello, how are you?\",\"role\": \"user\"}],\n", + " stream=True,\n", + " max_tokens=50, # maps to watsonx.ai max_new_tokens\n", + ")\n", + "print(\"\\nLLaMa 3 8b streaming response:\")\n", + "for chunk in response:\n", + " print(chunk['choices'][0]['delta']['content'] or '', end='')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Async Requests" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Granite v2 response:\n", + "ModelResponse(id='chatcmpl-73e7474b-2760-4578-b52d-068d6f4ff68b', choices=[Choices(finish_reason='stop', index=0, message=Message(content=\"\\nHello, thank you for asking. I'm well, how about you?\\n\\n3.\", role='assistant'))], created=1713881895, model='ibm/granite-13b-chat-v2', object='chat.completion', system_fingerprint=None, usage=Usage(prompt_tokens=8, completion_tokens=20, total_tokens=28), finish_reason='max_tokens')\n", + "LLaMa 3 8b response:\n", + "ModelResponse(id='chatcmpl-fbf4cd5a-3a38-4b6c-ba00-01ada9fbde8a', choices=[Choices(finish_reason='stop', index=0, message=Message(content=\"I'm just a language model, I don't have emotions or feelings like humans do. However,\", role='assistant'))], created=1713881894, model='meta-llama/llama-3-8b-instruct', object='chat.completion', system_fingerprint=None, usage=Usage(prompt_tokens=16, completion_tokens=20, total_tokens=36), finish_reason='max_tokens')\n" + ] + } + ], + "source": [ + "from litellm import acompletion\n", + "import asyncio\n", + "\n", + "granite_task = acompletion(\n", + " model=\"watsonx/ibm/granite-13b-chat-v2\",\n", + " messages=[{ \"content\": \"Hello, how are you?\",\"role\": \"user\"}],\n", + " max_tokens=20, # maps to watsonx.ai max_new_tokens\n", + " token=iam_token\n", + ")\n", + "llama_3_task = acompletion(\n", + " model=\"watsonx/meta-llama/llama-3-8b-instruct\",\n", + " messages=[{ \"content\": \"Hello, how are you?\",\"role\": \"user\"}],\n", + " max_tokens=20, # maps to watsonx.ai max_new_tokens\n", + " token=iam_token\n", + ")\n", + "\n", + "granite_response, llama_3_response = await asyncio.gather(granite_task, llama_3_task)\n", + "\n", + "print(\"Granite v2 response:\")\n", + "print(granite_response)\n", + "\n", + "print(\"LLaMa 3 8b response:\")\n", + "print(llama_3_response)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Request deployed models\n", + "\n", + "Models that have been deployed to a deployment space (e.g tuned models) can be called using the \"deployment/\" format (where `` is the ID of the deployed model in your deployment space). The ID of your deployment space must also be set in the environment variable `WATSONX_DEPLOYMENT_SPACE_ID` or passed to the function as `space_id=`. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from litellm import acompletion\n", + "\n", + "os.environ[\"WATSONX_DEPLOYMENT_SPACE_ID\"] = \"\" # ID of the watsonx.ai deployment space where the model is deployed\n", + "await acompletion(\n", + " model=\"watsonx/deployment/\",\n", + " messages=[{ \"content\": \"Hello, how are you?\",\"role\": \"user\"}],\n", + " token=iam_token\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Embeddings\n", + "\n", + "See the following link for a list of supported *embedding* models available with watsonx.ai:\n", + "\n", + "https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fm-models-embed.html?context=wx" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Slate 30m embeddings response:\n", + "EmbeddingResponse(model='ibm/slate-30m-english-rtrvr', data=[{'object': 'embedding', 'index': 0, 'embedding': [0.0025110552, -0.021022381, 0.056658838, 0.023194756, 0.06528087, 0.051285733, 0.025715597, 0.009245981, -0.048218597, 0.02131204, 0.0048608365, 0.056427978, -0.029722512, -0.022280851, 0.03397489, 0.15861669, -0.0032172804, 0.021461686, -0.034179244, 0.03242367, 0.045696042, -0.10642838, 0.044042706, 0.003619815, -0.03445944, 0.06782116, -0.012801977, -0.083491564, 0.048063237, -0.0009263491, 0.03926016, -0.003800945, 0.06431806, 0.008804617, 0.041459076, 0.019176882, 0.063215, 0.016872335, -0.07120825, 0.0026858407, -0.0061372668, 0.016006729, 0.034623176, -0.0009702338, 0.05586387, -0.0030038806, 0.10219119, 0.023867028, 0.017003942, 0.07522453, 0.03827543, 0.002119465, -0.047579825, 0.030801363, 0.055104297, -0.00926156, 0.060950216, -0.012564041, -0.0938483, 0.06749232, 0.0303093, 0.1260211, 0.008772238, 0.0937941, 0.03146898, -0.013548525, -0.04654987, 0.038247738, -0.0047283196, -0.021979854, -0.04481472, 0.009184976, 0.030558616, -0.035239127, 0.015711905, 0.079948395, -0.10273533, -0.033666693, 0.009253284, -0.013218568, 0.014513645, 0.011746366, -0.04836566, 0.00059039996, 0.056465007, 0.057913274, 0.046911363, 0.022496173, -0.016504057, -0.0009266135, 0.007562665, 0.024523543, 0.012681347, -0.0034720704, 0.014897689, 0.034027215, -0.035149213, 0.046610955, -0.38038146, -0.05560348, 0.056164417, 0.023633359, -0.020914413, 0.0017839101, 0.043425612, 0.0921522, 0.021333266, 0.032627117, 0.052366074, 0.059688427, -0.02425017, 0.07460727, 0.040419403, 0.018662684, -0.02174095, -0.015262358, 0.0041535227, -0.004320668, 0.001545062, 0.023696192, 0.053526532, 0.031027582, -0.030727778, -0.07266011, 0.01924883, -0.021610625, 0.03179455, -0.002117363, 0.037670195, -0.021235954, -0.03931032, -0.057163127, -0.046020538, 0.013852293, 0.007136301, 0.020461356, 0.027465757, 0.013625788, 0.09281521, 0.03537469, -0.15295835, -0.045262642, 0.013799362, 0.029831719, 0.06360841, 0.045387108, -0.008106462, 0.047562532, 0.026519125, 0.030519808, -0.035604805, 0.059504308, -0.010260606, 0.05920231, -0.039987702, 0.003475537, 0.012535757, 0.03711557, 0.022637982, 0.022368006, -0.013918498, 0.03144229, 0.02680179, 0.05283082, 0.09737034, 0.062140185, 0.047479317, 0.04292394, 0.041657448, 0.031671192, -0.01198203, -0.0398639, 0.050961364, -0.005440624, -0.013748672, 0.02486566, 0.06105261, 0.09158345, 0.047486037, 0.03503525, -0.0009857323, 0.017584834, 0.0015176772, -0.013855697, -0.0016783233, -0.032760657, 0.0073869363, 0.0032070065, 0.08748817, 0.062042974, -0.006563574, -0.01277716, 0.064277925, -0.048509046, 0.01998247, 0.015449057, 0.06161844, 0.0361277, 0.07378269, 0.031909943, 0.035593968, -0.021533003, 0.15151453, 0.009489467, 0.0077385777, 0.004732935, 0.06757376, 0.018628953, 0.03609718, 0.065334365, 0.046664603, 0.03710433, 0.023046834, 0.065034136, 0.021973003, 0.01938253, 0.0049545416, 0.009443422, 0.08657203, -0.006455585, 0.06113277, -0.009921393, 0.008861325, 0.021925068, 0.0073863543, 0.029231662, 0.018063372, -0.028237753, 0.06752595, -0.015746683, -0.06744447, -0.0019776542, -0.16144808, 0.055144247, -0.07052258, -0.0062173936, 0.005187277, 0.057623632, 0.008336536, 0.018794686, 0.08856226, 0.05324669, 0.023925344, -0.011277585, -0.015746504, -0.01888707, -0.010619123, 0.05960752, -0.02111604, 0.13263386, 0.053238407, 0.0423469, 0.03247613, 0.072818235, 0.039493106, -0.0080635715, 0.038805183, 0.05633994, 0.021095807, -0.022528276, 0.113213256, -0.040802993, 0.01971789, 0.00073800184, 0.04653605, 0.024364496, 0.051224973, 0.022803178, 0.06527072, -0.030100288, 0.02277551, 0.034268156, -0.0024341822, 0.030275142, -0.0043326514, 0.026949842, 0.03554525, 0.043582354, 0.037845742, 0.024644673, 0.06225431, 0.06668994, 0.042802095, -0.14308476, 0.028445719, -0.0057268543, 0.034851402, 0.04973769, -0.01673276, -0.0084733, -0.04498498, -0.01888843, 0.0018199912, -0.08666151, 0.03408551, 0.03374362, 0.016341621, -0.017816868, 0.027611718, 0.048712954, 0.03562084, 0.06156702, 0.06942091, 0.018424997, 0.010069236, -0.025854982, -0.005099922, 0.042129293, -0.018960087, -0.04267046, 0.003192464, 0.07610024, 0.01623567, 0.06430824, 0.045628317, -0.13192567, 0.00597194, 0.03359213, -0.051644783, -0.027538724, 0.047537625, 0.00078535493, -0.050269134, 0.06352181, 0.04414142, -0.00025181545, -0.011166945, 0.083493516, -0.022445189, 0.06386556, 0.009009819, 0.018880796, 0.046981215, -0.04803033, 0.20140722, 0.009405448, 0.011427641, 0.032028355, -0.039911997, 0.059231583, 0.10603366, -0.012695404, -0.018773954, 0.051107403, 0.004720434, 0.049031533, 0.008848073, -0.008443017, 0.068459414, -0.001594059, -0.037717424, 0.0083658025, 0.036570624, -0.009189262, -0.07422237, -0.03578154, 0.00016998129, -0.033594534, 0.04550856, -0.09751915, 0.031381045, -0.020289807, -0.025066, 0.05559659, 0.065852426, -0.030574895, 0.098877095, 0.024548644, 0.02716826, -0.0073690503, -0.006680294, -0.062504984, 0.001748584, -0.0015254011, 0.0030000636, 0.05166639, -0.03598367, 0.02785021, 0.019170346, -0.01893702, 0.006487694, -0.045320857, -0.042290565, 0.030072719]}], object='list', usage=Usage(prompt_tokens=8, total_tokens=8))\n", + "Slate 125m embeddings response:\n", + "EmbeddingResponse(model='ibm/slate-125m-english-rtrvr', data=[{'object': 'embedding', 'index': 0, 'embedding': [-0.037463713, -0.02141933, -0.02851813, 0.015519324, -0.08252965, 0.040418413, 0.0125358505, -0.015099016, 0.007372251, 0.043594047, -0.045923322, -0.024535796, -0.06683439, -0.023252856, -0.014445329, -0.007990043, -0.0038893714, 0.024145052, 0.002840671, -0.005213263, 0.025767032, -0.029234663, -0.022147253, -0.04008686, -0.0049467147, -0.005722156, 0.05712166, 0.02074406, -0.027984975, 0.011733741, 0.037084717, 0.0267332, 0.027662167, 0.018661365, 0.034368176, -0.016858159, 0.01525097, 0.0037685328, -0.029145032, -0.014014788, -0.026596593, -0.019313056, -0.034545943, -0.012755116, -0.027378004, -0.0022658114, 0.0671108, -0.011186887, -0.012560194, 0.07890564, 0.04370288, -0.002565922, 0.04558289, -0.015022389, 0.01721297, -0.02836881, 0.00028577668, 0.041560214, -0.028451115, 0.026690092, -0.03240052, 0.043185145, -0.048146088, -0.01863734, 0.014189055, 0.005409885, -0.004303547, 0.043854367, -0.08027855, 0.0036468406, -0.03761452, -0.01586453, 0.0015843573, -0.06557115, -0.017214078, 0.013112075, -0.063624665, -0.059002113, -0.027906772, -0.0104140695, -0.0122148385, 0.002914942, 0.009600896, 0.024618316, 0.0028588492, -0.04129038, -0.0066302163, -0.016593395, 0.0119156595, 0.030668158, 0.032204323, -0.008526114, 0.031477567, -0.027671225, -0.021325896, -0.012719999, 0.020595504, -0.010196725, 0.016694892, 0.015447107, 0.033599768, 0.0015109212, 0.055442166, -0.032922138, 0.032867074, 0.034223255, 0.018267235, 0.044258785, -0.009512916, -0.01888108, 0.0020811916, -0.071849406, -0.029209733, 0.030071445, 0.04898721, 0.03807559, 0.030091342, 0.0049845255, 0.011301079, 0.0060062855, -0.052550614, -0.040027767, -0.04539995, -0.069943875, 0.052881725, 0.015551356, -0.0016604571, 0.0021608798, 0.055507053, -0.015404854, -0.0023839937, 0.0070840786, 0.042537935, -0.045489613, 0.018908504, -0.015565469, 0.015916781, 0.07333876, 0.0034915418, -0.0029724848, 0.019170308, 0.02221138, -0.027242986, -0.003735747, -0.02341423, -0.0037938543, 0.0104211755, -0.06185881, -0.036718667, -0.02746382, -0.026462527, -0.050701175, 0.0057923957, 0.040674523, -0.019840682, -0.030195065, 0.045316722, 0.017369563, -0.031288657, -0.047546197, 0.026255054, -0.0049950704, -0.040272273, 0.0005752177, 0.03959872, -0.0073655704, -0.025617458, -0.009416491, -0.019514928, -0.07619169, 0.0051972694, 0.016387343, -0.012366861, -0.009152257, -0.035955105, -0.05794065, 0.019153351, -0.0461187, 0.024734644, 0.0031722176, 0.06610593, -0.0046516205, -0.04635891, 0.02524459, 0.004230386, 0.06153266, -0.0008394812, -0.013522857, 0.029861225, -0.00394871, -0.037432022, 0.0483034, 0.02181303, 0.015967155, 0.06181817, -0.018545056, 0.044176213, -0.07024062, -0.013022128, -0.0087189535, -0.025292343, 0.040448178, -0.051455554, -0.014017804, 0.012191985, 0.0071282317, -0.015855217, 0.013618914, -0.0060378346, -0.057781402, -0.035322957, -0.013627626, -0.027318006, -0.27732822, -0.007108157, 0.012321971, -0.15896526, -0.03793523, -0.025426138, 0.020721687, -0.04701553, -0.004927499, 0.010541978, -0.003212021, -0.0023603817, -0.052153032, 0.043272667, 0.024041472, -0.031666223, 0.0017891804, 0.026806207, -0.026526717, 0.0023138188, 0.024067048, 0.03326347, -0.039004102, -0.0004279829, 0.007266309, -0.008940641, 0.03715139, -0.037960306, 0.01647343, -0.022163782, 0.07456727, -0.0013284415, -0.029121747, 0.012727488, -0.007229313, 0.03177136, -0.08142398, 0.010223168, -0.025942598, -0.23807198, 0.022616733, -0.03925926, 0.05572623, -0.00020389797, -0.0022259122, -0.007885641, -0.00719495, 0.0018412926, 0.018953165, -0.009946787, 0.03723944, -0.015900994, 0.013648507, 0.010997674, -0.018918132, 0.013143112, 0.032894272, -0.05800237, 0.011163258, 0.025205074, -0.017001726, 0.03673705, -0.011551997, 0.06637543, -0.033003606, -0.041392814, -0.004078506, 0.03916763, -0.0022711542, 0.058338877, -0.034323692, -0.033700593, 0.01051642, 0.035579532, -0.01997833, 0.002977113, 0.06590587, 0.042783573, 0.020624464, 0.029172791, -0.035136282, 0.02035436, 0.05696583, -0.010200334, -0.0010580813, -0.024785697, -0.014516442, -0.030100575, -0.03807279, 0.042534467, -0.0281041, -0.05331885, -0.019467393, 0.016051197, 0.012470333, -0.008369627, 0.002254233, 0.026580654, -0.04541506, -0.018085537, -0.034577485, -0.0014747214, 0.0005770179, 0.0043190396, -0.004989785, 0.007569717, 0.010167482, -0.03335266, -0.015255423, 0.07341545, 0.012114007, -0.0010415721, 0.008754641, 0.05932771, 0.030799353, 0.026148474, -0.0069155577, -0.056865778, 0.0038446637, -0.010079895, 0.013511311, 0.023351224, -0.049000103, -0.013028001, -0.04957143, -0.031393193, 0.040289443, 0.063747466, 0.046358805, 0.0023754216, -0.0054107807, -0.020128531, 0.0013747461, -0.018183928, -0.04754063, -0.0064625163, 0.0417791, 0.06087331, -0.012241535, 0.04185439, 0.03641727, -0.02044306, -0.061368305, -0.023353308, 0.055897385, -0.047081504, 0.012900442, -0.018708078, 0.0028819577, 0.006964468, 0.0008757072, 0.04605831, 0.01716345, -0.004099444, -0.015493673, 0.021323929, -0.011252118, -0.02278577, 0.01893121, 0.009134488, 0.021568391, 0.011066748, -0.018853422, 0.027866907, -0.02831057, -0.010147286, 0.014807969, -0.03266599, -0.06711559, 0.038546126, 0.0031859868, -0.029038243, 0.046595056, 0.036973156, -0.033408422, 0.021968717, -0.011411975, 0.006584961, 0.072844714, -0.005873538, 0.029435376, 0.061169676, -0.02318868, 0.051129397, 0.014791153, -0.009028991, -0.021579748, 0.02669236, 0.029696332, -0.063952625, -0.061506465, -0.00080902094, 0.06850867, -0.09809231, -0.005534635, 0.066767104, -0.041267477, 0.046568397, 0.00983124, -0.0048434925, 0.038644254, 0.04096419, 0.0023063375, 0.014526287, 0.014016995, 0.020224908, 0.007113328, -0.0732543, -0.0054818415, 0.05807576, 0.022461535, 0.21100426, -0.009597197, -0.020674499, 0.010743241, -0.046834, -0.0068005333, 0.04918187, -0.06680011, -0.025018543, 0.016360015, 0.100744724, -0.019944709, -0.052390855, -0.0034876189, 0.031699855, -0.03024188, 0.009384044, -0.073849924, 0.01846066, -0.017075414, 0.0067319535, 0.045643695, 0.0121267075, 0.014980903, -0.0022226444, -0.015187039, 0.040638167, 0.023607453, -0.018353134, 0.007413985, 0.03487914, 0.018997269, -0.0107962405, -0.0040080273, 0.001454658, -0.023004232, -0.03065838, -0.0691732, -0.009669473, -0.017253181, 0.100617275, -0.00028453665, -0.055184573, -0.04010461, -0.022628073, -0.02138574, -0.00011931983, -0.021988528, 0.021569526, 0.018913478, -0.07588871, -0.030895703, -0.045679674, 0.03548181, 0.05806986, -0.00313453, 0.005607964, 0.014474551, -0.016833752, -0.022846023, 0.03665983, 0.04312398, 0.006030178, 0.020107903, -0.067837745, -0.039261904, -0.013903933, -0.011238981, -0.091779895, 0.03393072, 0.03576862, -0.016447216, -0.013628061, 0.035994843, 0.02442105, 0.0013356373, -0.013639993, -0.0070654624, -0.031047037, 0.0321763, 0.019488426, 0.030912274, -0.018131692, 0.034129236, -0.038152352, -0.020318052, 0.012934771, -0.0038958737, 0.029313264, 0.0609006, -0.06022117, -0.016697206, -0.030089315, -0.0030464267, -0.05011375, 0.016849633, -0.01935251, 0.00033423092, 0.018090008, 0.034528963, 0.015720658, 0.006443832, 0.0024674414, 0.0033006326, -0.011959118, -0.014686165, 0.00851113, 0.032130115, 0.016566927, -0.0048006177, -0.041135546, 0.017366901, 0.014404645, 0.0014093819, -0.039899524, -0.020875102, -0.01322629, -0.010891931, 0.019460721, -0.098985165, -0.03990147, 0.035807386, 0.05274234, -0.017714208, 0.0023620757, 0.022553496, 0.010935722, -0.016535437, -0.014505468, -0.005573891, -0.029528206, -0.010998497, 0.011297328, 0.007440231, 0.054734096, -0.035311602, 0.07038191, -0.034328025, -0.0109814005, -0.00578824, -0.009286793, 0.06692834, -0.040116422, -0.030043483, -0.010882302, -0.024094587, 0.026659116, -0.0637435, -0.022305744, 0.024388585, 0.011812823, -0.022778027, -0.0039024823, 0.027778644, 0.010566278, 0.011030791, -0.0021155484, 0.018014789, -0.03458981, 0.02546183, -0.11745906, 0.038193583, 0.0019787792, 0.01639592, 0.013218127, -0.012434678, -0.047858853, 0.006662704, 0.033221778, 0.008376927, -0.011822234, 0.01202769, 0.008761578, -0.04075117, 0.0025187496, 0.0026266004, 0.029762473, 0.009570205, -0.03644678, -0.033258904, -0.030776607, 0.05373578, 0.010904848, 0.040284622, 0.02707032, 0.021803873, -0.022011256, -0.05517991, -0.005213912, 0.009023477, -0.011895841, -0.026821174, -0.009035418, -0.021059638, 0.025536137, -0.053264923, 0.032206282, 0.020235807, 0.018660447, 0.0028790566, -0.019914437, 0.097842626, 0.027617158, 0.020276038, -0.014215543, 0.012761584, 0.032757074, 0.061124176, 0.049016643, -0.016509317, -0.03750349, -0.03449537, -0.02039439, -0.051360182, -0.041909404, 0.016175032, 0.040492736, 0.031218654, 0.0020242895, -0.032167237, 0.019398497, 0.057013687, 0.0031299617, 0.019177254, 0.015395364, -0.034078192, 0.041325297, 0.044380017, -0.004446819, 0.019610956, -0.030034903, 0.008468295, 0.03065914, -0.009548659, -0.07113981, 0.051648173, 0.03746448, -0.021847434, 0.01844844, 0.01333424, -0.001188216, 0.012330977, -0.056448817, 0.0008659569, 0.011183285, 0.006780519, -0.007357356, 0.05263679, -0.024631461, 0.00519591, -0.052165415, -0.03250626, -0.009370051, 0.00292325, -0.007187242, 0.029566163, -0.049605303, -0.02625627, -0.003157652, 0.052691437, -0.03589223, 0.03889354, -0.0035060279, 0.024555178, -0.00929779, -0.05037946, -0.022402484, 0.030634355, -0.03300659, -0.0063623153, 0.0027472514, 0.03196768, -0.019257778, 0.0089001395, 0.008908001, 0.018918095, 0.059574094, -0.02838763, 0.018203752, -0.06708146, -0.022670228, -0.013985525, 0.045018435, 0.011420395, -0.008649952, -0.027328938, -0.03527292, -0.0038555951, 0.017597001, 0.024891963, -0.0039160745, -0.015237065, -0.0008723479, -0.018641612, -0.036825016, -0.028743235, 0.00091956893, 0.00030935413, -0.048641082, 0.03744432, -0.024196126, 0.009848505, -0.043836866, 0.0044429195, 0.013709644, 0.06295503, -0.016072558, 0.01277375, -0.03548109, 0.003398656, 0.025347201, 0.019685786, 0.00758199, -0.016122513, -0.039198015, -0.0023108267, -0.0041584945, 0.005161282, 0.00089106365, 0.0076085874, -0.055768084, -0.0058975955, 0.007728267, 0.00076985586, -0.013469806, -0.031578194, -0.0138569595, 0.044540506, -0.0408136, -0.015252405, 0.06232591, -0.04198101, 0.0048899655, -0.0030694627, -0.025022805, -0.010789543, -0.025350742, 0.007836728, 0.024604483, -5.385127e-05, -0.0021367231, -0.01704561, -0.001425816, 0.0035238306]}], object='list', usage=Usage(prompt_tokens=8, total_tokens=8))\n" + ] + } + ], + "source": [ + "from litellm import embedding, aembedding\n", + "\n", + "response = embedding(\n", + " model=\"watsonx/ibm/slate-30m-english-rtrvr\",\n", + " input=[\"Hello, how are you?\"],\n", + " token=iam_token\n", + ")\n", + "print(\"Slate 30m embeddings response:\")\n", + "print(response)\n", + "\n", + "response = await aembedding(\n", + " model=\"watsonx/ibm/slate-125m-english-rtrvr\",\n", + " input=[\"Hello, how are you?\"],\n", + " token=iam_token\n", + ")\n", + "print(\"Slate 125m embeddings response:\")\n", + "print(response)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/my-website/docs/providers/watsonx.md b/docs/my-website/docs/providers/watsonx.md new file mode 100644 index 000000000..9154816a0 --- /dev/null +++ b/docs/my-website/docs/providers/watsonx.md @@ -0,0 +1,284 @@ +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + +# IBM watsonx.ai + +LiteLLM supports all IBM [watsonx.ai](https://watsonx.ai/) foundational models and embeddings. + +## Environment Variables +```python +os.environ["WATSONX_URL"] = "" # (required) Base URL of your WatsonX instance +# (required) either one of the following: +os.environ["WATSONX_APIKEY"] = "" # IBM cloud API key +os.environ["WATSONX_TOKEN"] = "" # IAM auth token +# optional - can also be passed as params to completion() or embedding() +os.environ["WATSONX_PROJECT_ID"] = "" # Project ID of your WatsonX instance +os.environ["WATSONX_DEPLOYMENT_SPACE_ID"] = "" # ID of your deployment space to use deployed models +``` + +See [here](https://cloud.ibm.com/apidocs/watsonx-ai#api-authentication) for more information on how to get an access token to authenticate to watsonx.ai. + +## Usage + + + Open In Colab + + +```python +import os +from litellm import completion + +os.environ["WATSONX_URL"] = "" +os.environ["WATSONX_APIKEY"] = "" + +response = completion( + model="watsonx/ibm/granite-13b-chat-v2", + messages=[{ "content": "what is your favorite colour?","role": "user"}], + project_id="" # or pass with os.environ["WATSONX_PROJECT_ID"] +) + +response = completion( + model="watsonx/meta-llama/llama-3-8b-instruct", + messages=[{ "content": "what is your favorite colour?","role": "user"}], + project_id="" +) +``` + +## Usage - Streaming +```python +import os +from litellm import completion + +os.environ["WATSONX_URL"] = "" +os.environ["WATSONX_APIKEY"] = "" +os.environ["WATSONX_PROJECT_ID"] = "" + +response = completion( + model="watsonx/ibm/granite-13b-chat-v2", + messages=[{ "content": "what is your favorite colour?","role": "user"}], + stream=True +) +for chunk in response: + print(chunk) +``` + +#### Example Streaming Output Chunk +```json +{ + "choices": [ + { + "finish_reason": null, + "index": 0, + "delta": { + "content": "I don't have a favorite color, but I do like the color blue. What's your favorite color?" + } + } + ], + "created": null, + "model": "watsonx/ibm/granite-13b-chat-v2", + "usage": { + "prompt_tokens": null, + "completion_tokens": null, + "total_tokens": null + } +} +``` + +## Usage - Models in deployment spaces + +Models that have been deployed to a deployment space (e.g.: tuned models) can be called using the `deployment/` format (where `` is the ID of the deployed model in your deployment space). + +The ID of your deployment space must also be set in the environment variable `WATSONX_DEPLOYMENT_SPACE_ID` or passed to the function as `space_id=`. + +```python +import litellm +response = litellm.completion( + model="watsonx/deployment/", + messages=[{"content": "Hello, how are you?", "role": "user"}], + space_id="" +) +``` + +## Usage - Embeddings + +LiteLLM also supports making requests to IBM watsonx.ai embedding models. The credential needed for this is the same as for completion. + +```python +from litellm import embedding + +response = embedding( + model="watsonx/ibm/slate-30m-english-rtrvr", + input=["What is the capital of France?"], + project_id="" +) +print(response) +# EmbeddingResponse(model='ibm/slate-30m-english-rtrvr', data=[{'object': 'embedding', 'index': 0, 'embedding': [-0.037463713, -0.02141933, -0.02851813, 0.015519324, ..., -0.0021367231, -0.01704561, -0.001425816, 0.0035238306]}], object='list', usage=Usage(prompt_tokens=8, total_tokens=8)) +``` + +## OpenAI Proxy Usage + +Here's how to call IBM watsonx.ai with the LiteLLM Proxy Server + +### 1. Save keys in your environment + +```bash +export WATSONX_URL="" +export WATSONX_APIKEY="" +export WATSONX_PROJECT_ID="" +``` + +### 2. Start the proxy + + + + +```bash +$ litellm --model watsonx/meta-llama/llama-3-8b-instruct + +# Server running on http://0.0.0.0:4000 +``` + + + + +```yaml +model_list: + - model_name: llama-3-8b + litellm_params: + # all params accepted by litellm.completion() + model: watsonx/meta-llama/llama-3-8b-instruct + api_key: "os.environ/WATSONX_API_KEY" # does os.getenv("WATSONX_API_KEY") +``` + + + +### 3. Test it + + + + + +```shell +curl --location 'http://0.0.0.0:4000/chat/completions' \ +--header 'Content-Type: application/json' \ +--data ' { + "model": "llama-3-8b", + "messages": [ + { + "role": "user", + "content": "what is your favorite colour?" + } + ] + } +' +``` + + + +```python +import openai +client = openai.OpenAI( + api_key="anything", + base_url="http://0.0.0.0:4000" +) + +# request sent to model set on litellm proxy, `litellm --model` +response = client.chat.completions.create(model="llama-3-8b", messages=[ + { + "role": "user", + "content": "what is your favorite colour?" + } +]) + +print(response) + +``` + + + +```python +from langchain.chat_models import ChatOpenAI +from langchain.prompts.chat import ( + ChatPromptTemplate, + HumanMessagePromptTemplate, + SystemMessagePromptTemplate, +) +from langchain.schema import HumanMessage, SystemMessage + +chat = ChatOpenAI( + openai_api_base="http://0.0.0.0:4000", # set openai_api_base to the LiteLLM Proxy + model = "llama-3-8b", + temperature=0.1 +) + +messages = [ + SystemMessage( + content="You are a helpful assistant that im using to make a test request to." + ), + HumanMessage( + content="test from litellm. tell me why it's amazing in 1 sentence" + ), +] +response = chat(messages) + +print(response) +``` + + + + +## Authentication + +### Passing credentials as parameters + +You can also pass the credentials as parameters to the completion and embedding functions. + +```python +import os +from litellm import completion + +response = completion( + model="watsonx/ibm/granite-13b-chat-v2", + messages=[{ "content": "What is your favorite color?","role": "user"}], + url="", + api_key="", + project_id="" +) +``` + + +## Supported IBM watsonx.ai Models + +Here are some examples of models available in IBM watsonx.ai that you can use with LiteLLM: + +| Mode Name | Command | +| ---------- | --------- | +| Flan T5 XXL | `completion(model=watsonx/google/flan-t5-xxl, messages=messages)` | +| Flan Ul2 | `completion(model=watsonx/google/flan-ul2, messages=messages)` | +| Mt0 XXL | `completion(model=watsonx/bigscience/mt0-xxl, messages=messages)` | +| Gpt Neox | `completion(model=watsonx/eleutherai/gpt-neox-20b, messages=messages)` | +| Mpt 7B Instruct2 | `completion(model=watsonx/ibm/mpt-7b-instruct2, messages=messages)` | +| Starcoder | `completion(model=watsonx/bigcode/starcoder, messages=messages)` | +| Llama 2 70B Chat | `completion(model=watsonx/meta-llama/llama-2-70b-chat, messages=messages)` | +| Llama 2 13B Chat | `completion(model=watsonx/meta-llama/llama-2-13b-chat, messages=messages)` | +| Granite 13B Instruct | `completion(model=watsonx/ibm/granite-13b-instruct-v1, messages=messages)` | +| Granite 13B Chat | `completion(model=watsonx/ibm/granite-13b-chat-v1, messages=messages)` | +| Flan T5 XL | `completion(model=watsonx/google/flan-t5-xl, messages=messages)` | +| Granite 13B Chat V2 | `completion(model=watsonx/ibm/granite-13b-chat-v2, messages=messages)` | +| Granite 13B Instruct V2 | `completion(model=watsonx/ibm/granite-13b-instruct-v2, messages=messages)` | +| Elyza Japanese Llama 2 7B Instruct | `completion(model=watsonx/elyza/elyza-japanese-llama-2-7b-instruct, messages=messages)` | +| Mixtral 8X7B Instruct V01 Q | `completion(model=watsonx/ibm-mistralai/mixtral-8x7b-instruct-v01-q, messages=messages)` | + + +For a list of all available models in watsonx.ai, see [here](https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fm-models.html?context=wx&locale=en&audience=wdp). + + +## Supported IBM watsonx.ai Embedding Models + +| Model Name | Function Call | +|----------------------|---------------------------------------------| +| Slate 30m | `embedding(model="watsonx/ibm/slate-30m-english-rtrvr", input=input)` | +| Slate 125m | `embedding(model="watsonx/ibm/slate-125m-english-rtrvr", input=input)` | + + +For a list of all available embedding models in watsonx.ai, see [here](https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fm-models-embed.html?context=wx). \ No newline at end of file diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index f47846892..a5d1f30ae 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -148,6 +148,7 @@ const sidebars = { "providers/openrouter", "providers/custom_openai_proxy", "providers/petals", + "providers/watsonx", ], }, "proxy/custom_pricing", diff --git a/litellm/__init__.py b/litellm/__init__.py index 75a6751b0..5f23ae33e 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -299,6 +299,7 @@ aleph_alpha_models: List = [] bedrock_models: List = [] deepinfra_models: List = [] perplexity_models: List = [] +watsonx_models: List = [] for key, value in model_cost.items(): if value.get("litellm_provider") == "openai": open_ai_chat_completion_models.append(key) @@ -343,6 +344,8 @@ for key, value in model_cost.items(): deepinfra_models.append(key) elif value.get("litellm_provider") == "perplexity": perplexity_models.append(key) + elif value.get("litellm_provider") == "watsonx": + watsonx_models.append(key) # known openai compatible endpoints - we'll eventually move this list to the model_prices_and_context_window.json dictionary openai_compatible_endpoints: List = [ @@ -479,6 +482,7 @@ model_list = ( + perplexity_models + maritalk_models + vertex_language_models + + watsonx_models ) provider_list: List = [ @@ -517,6 +521,7 @@ provider_list: List = [ "cloudflare", "xinference", "fireworks_ai", + "watsonx", "custom", # custom apis ] @@ -538,6 +543,7 @@ models_by_provider: dict = { "deepinfra": deepinfra_models, "perplexity": perplexity_models, "maritalk": maritalk_models, + "watsonx": watsonx_models, } # mapping for those models which have larger equivalents @@ -651,6 +657,7 @@ from .llms.bedrock import ( ) from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig from .llms.azure import AzureOpenAIConfig, AzureOpenAIError +from .llms.watsonx import IBMWatsonXAIConfig from .main import * # type: ignore from .integrations import * from .exceptions import ( diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index 0e05d729e..cb1b2eb73 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -430,6 +430,32 @@ def format_prompt_togetherai(messages, prompt_format, chat_template): prompt = default_pt(messages) return prompt +### IBM Granite + +def ibm_granite_pt(messages: list): + """ + IBM's Granite models uses the template: + <|system|> {system_message} <|user|> {user_message} <|assistant|> {assistant_message} + + See: https://www.ibm.com/docs/en/watsonx-as-a-service?topic=solutions-supported-foundation-models + """ + return custom_prompt( + messages=messages, + role_dict={ + 'system': { + 'pre_message': '<|system|>\n', + 'post_message': '\n', + }, + 'user': { + 'pre_message': '<|user|>\n', + 'post_message': '\n', + }, + 'assistant': { + 'pre_message': '<|assistant|>\n', + 'post_message': '\n', + } + } + ).strip() ### ANTHROPIC ### @@ -1365,6 +1391,25 @@ def prompt_factory( return messages elif custom_llm_provider == "azure_text": return azure_text_pt(messages=messages) + elif custom_llm_provider == "watsonx": + if "granite" in model and "chat" in model: + # granite-13b-chat-v1 and granite-13b-chat-v2 use a specific prompt template + return ibm_granite_pt(messages=messages) + elif "ibm-mistral" in model and "instruct" in model: + # models like ibm-mistral/mixtral-8x7b-instruct-v01-q use the mistral instruct prompt template + return mistral_instruct_pt(messages=messages) + elif "meta-llama/llama-3" in model and "instruct" in model: + # https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/ + return custom_prompt( + role_dict={ + "system": {"pre_message": "<|start_header_id|>system<|end_header_id|>\n", "post_message": "<|eot_id|>"}, + "user": {"pre_message": "<|start_header_id|>user<|end_header_id|>\n", "post_message": "<|eot_id|>"}, + "assistant": {"pre_message": "<|start_header_id|>assistant<|end_header_id|>\n", "post_message": "<|eot_id|>"}, + }, + messages=messages, + initial_prompt_value="<|begin_of_text|>", + final_prompt_value="<|start_header_id|>assistant<|end_header_id|>\n", + ) try: if "meta-llama/llama-2" in model and "chat" in model: return llama_2_chat_pt(messages=messages) diff --git a/litellm/llms/watsonx.py b/litellm/llms/watsonx.py new file mode 100644 index 000000000..aa0cb32df --- /dev/null +++ b/litellm/llms/watsonx.py @@ -0,0 +1,569 @@ +from enum import Enum +import json, types, time # noqa: E401 +from contextlib import contextmanager +from typing import Callable, Dict, Optional, Any, Union, List + +import httpx +import requests +import litellm +from litellm.utils import ModelResponse, get_secret, Usage + +from .base import BaseLLM +from .prompt_templates import factory as ptf + + +class WatsonXAIError(Exception): + def __init__(self, status_code, message, url: str = None): + self.status_code = status_code + self.message = message + url = url or "https://https://us-south.ml.cloud.ibm.com" + self.request = httpx.Request(method="POST", url=url) + self.response = httpx.Response(status_code=status_code, request=self.request) + super().__init__( + self.message + ) # Call the base class constructor with the parameters it needs + + +class IBMWatsonXAIConfig: + """ + Reference: https://cloud.ibm.com/apidocs/watsonx-ai#text-generation + (See ibm_watsonx_ai.metanames.GenTextParamsMetaNames for a list of all available params) + + Supported params for all available watsonx.ai foundational models. + + - `decoding_method` (str): One of "greedy" or "sample" + + - `temperature` (float): Sets the model temperature for sampling - not available when decoding_method='greedy'. + + - `max_new_tokens` (integer): Maximum length of the generated tokens. + + - `min_new_tokens` (integer): Maximum length of input tokens. Any more than this will be truncated. + + - `length_penalty` (dict): A dictionary with keys "decay_factor" and "start_index". + + - `stop_sequences` (string[]): list of strings to use as stop sequences. + + - `top_k` (integer): top k for sampling - not available when decoding_method='greedy'. + + - `top_p` (integer): top p for sampling - not available when decoding_method='greedy'. + + - `repetition_penalty` (float): token repetition penalty during text generation. + + - `truncate_input_tokens` (integer): Truncate input tokens to this length. + + - `include_stop_sequences` (bool): If True, the stop sequence will be included at the end of the generated text in the case of a match. + + - `return_options` (dict): A dictionary of options to return. Options include "input_text", "generated_tokens", "input_tokens", "token_ranks". Values are boolean. + + - `random_seed` (integer): Random seed for text generation. + + - `moderations` (dict): Dictionary of properties that control the moderations, for usages such as Hate and profanity (HAP) and PII filtering. + + - `stream` (bool): If True, the model will return a stream of responses. + """ + + decoding_method: Optional[str] = "sample" + temperature: Optional[float] = None + max_new_tokens: Optional[int] = None # litellm.max_tokens + min_new_tokens: Optional[int] = None + length_penalty: Optional[dict] = None # e.g {"decay_factor": 2.5, "start_index": 5} + stop_sequences: Optional[List[str]] = None # e.g ["}", ")", "."] + top_k: Optional[int] = None + top_p: Optional[float] = None + repetition_penalty: Optional[float] = None + truncate_input_tokens: Optional[int] = None + include_stop_sequences: Optional[bool] = False + return_options: Optional[dict] = None + return_options: Optional[Dict[str, bool]] = None + random_seed: Optional[int] = None # e.g 42 + moderations: Optional[dict] = None + stream: Optional[bool] = False + + def __init__( + self, + decoding_method: Optional[str] = None, + temperature: Optional[float] = None, + max_new_tokens: Optional[int] = None, + min_new_tokens: Optional[int] = None, + length_penalty: Optional[dict] = None, + stop_sequences: Optional[List[str]] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + repetition_penalty: Optional[float] = None, + truncate_input_tokens: Optional[int] = None, + include_stop_sequences: Optional[bool] = None, + return_options: Optional[dict] = None, + random_seed: Optional[int] = None, + moderations: Optional[dict] = None, + stream: Optional[bool] = None, + **kwargs, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def get_supported_openai_params(self): + return [ + "temperature", # equivalent to temperature + "max_tokens", # equivalent to max_new_tokens + "top_p", # equivalent to top_p + "frequency_penalty", # equivalent to repetition_penalty + "stop", # equivalent to stop_sequences + "seed", # equivalent to random_seed + "stream", # equivalent to stream + ] + + +def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict): + # handle anthropic prompts and amazon titan prompts + if model in custom_prompt_dict: + # check if the model has a registered custom prompt + model_prompt_dict = custom_prompt_dict[model] + prompt = ptf.custom_prompt( + messages=messages, + role_dict=model_prompt_dict.get( + "role_dict", model_prompt_dict.get("roles") + ), + initial_prompt_value=model_prompt_dict.get("initial_prompt_value", ""), + final_prompt_value=model_prompt_dict.get("final_prompt_value", ""), + bos_token=model_prompt_dict.get("bos_token", ""), + eos_token=model_prompt_dict.get("eos_token", ""), + ) + return prompt + elif provider == "ibm": + prompt = ptf.prompt_factory( + model=model, messages=messages, custom_llm_provider="watsonx" + ) + elif provider == "ibm-mistralai": + prompt = ptf.mistral_instruct_pt(messages=messages) + else: + prompt = ptf.prompt_factory( + model=model, messages=messages, custom_llm_provider="watsonx" + ) + return prompt + +class WatsonXAIEndpoint(str, Enum): + TEXT_GENERATION = "/ml/v1/text/generation" + TEXT_GENERATION_STREAM = "/ml/v1/text/generation_stream" + DEPLOYMENT_TEXT_GENERATION = "/ml/v1/deployments/{deployment_id}/text/generation" + DEPLOYMENT_TEXT_GENERATION_STREAM = ( + "/ml/v1/deployments/{deployment_id}/text/generation_stream" + ) + EMBEDDINGS = "/ml/v1/text/embeddings" + PROMPTS = "/ml/v1/prompts" + +class IBMWatsonXAI(BaseLLM): + """ + Class to interface with IBM Watsonx.ai API for text generation and embeddings. + + Reference: https://cloud.ibm.com/apidocs/watsonx-ai + """ + + api_version = "2024-03-13" + + + def __init__(self) -> None: + super().__init__() + + def _prepare_text_generation_req( + self, + model_id: str, + prompt: str, + stream: bool, + optional_params: dict, + print_verbose: Callable = None, + ) -> dict: + """ + Get the request parameters for text generation. + """ + api_params = self._get_api_params(optional_params, print_verbose=print_verbose) + # build auth headers + api_token = api_params.get("token") + + headers = { + "Authorization": f"Bearer {api_token}", + "Content-Type": "application/json", + "Accept": "application/json", + } + extra_body_params = optional_params.pop("extra_body", {}) + optional_params.update(extra_body_params) + # init the payload to the text generation call + payload = { + "input": prompt, + "moderations": optional_params.pop("moderations", {}), + "parameters": optional_params, + } + request_params = dict(version=api_params["api_version"]) + # text generation endpoint deployment or model / stream or not + if model_id.startswith("deployment/"): + # deployment models are passed in as 'deployment/' + if api_params.get("space_id") is None: + raise WatsonXAIError( + status_code=401, + url=api_params["url"], + message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.", + ) + deployment_id = "/".join(model_id.split("/")[1:]) + endpoint = ( + WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION_STREAM + if stream + else WatsonXAIEndpoint.DEPLOYMENT_TEXT_GENERATION + ) + endpoint = endpoint.format(deployment_id=deployment_id) + else: + payload["model_id"] = model_id + payload["project_id"] = api_params["project_id"] + endpoint = ( + WatsonXAIEndpoint.TEXT_GENERATION_STREAM + if stream + else WatsonXAIEndpoint.TEXT_GENERATION + ) + url = api_params["url"].rstrip("/") + endpoint + return dict( + method="POST", url=url, headers=headers, + json=payload, params=request_params + ) + + def _get_api_params(self, params: dict, print_verbose: Callable = None) -> dict: + """ + Find watsonx.ai credentials in the params or environment variables and return the headers for authentication. + """ + # Load auth variables from params + url = params.pop("url", None) + api_key = params.pop("apikey", None) + token = params.pop("token", None) + project_id = params.pop("project_id", None) # watsonx.ai project_id + space_id = params.pop("space_id", None) # watsonx.ai deployment space_id + region_name = params.pop("region_name", params.pop("region", None)) + wx_credentials = params.pop("wx_credentials", None) + api_version = params.pop("api_version", IBMWatsonXAI.api_version) + # Load auth variables from environment variables + if url is None: + url = ( + get_secret("WATSONX_URL") + or get_secret("WX_URL") + or get_secret("WML_URL") + ) + if api_key is None: + api_key = ( + get_secret("WATSONX_APIKEY") + or get_secret("WATSONX_API_KEY") + or get_secret("WX_API_KEY") + ) + if token is None: + token = get_secret("WATSONX_TOKEN") or get_secret("WX_TOKEN") + if project_id is None: + project_id = ( + get_secret("WATSONX_PROJECT_ID") + or get_secret("WX_PROJECT_ID") + or get_secret("PROJECT_ID") + ) + if region_name is None: + region_name = ( + get_secret("WATSONX_REGION") + or get_secret("WX_REGION") + or get_secret("REGION") + ) + if space_id is None: + space_id = ( + get_secret("WATSONX_DEPLOYMENT_SPACE_ID") + or get_secret("WATSONX_SPACE_ID") + or get_secret("WX_SPACE_ID") + or get_secret("SPACE_ID") + ) + + # credentials parsing + if wx_credentials is not None: + url = wx_credentials.get("url", url) + api_key = wx_credentials.get( + "apikey", wx_credentials.get("api_key", api_key) + ) + token = wx_credentials.get("token", token) + + # verify that all required credentials are present + if url is None: + raise WatsonXAIError( + status_code=401, + message="Error: Watsonx URL not set. Set WX_URL in environment variables or pass in as a parameter.", + ) + if token is None and api_key is not None: + # generate the auth token + if print_verbose: + print_verbose("Generating IAM token for Watsonx.ai") + token = self.generate_iam_token(api_key) + elif token is None and api_key is None: + raise WatsonXAIError( + status_code=401, + url=url, + message="Error: API key or token not found. Set WX_API_KEY or WX_TOKEN in environment variables or pass in as a parameter.", + ) + if project_id is None: + raise WatsonXAIError( + status_code=401, + url=url, + message="Error: Watsonx project_id not set. Set WX_PROJECT_ID in environment variables or pass in as a parameter.", + ) + + return { + "url": url, + "api_key": api_key, + "token": token, + "project_id": project_id, + "space_id": space_id, + "region_name": region_name, + "api_version": api_version, + } + + def completion( + self, + model: str, + messages: list, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + logging_obj, + optional_params: Optional[dict] = None, + litellm_params: Optional[dict] = None, + logger_fn=None, + timeout: float = None, + ): + """ + Send a text generation request to the IBM Watsonx.ai API. + Reference: https://cloud.ibm.com/apidocs/watsonx-ai#text-generation + """ + stream = optional_params.pop("stream", False) + + # Load default configs + config = IBMWatsonXAIConfig.get_config() + for k, v in config.items(): + if k not in optional_params: + optional_params[k] = v + + # Make prompt to send to model + provider = model.split("/")[0] + # model_name = "/".join(model.split("/")[1:]) + prompt = convert_messages_to_prompt( + model, messages, provider, custom_prompt_dict + ) + + def process_text_request(request_params: dict) -> ModelResponse: + with self._manage_response( + request_params, logging_obj=logging_obj, input=prompt, timeout=timeout + ) as resp: + json_resp = resp.json() + + generated_text = json_resp["results"][0]["generated_text"] + prompt_tokens = json_resp["results"][0]["input_token_count"] + completion_tokens = json_resp["results"][0]["generated_token_count"] + model_response["choices"][0]["message"]["content"] = generated_text + model_response["finish_reason"] = json_resp["results"][0]["stop_reason"] + model_response["created"] = int(time.time()) + model_response["model"] = model + model_response.usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + return model_response + + def process_stream_request( + request_params: dict, + ) -> litellm.CustomStreamWrapper: + # stream the response - generated chunks will be handled + # by litellm.utils.CustomStreamWrapper.handle_watsonx_stream + with self._manage_response( + request_params, + logging_obj=logging_obj, + stream=True, + input=prompt, + timeout=timeout, + ) as resp: + response = litellm.CustomStreamWrapper( + resp.iter_lines(), + model=model, + custom_llm_provider="watsonx", + logging_obj=logging_obj, + ) + return response + + try: + ## Get the response from the model + req_params = self._prepare_text_generation_req( + model_id=model, + prompt=prompt, + stream=stream, + optional_params=optional_params, + print_verbose=print_verbose, + ) + if stream: + return process_stream_request(req_params) + else: + return process_text_request(req_params) + except WatsonXAIError as e: + raise e + except Exception as e: + raise WatsonXAIError(status_code=500, message=str(e)) + + def embedding( + self, + model: str, + input: Union[list, str], + api_key: Optional[str] = None, + logging_obj=None, + model_response=None, + optional_params=None, + encoding=None, + ): + """ + Send a text embedding request to the IBM Watsonx.ai API. + """ + if optional_params is None: + optional_params = {} + # Load default configs + config = IBMWatsonXAIConfig.get_config() + for k, v in config.items(): + if k not in optional_params: + optional_params[k] = v + + # Load auth variables from environment variables + if isinstance(input, str): + input = [input] + if api_key is not None: + optional_params["api_key"] = api_key + api_params = self._get_api_params(optional_params) + # build auth headers + api_token = api_params.get("token") + headers = { + "Authorization": f"Bearer {api_token}", + "Content-Type": "application/json", + "Accept": "application/json", + } + # init the payload to the text generation call + payload = { + "inputs": input, + "model_id": model, + "project_id": api_params["project_id"], + "parameters": optional_params, + } + request_params = dict(version=api_params["api_version"]) + url = api_params["url"].rstrip("/") + WatsonXAIEndpoint.EMBEDDINGS + # request = httpx.Request( + # "POST", url, headers=headers, json=payload, params=request_params + # ) + req_params = { + "method": "POST", + "url": url, + "headers": headers, + "json": payload, + "params": request_params, + } + with self._manage_response( + req_params, logging_obj=logging_obj, input=input + ) as resp: + json_resp = resp.json() + + results = json_resp.get("results", []) + embedding_response = [] + for idx, result in enumerate(results): + embedding_response.append( + {"object": "embedding", "index": idx, "embedding": result["embedding"]} + ) + model_response["object"] = "list" + model_response["data"] = embedding_response + model_response["model"] = model + input_tokens = json_resp.get("input_token_count", 0) + model_response.usage = Usage( + prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens + ) + return model_response + + def generate_iam_token(self, api_key=None, **params): + headers = {} + headers["Content-Type"] = "application/x-www-form-urlencoded" + if api_key is None: + api_key = get_secret("WX_API_KEY") or get_secret("WATSONX_API_KEY") + if api_key is None: + raise ValueError("API key is required") + headers["Accept"] = "application/json" + data = { + "grant_type": "urn:ibm:params:oauth:grant-type:apikey", + "apikey": api_key, + } + response = httpx.post( + "https://iam.cloud.ibm.com/identity/token", data=data, headers=headers + ) + response.raise_for_status() + json_data = response.json() + iam_access_token = json_data["access_token"] + self.token = iam_access_token + return iam_access_token + + @contextmanager + def _manage_response( + self, + request_params: dict, + logging_obj: Any, + stream: bool = False, + input: Optional[Any] = None, + timeout: float = None, + ): + request_str = ( + f"response = {request_params['method']}(\n" + f"\turl={request_params['url']},\n" + f"\tjson={request_params['json']},\n" + f")" + ) + logging_obj.pre_call( + input=input, + api_key=request_params['headers'].get("Authorization"), + additional_args={ + "complete_input_dict": request_params['json'], + "request_str": request_str, + }, + ) + if timeout: + request_params['timeout'] = timeout + try: + if stream: + resp = requests.request( + **request_params, + stream=True, + ) + resp.raise_for_status() + yield resp + else: + resp = requests.request(**request_params) + resp.raise_for_status() + yield resp + except Exception as e: + raise WatsonXAIError(status_code=500, message=str(e)) + if not stream: + logging_obj.post_call( + input=input, + api_key=request_params['headers'].get("Authorization"), + original_response=json.dumps(resp.json()), + additional_args={ + "status_code": resp.status_code, + "complete_input_dict": request_params['json'], + }, + ) diff --git a/litellm/main.py b/litellm/main.py index 52b2f3089..41794ccd5 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -62,6 +62,7 @@ from .llms import ( vertex_ai, vertex_ai_anthropic, maritalk, + watsonx, ) from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion from .llms.azure import AzureChatCompletion @@ -1862,6 +1863,43 @@ def completion( ## RESPONSE OBJECT response = response + elif custom_llm_provider == "watsonx": + custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict + response = watsonx.IBMWatsonXAI().completion( + model=model, + messages=messages, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + logging_obj=logging, + timeout=timeout, + ) + if ( + "stream" in optional_params + and optional_params["stream"] == True + and not isinstance(response, CustomStreamWrapper) + ): + # don't try to access stream object, + response = CustomStreamWrapper( + iter(response), + model, + custom_llm_provider="watsonx", + logging_obj=logging, + ) + + if optional_params.get("stream", False): + ## LOGGING + logging.post_call( + input=messages, + api_key=None, + original_response=response, + ) + ## RESPONSE OBJECT + response = response elif custom_llm_provider == "vllm": custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict model_response = vllm.completion( @@ -2941,6 +2979,15 @@ def embedding( client=client, aembedding=aembedding, ) + elif custom_llm_provider == "watsonx": + response = watsonx.IBMWatsonXAI().embedding( + model=model, + input=input, + encoding=encoding, + logging_obj=logging, + optional_params=optional_params, + model_response=EmbeddingResponse(), + ) else: args = locals() raise ValueError(f"No valid embedding model args passed in - {args}") diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index e2c635a5c..b4e06f596 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -2696,6 +2696,41 @@ def test_completion_palm_stream(): except Exception as e: pytest.fail(f"Error occurred: {e}") +def test_completion_watsonx(): + litellm.set_verbose = True + model_name = "watsonx/ibm/granite-13b-chat-v2" + try: + response = completion( + model=model_name, + messages=messages, + stop=["stop"], + max_tokens=20, + ) + # Add any assertions here to check the response + print(response) + except litellm.APIError as e: + pass + except Exception as e: + pytest.fail(f"Error occurred: {e}") + +@pytest.mark.asyncio +async def test_acompletion_watsonx(): + litellm.set_verbose = True + model_name = "watsonx/deployment/"+os.getenv("WATSONX_DEPLOYMENT_ID") + print("testing watsonx") + try: + response = await litellm.acompletion( + model=model_name, + messages=messages, + temperature=0.2, + max_tokens=80, + space_id=os.getenv("WATSONX_SPACE_ID_TEST"), + ) + # Add any assertions here to check the response + print(response) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + # test_completion_palm_stream() diff --git a/litellm/tests/test_embedding.py b/litellm/tests/test_embedding.py index d69e2d708..e9a86997b 100644 --- a/litellm/tests/test_embedding.py +++ b/litellm/tests/test_embedding.py @@ -483,6 +483,18 @@ def test_mistral_embeddings(): except Exception as e: pytest.fail(f"Error occurred: {e}") +def test_watsonx_embeddings(): + try: + litellm.set_verbose = True + response = litellm.embedding( + model="watsonx/ibm/slate-30m-english-rtrvr", + input=["good morning from litellm"], + ) + print(f"response: {response}") + assert isinstance(response.usage, litellm.Usage) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + # test_mistral_embeddings() diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index df759b0b9..d0d8a720a 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -1271,6 +1271,32 @@ def test_completion_sagemaker_stream(): pytest.fail(f"Error occurred: {e}") +def test_completion_watsonx_stream(): + litellm.set_verbose = True + try: + response = completion( + model="watsonx/ibm/granite-13b-chat-v2", + messages=messages, + temperature=0.5, + max_tokens=20, + stream=True, + ) + complete_response = "" + has_finish_reason = False + # Add any assertions here to check the response + for idx, chunk in enumerate(response): + chunk, finished = streaming_format_tests(idx, chunk) + has_finish_reason = finished + if finished: + break + complete_response += chunk + if has_finish_reason is False: + raise Exception("finish reason not set for last chunk") + if complete_response.strip() == "": + raise Exception("Empty response received") + except Exception as e: + pytest.fail(f"Error occurred: {e}") + # test_completion_sagemaker_stream() diff --git a/litellm/utils.py b/litellm/utils.py index 8c3863344..9b91ab36d 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5427,6 +5427,45 @@ def get_optional_params( optional_params["extra_body"] = ( extra_body # openai client supports `extra_body` param ) + elif custom_llm_provider == "watsonx": + supported_params = get_supported_openai_params( + model=model, custom_llm_provider=custom_llm_provider + ) + _check_valid_arg(supported_params=supported_params) + if max_tokens is not None: + optional_params["max_new_tokens"] = max_tokens + if stream: + optional_params["stream"] = stream + if temperature is not None: + optional_params["temperature"] = temperature + if top_p is not None: + optional_params["top_p"] = top_p + if frequency_penalty is not None: + optional_params["repetition_penalty"] = frequency_penalty + if seed is not None: + optional_params["random_seed"] = seed + if stop is not None: + optional_params["stop_sequences"] = stop + + # WatsonX-only parameters + extra_body = {} + if "decoding_method" in passed_params: + extra_body["decoding_method"] = passed_params.pop("decoding_method") + if "min_tokens" in passed_params or "min_new_tokens" in passed_params: + extra_body["min_new_tokens"] = passed_params.pop("min_tokens", passed_params.pop("min_new_tokens")) + if "top_k" in passed_params: + extra_body["top_k"] = passed_params.pop("top_k") + if "truncate_input_tokens" in passed_params: + extra_body["truncate_input_tokens"] = passed_params.pop("truncate_input_tokens") + if "length_penalty" in passed_params: + extra_body["length_penalty"] = passed_params.pop("length_penalty") + if "time_limit" in passed_params: + extra_body["time_limit"] = passed_params.pop("time_limit") + if "return_options" in passed_params: + extra_body["return_options"] = passed_params.pop("return_options") + optional_params["extra_body"] = ( + extra_body # openai client supports `extra_body` param + ) else: # assume passing in params for openai/azure openai print_verbose( f"UNMAPPED PROVIDER, ASSUMING IT'S OPENAI/AZURE - model={model}, custom_llm_provider={custom_llm_provider}" @@ -5829,6 +5868,8 @@ def get_supported_openai_params(model: str, custom_llm_provider: str): "frequency_penalty", "presence_penalty", ] + elif custom_llm_provider == "watsonx": + return litellm.IBMWatsonXAIConfig().get_supported_openai_params() def get_formatted_prompt( @@ -6056,6 +6097,8 @@ def get_llm_provider( model in litellm.bedrock_models or model in litellm.bedrock_embedding_models ): custom_llm_provider = "bedrock" + elif model in litellm.watsonx_models: + custom_llm_provider = "watsonx" # openai embeddings elif model in litellm.open_ai_embedding_models: custom_llm_provider = "openai" @@ -9750,6 +9793,37 @@ class CustomStreamWrapper: "is_finished": chunk["is_finished"], "finish_reason": finish_reason, } + + def handle_watsonx_stream(self, chunk): + try: + if isinstance(chunk, dict): + parsed_response = chunk + elif isinstance(chunk, (str, bytes)): + if isinstance(chunk, bytes): + chunk = chunk.decode("utf-8") + if 'generated_text' in chunk: + response = chunk.replace('data: ', '').strip() + parsed_response = json.loads(response) + else: + return {"text": "", "is_finished": False} + else: + print_verbose(f"chunk: {chunk} (Type: {type(chunk)})") + raise ValueError(f"Unable to parse response. Original response: {chunk}") + results = parsed_response.get("results", []) + if len(results) > 0: + text = results[0].get("generated_text", "") + finish_reason = results[0].get("stop_reason") + is_finished = finish_reason != 'not_finished' + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + "prompt_tokens": results[0].get("input_token_count", None), + "completion_tokens": results[0].get("generated_token_count", None), + } + return {"text": "", "is_finished": False} + except Exception as e: + raise e def model_response_creator(self): model_response = ModelResponse(stream=True, model=self.model) @@ -10006,6 +10080,21 @@ class CustomStreamWrapper: print_verbose(f"completion obj content: {completion_obj['content']}") if response_obj["is_finished"]: self.received_finish_reason = response_obj["finish_reason"] + elif self.custom_llm_provider == "watsonx": + response_obj = self.handle_watsonx_stream(chunk) + completion_obj["content"] = response_obj["text"] + print_verbose(f"completion obj content: {completion_obj['content']}") + if response_obj.get("prompt_tokens") is not None: + prompt_token_count = getattr(model_response.usage, "prompt_tokens", 0) + model_response.usage.prompt_tokens = (prompt_token_count+response_obj["prompt_tokens"]) + if response_obj.get("completion_tokens") is not None: + model_response.usage.completion_tokens = response_obj["completion_tokens"] + model_response.usage.total_tokens = ( + getattr(model_response.usage, "prompt_tokens", 0) + + getattr(model_response.usage, "completion_tokens", 0) + ) + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] elif self.custom_llm_provider == "text-completion-openai": response_obj = self.handle_openai_text_completion_chunk(chunk) completion_obj["content"] = response_obj["text"]